Python tensorflow.python.platform.gfile 模块,Open() 实例源码

我们从Python开源项目中,提取了以下39个代码示例,用于说明如何使用tensorflow.python.platform.gfile.Open()

项目:benchmarks    作者:tensorflow    | 项目源码 | 文件源码
def read_data_files(self, subset='train'):
    """Reads from data file and returns images and labels in a numpy array."""
    assert self.data_dir, ('Cannot call `read_data_files` when using synthetic '
                           'data')
    if subset == 'train':
      filenames = [os.path.join(self.data_dir, 'data_batch_%d' % i)
                   for i in xrange(1, 6)]
    elif subset == 'validation':
      filenames = [os.path.join(self.data_dir, 'test_batch')]
    else:
      raise ValueError('Invalid data subset "%s"' % subset)

    inputs = []
    for filename in filenames:
      with gfile.Open(filename, 'r') as f:
        inputs.append(cPickle.load(f))
    # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
    # input format.
    all_images = np.concatenate(
        [each_input['data'] for each_input in inputs]).astype(np.float32)
    all_labels = np.concatenate(
        [each_input['labels'] for each_input in inputs])
    return all_images, all_labels
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def load_csv_with_header(filename,
                         target_dtype,
                         features_dtype,
                         target_column=-1):
  """Load dataset from CSV file with a header row."""
  with gfile.Open(filename) as csv_file:
    data_file = csv.reader(csv_file)
    header = next(data_file)
    n_samples = int(header[0])
    n_features = int(header[1])
    data = np.zeros((n_samples, n_features), dtype=features_dtype)
    target = np.zeros((n_samples,), dtype=target_dtype)
    for i, row in enumerate(data_file):
      target[i] = np.asarray(row.pop(target_column), dtype=target_dtype)
      data[i] = np.asarray(row, dtype=features_dtype)

  return Dataset(data=data, target=target)
项目:stuff    作者:yaroslavvb    | 项目源码 | 文件源码
def read_data_files(self, subset='train'):
    """Reads from data file and return images and labels in a numpy array."""
    if subset == 'train':
      filenames = [os.path.join(self.data_dir, 'data_batch_%d' % i)
                   for i in xrange(1, 6)]
    elif subset == 'validation':
      filenames = [os.path.join(self.data_dir, 'test_batch')]
    else:
      raise ValueError('Invalid data subset "%s"' % subset)

    inputs = []
    for filename in filenames:
      with gfile.Open(filename, 'r') as f:
        inputs.append(cPickle.load(f))
    # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
    # input format.
    all_images = np.concatenate(
        [each_input['data'] for each_input in inputs]).astype(np.float32)
    all_labels = np.concatenate(
        [each_input['labels'] for each_input in inputs])
    return all_images, all_labels
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def load_csv_with_header(filename,
                         target_dtype,
                         features_dtype,
                         target_column=-1):
  """Load dataset from CSV file with a header row."""
  with gfile.Open(filename) as csv_file:
    data_file = csv.reader(csv_file)
    header = next(data_file)
    n_samples = int(header[0])
    n_features = int(header[1])
    data = np.zeros((n_samples, n_features), dtype=features_dtype)
    target = np.zeros((n_samples,), dtype=target_dtype)
    for i, row in enumerate(data_file):
      target[i] = np.asarray(row.pop(target_column), dtype=target_dtype)
      data[i] = np.asarray(row, dtype=features_dtype)

  return Dataset(data=data, target=target)
项目:Caption-Generation    作者:m516825    | 项目源码 | 文件源码
def save(self, filename):
        with gfile.Open(filename, 'wb') as f:
            f.write(pickle.dumps(self))
项目:Caption-Generation    作者:m516825    | 项目源码 | 文件源码
def restore(cls, filename):
        with gfile.Open(filename, 'rb') as f:
            return pickle.loads(f.read())
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def _write_plugin_assets(self, graph):
    plugin_assets = plugin_asset.get_all_plugin_assets(graph)
    logdir = self.event_writer.get_logdir()
    for asset_container in plugin_assets:
      plugin_name = asset_container.plugin_name
      plugin_dir = os.path.join(logdir, _PLUGINS_DIR, plugin_name)
      gfile.MakeDirs(plugin_dir)
      assets = asset_container.assets()
      for (asset_name, content) in assets.items():
        asset_path = os.path.join(plugin_dir, asset_name)
        with gfile.Open(asset_path, "w") as f:
          f.write(content)
项目:tefla    作者:openAGI    | 项目源码 | 文件源码
def main(unused_args):
    if not gfile.Exists(FLAGS.input):
        print("Input graph file '" + FLAGS.input + "' does not exist!")
        return -1

    input_graph_def = graph_pb2.GraphDef()
    with gfile.Open(FLAGS.input, "r") as f:
        data = f.read()
        if FLAGS.frozen_graph:
            input_graph_def.ParseFromString(data)
        else:
            text_format.Merge(data.decode("utf-8"), input_graph_def)

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def,
        FLAGS.input_names.split(","),
        FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)

    if FLAGS.frozen_graph:
        f = gfile.FastGFile(FLAGS.output, "w")
        f.write(output_graph_def.SerializeToString())
    else:
        graph_io.write_graph(output_graph_def,
                             os.path.dirname(FLAGS.output),
                             os.path.basename(FLAGS.output))
    return 0
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _create_temp_file(self, lines):
    tempdir = tempfile.mkdtemp()
    filename = os.path.join(tempdir, "temp_file")
    gfile.Open(filename, "w").write(lines)
    return filename
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _create_sorted_temp_files(self, lines_list):
    tempdir = tempfile.mkdtemp()
    filenames = []
    for i, lines in enumerate(lines_list):
      filename = os.path.join(tempdir, "temp_file%05d" % i)
      gfile.Open(filename, "w").write(lines)
      filenames.append(filename)
    return filenames
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _write_with_backup(filename, content):
  if gfile.Exists(filename):
    gfile.Rename(filename, filename + '.old', overwrite=True)
  with gfile.Open(filename, 'w') as f:
    f.write(content)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def shrink_csv(filename, ratio):
  """Create a smaller dataset of only 1/ratio of original data."""
  filename_small = filename.replace('.', '_small.')
  with gfile.Open(filename_small, 'w') as csv_file_small:
    writer = csv.writer(csv_file_small)
    with gfile.Open(filename) as csv_file:
      reader = csv.reader(csv_file)
      i = 0
      for row in reader:
        if i % ratio == 0:
          writer.writerow(row)
        i += 1
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def save(self, filename):
    """Saves vocabulary processor into given file.

    Args:
      filename: Path to output file.
    """
    with gfile.Open(filename, 'wb') as f:
      f.write(pickle.dumps(self))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def restore(cls, filename):
    """Restores vocabulary processor from given file.

    Args:
      filename: Path to file to load from.

    Returns:
      VocabularyProcessor object.
    """
    with gfile.Open(filename, 'rb') as f:
      return pickle.loads(f.read())
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _create_temp_file(self, lines):
    tempdir = tempfile.mkdtemp()
    filename = os.path.join(tempdir, "temp_file")
    gfile.Open(filename, "w").write(lines)
    return filename
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _create_sorted_temp_files(self, lines_list):
    tempdir = tempfile.mkdtemp()
    filenames = []
    for i, lines in enumerate(lines_list):
      filename = os.path.join(tempdir, "temp_file%05d" % i)
      gfile.Open(filename, "w").write(lines)
      filenames.append(filename)
    return filenames
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def load_csv_without_header(filename,
                            target_dtype,
                            features_dtype,
                            target_column=-1):
  """Load dataset from CSV file without a header row."""
  with gfile.Open(filename) as csv_file:
    data_file = csv.reader(csv_file)
    data, target = [], []
    for row in data_file:
      target.append(row.pop(target_column))
      data.append(np.asarray(row, dtype=features_dtype))

  target = np.array(target, dtype=target_dtype)
  data = np.array(data)
  return Dataset(data=data, target=target)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def shrink_csv(filename, ratio):
  """Create a smaller dataset of only 1/ratio of original data."""
  filename_small = filename.replace('.', '_small.')
  with gfile.Open(filename_small, 'w') as csv_file_small:
    writer = csv.writer(csv_file_small)
    with gfile.Open(filename) as csv_file:
      reader = csv.reader(csv_file)
      i = 0
      for row in reader:
        if i % ratio == 0:
          writer.writerow(row)
        i += 1
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def save(self, filename):
    """Saves vocabulary processor into given file.

    Args:
      filename: Path to output file.
    """
    with gfile.Open(filename, 'wb') as f:
      f.write(pickle.dumps(self))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def restore(cls, filename):
    """Restores vocabulary processor from given file.

    Args:
      filename: Path to file to load from.

    Returns:
      VocabularyProcessor object.
    """
    with gfile.Open(filename, 'rb') as f:
      return pickle.loads(f.read())
项目:Conditional-GAN    作者:m516825    | 项目源码 | 文件源码
def save(self, filename):
        with gfile.Open(filename, 'wb') as f:
            f.write(pickle.dumps(self))
项目:Conditional-GAN    作者:m516825    | 项目源码 | 文件源码
def restore(cls, filename):
        with gfile.Open(filename, 'rb') as f:
            return pickle.loads(f.read())
项目:gan_tensorflow    作者:dantkz    | 项目源码 | 文件源码
def extract_images(filename):
  """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
  print('Extracting', filename)
  with gfile.Open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read32(bytestream)
    if magic != 2051:
      raise ValueError('Invalid magic number %d in MNIST image file: %s' %
                       (magic, filename))
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf, dtype=numpy.uint8)
    data = data.reshape(num_images, rows, cols, 1)
    return data
项目:gan_tensorflow    作者:dantkz    | 项目源码 | 文件源码
def extract_labels(filename, one_hot=False, num_classes=10):
  """Extract the labels into a 1D uint8 numpy array [index]."""
  print('Extracting', filename)
  with gfile.Open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read32(bytestream)
    if magic != 2049:
      raise ValueError('Invalid magic number %d in MNIST label file: %s' %
                       (magic, filename))
    num_items = _read32(bytestream)
    buf = bytestream.read(num_items)
    labels = numpy.frombuffer(buf, dtype=numpy.uint8)
    if one_hot:
      return dense_to_one_hot(labels, num_classes)
    return labels
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def _create_temp_file(self, lines):
    tempdir = tempfile.mkdtemp()
    filename = os.path.join(tempdir, "temp_file")
    gfile.Open(filename, "w").write(lines)
    return filename
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def _create_sorted_temp_files(self, lines_list):
    tempdir = tempfile.mkdtemp()
    filenames = []
    for i, lines in enumerate(lines_list):
      filename = os.path.join(tempdir, "temp_file%05d" % i)
      gfile.Open(filename, "w").write(lines)
      filenames.append(filename)
    return filenames
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def load_csv_without_header(filename,
                            target_dtype,
                            features_dtype,
                            target_column=-1):
  """Load dataset from CSV file without a header row."""
  with gfile.Open(filename) as csv_file:
    data_file = csv.reader(csv_file)
    data, target = [], []
    for row in data_file:
      target.append(row.pop(target_column))
      data.append(np.asarray(row, dtype=features_dtype))

  target = np.array(target, dtype=target_dtype)
  data = np.array(data)
  return Dataset(data=data, target=target)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def shrink_csv(filename, ratio):
  """Create a smaller dataset of only 1/ratio of original data."""
  filename_small = filename.replace('.', '_small.')
  with gfile.Open(filename_small, 'w') as csv_file_small:
    writer = csv.writer(csv_file_small)
    with gfile.Open(filename) as csv_file:
      reader = csv.reader(csv_file)
      i = 0
      for row in reader:
        if i % ratio == 0:
          writer.writerow(row)
        i += 1
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def restore(cls, filename):
    """Restores vocabulary processor from given file.

    Args:
      filename: Path to file to load from.

    Returns:
      VocabularyProcessor object.
    """
    with gfile.Open(filename, 'rb') as f:
      return pickle.loads(f.read())
项目:benchmarks    作者:tensorflow    | 项目源码 | 文件源码
def benchmark_one_step(sess,
                       fetches,
                       step,
                       batch_size,
                       step_train_times,
                       trace_filename,
                       image_producer,
                       params,
                       summary_op=None):
  """Advance one step of benchmarking."""
  if trace_filename and step == -1:
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
  else:
    run_options = None
    run_metadata = None
  summary_str = None
  start_time = time.time()
  if summary_op is None:
    results = sess.run(fetches, options=run_options, run_metadata=run_metadata)
  else:
    (results, summary_str) = sess.run(
        [fetches, summary_op], options=run_options, run_metadata=run_metadata)

  if not params.forward_only:
    lossval = results['total_loss']
  else:
    lossval = 0.
  image_producer.notify_image_consumption()
  train_time = time.time() - start_time
  step_train_times.append(train_time)
  if step >= 0 and (step == 0 or (step + 1) % params.display_every == 0):
    log_str = '%i\t%s\t%.3f' % (
        step + 1, get_perf_timing_str(batch_size, step_train_times), lossval)
    if 'top_1_accuracy' in results:
      log_str += '\t%.3f\t%.3f' % (results['top_1_accuracy'],
                                   results['top_5_accuracy'])
    log_fn(log_str)
  if trace_filename and step == -1:
    log_fn('Dumping trace to %s' % trace_filename)
    trace = timeline.Timeline(step_stats=run_metadata.step_stats)
    with gfile.Open(trace_filename, 'w') as trace_file:
      trace_file.write(trace.generate_chrome_trace_format(show_memory=True))
  return summary_str
项目:benchmarks    作者:tensorflow    | 项目源码 | 文件源码
def store_data_in_json(
    stat_entries, timestamp, output_file=None, test_name=None):
  """Stores benchmark results in JSON format.

  Args:
    stat_entries: list of StatEntry objects.
    timestamp: (datetime) start time of the test run.
    output_file: if specified, writes benchmark results to output_file.
      Otherwise, if TF_DIST_BENCHMARK_RESULTS_FILE environment variable is set,
      writes to file specified by this environment variable. If neither
      output_file is passed in, nor TF_DIST_BENCHMARK_RESULTS_FILE is set,
      does nothing.
    test_name: benchmark name. This argument is required if
      TF_DIST_BENCHMARK_NAME environment variable is not set.

  Raises:
    ValueError: when neither test_name is passed in nor
      TF_DIST_BENCHMARK_NAME is set.
  """
  test_result = test_log_pb2.TestResults(
      start_time=calendar.timegm(timestamp.timetuple()))
  if not output_file:
    if _OUTPUT_FILE_ENV_VAR not in os.environ:
      logging.warning(
          'Skipping storing json output, since we could not determine '
          'location to store results at. Either output_file argument or '
          '%s environment variable needs to be set.', _OUTPUT_FILE_ENV_VAR)
      return
    output_file = os.environ[_OUTPUT_FILE_ENV_VAR]

  if test_name is not None:
    test_result.name = test_name
  elif _TEST_NAME_ENV_VAR in os.environ:
    test_result.name = os.environ[_TEST_NAME_ENV_VAR]
  else:
    raise ValueError(
        'Could not determine test name. test_name argument is not passed in '
        'and TF_DIST_BENCHMARK_NAME environment variable is not set.')

  for stat_entry in stat_entries:
    test_result.entries.entry.add(
        name=stat_entry.name,
        iters=stat_entry.num_samples,
        wall_time=stat_entry.stat_value
    )
  json_test_results = json_format.MessageToJson(test_result)

  with gfile.Open(output_file, 'wb') as jsonfile:
    jsonfile.write(json_test_results)
项目:tflearn    作者:tflearn    | 项目源码 | 文件源码
def load_csv(filepath, target_column=-1, columns_to_ignore=None,
             has_header=True, categorical_labels=False, n_classes=None):
    """ load_csv.

    Load data from a CSV file. By default the labels are considered to be the
    last column, but it can be changed by filling 'target_column' parameter.

    Arguments:
        filepath: `str`. The csv file path.
        target_column: The id of the column representing the labels.
            Default: -1 (The last column).
        columns_to_ignore: `list of int`. A list of columns index to ignore.
        has_header: `bool`. Whether the csv file has a header or not.
        categorical_labels: `bool`. If True, labels are returned as binary
            vectors (to be used with 'categorical_crossentropy').
        n_classes: `int`. Total number of class (needed if
            categorical_labels is True).

    Returns:
        A tuple (data, target).

    """

    from tensorflow.python.platform import gfile
    with gfile.Open(filepath) as csv_file:
        data_file = csv.reader(csv_file)
        if not columns_to_ignore:
            columns_to_ignore = []
        if has_header:
            header = next(data_file)
        data, target = [], []
        # Fix column to ignore ids after removing target_column
        for i, c in enumerate(columns_to_ignore):
            if c > target_column:
                columns_to_ignore[i] -= 1
        for i, d in enumerate(data_file):
            target.append(d.pop(target_column))
            data.append([_d for j, _d in enumerate(d) if j not in columns_to_ignore])
        if categorical_labels:
            assert isinstance(n_classes, int), "n_classes not specified!"
            target = to_categorical(target, n_classes)
        return data, target
项目:sonnet    作者:deepmind    | 项目源码 | 文件源码
def __init__(self, num_steps=1, batch_size=1,
               subset="train", random=False, dtype=tf.float32,
               name="tiny_shakespeare_dataset"):
    """Initializes a TinyShakespeare sequence data object.

    Args:
      num_steps: sequence_length.
      batch_size: batch size.
      subset: 'train', 'valid' or 'test'.
      random: boolean indicating whether to do random sampling of sequences.
        Default is false (sequential sampling).
      dtype: type of generated tensors (both observations and targets).
      name: object name.

    Raises:
      ValueError: if subset is not train, valid or test.
    """

    if subset not in [self.TRAIN, self.VALID, self.TEST]:
      raise ValueError("subset should be %s, %s, or %s. Received %s instead."
                       % (self.TRAIN, self.VALID, self.TEST, subset))

    super(TinyShakespeareDataset, self).__init__(name=name)

    # Generate vocab from train set.

    self._vocab_file = gfile.Open(
        os.path.join(self._RESOURCE_ROOT, "ts.train.txt"))
    self._data_file = gfile.Open(
        os.path.join(self._RESOURCE_ROOT, "ts.{}.txt".format(subset)))
    self._num_steps = num_steps
    self._batch_size = batch_size
    self._random_sampling = random
    self._dtype = dtype

    self._data_source = TokenDataSource(
        data_file=self._data_file,
        vocab_data_file=self._vocab_file)

    self._vocab_size = self._data_source.vocab_size
    self._flat_data = self._data_source.flat_data
    self._n_flat_elements = self._data_source.num_tokens

    self._num_batches = self._n_flat_elements // (self._num_steps * batch_size)
    self._reset_head_indices()

    self._queue_capacity = 10
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def restore(cls, path, config=None):
    # pylint: disable=unused-argument
    """Restores model from give path.

    Args:
      path: Path to the checkpoints and other model information.
      config: RunConfig object that controls the configurations of the session,
        e.g. num_cores, gpu_memory_fraction, etc. This is allowed to be
          reconfigured.

    Returns:
      Estimator, object of the subclass of TensorFlowEstimator.

    Raises:
      ValueError: if `path` does not contain a model definition.
    """
    model_def_filename = os.path.join(path, 'model.def')
    if not os.path.exists(model_def_filename):
      raise ValueError("Restore folder doesn't contain model definition.")
    # list of parameters that are allowed to be reconfigured
    reconfigurable_params = ['_config']
    _config = config  # pylint: disable=unused-variable,invalid-name
    with gfile.Open(model_def_filename) as fmodel:
      model_def = json.loads(fmodel.read())
      # TensorFlow binding requires parameters to be strings not unicode.
      # Only issue in Python2.
      for key, value in model_def.items():
        if isinstance(value, string_types) and not isinstance(value, str):
          model_def[key] = str(value)
        if key in reconfigurable_params:
          new_value = locals()[key]
          if new_value is not None:
            model_def[key] = new_value

    class_name = model_def.pop('class_name')
    if class_name == 'TensorFlowEstimator':
      custom_estimator = TensorFlowEstimator(model_fn=None, **model_def)
      # pylint: disable=protected-access
      custom_estimator._restore(path)
      return custom_estimator

    # To avoid cyclical dependencies, import inside the function instead of
    # the beginning of the file.
    # pylint: disable=g-import-not-at-top
    from tensorflow.contrib.learn.python.learn import estimators
    # Estimator must be one of the defined estimators in the __init__ file.
    result = getattr(estimators, class_name)(**model_def)
    # pylint: disable=protected-access
    result._restore(path)
    return result
项目:tensorflow-for-poets-2    作者:googlecodelabs    | 项目源码 | 文件源码
def main(unused_args):
  if not gfile.Exists(FLAGS.input):
    print("Input graph file '" + FLAGS.input + "' does not exist!")
    return -1

  known_modes = [
      "round", "quantize", "eightbit", "weights", "test", "weights_rounded"
  ]
  if not any(FLAGS.mode in s for s in known_modes):
    print("mode is '" + FLAGS.mode + "', not in " + ", ".join(known_modes) +
          ".")
    return -1

  tf_graph = graph_pb2.GraphDef()
  with gfile.Open(FLAGS.input, "rb") as f:
    data = f.read()
    tf_graph.ParseFromString(data)

  graph = ops.Graph()
  with graph.as_default():
    importer.import_graph_def(tf_graph, input_map={}, name="")

  quantized_input_range = None
  if FLAGS.quantized_input:
    quantized_input_range = [
        FLAGS.quantized_input_min, FLAGS.quantized_input_max
    ]

  fallback_quantization_range = None
  if (FLAGS.quantized_fallback_min is not None or
      FLAGS.quantized_fallback_max is not None):
    assert FLAGS.quantized_fallback_min is not None
    assert FLAGS.quantized_fallback_max is not None
    fallback_quantization_range = [
        FLAGS.quantized_fallback_min, FLAGS.quantized_fallback_max
    ]

  rewriter = GraphRewriter(tf_graph, FLAGS.mode, quantized_input_range,
                           fallback_quantization_range)

  output_graph = rewriter.rewrite(FLAGS.output_node_names.split(","))

  f = gfile.FastGFile(FLAGS.output, "wb")
  f.write(output_graph.SerializeToString())

  return 0
项目:MobileNet    作者:Zehaos    | 项目源码 | 文件源码
def main(unused_args):
  if not gfile.Exists(FLAGS.input):
    print("Input graph file '" + FLAGS.input + "' does not exist!")
    return -1

  known_modes = [
      "round", "quantize", "eightbit", "weights", "test", "weights_rounded"
  ]
  if not any(FLAGS.mode in s for s in known_modes):
    print("mode is '" + FLAGS.mode + "', not in " + ", ".join(known_modes) +
          ".")
    return -1

  tf_graph = graph_pb2.GraphDef()
  with gfile.Open(FLAGS.input, "rb") as f:
    data = f.read()
    tf_graph.ParseFromString(data)

  graph = ops.Graph()
  with graph.as_default():
    importer.import_graph_def(tf_graph, input_map={}, name="")

  quantized_input_range = None
  if FLAGS.quantized_input:
    quantized_input_range = [
        FLAGS.quantized_input_min, FLAGS.quantized_input_max
    ]

  fallback_quantization_range = None
  if (FLAGS.quantized_fallback_min is not None or
      FLAGS.quantized_fallback_max is not None):
    assert FLAGS.quantized_fallback_min is not None
    assert FLAGS.quantized_fallback_max is not None
    fallback_quantization_range = [
        FLAGS.quantized_fallback_min, FLAGS.quantized_fallback_max
    ]

  rewriter = GraphRewriter(tf_graph, FLAGS.mode, quantized_input_range,
                           fallback_quantization_range)

  output_graph = rewriter.rewrite(FLAGS.output_node_names.split(","))

  f = gfile.FastGFile(FLAGS.output, "wb")
  f.write(output_graph.SerializeToString())

  return 0
项目:stuff    作者:yaroslavvb    | 项目源码 | 文件源码
def benchmark_one_step(sess,
                       fetches,
                       step,
                       batch_size,
                       step_train_times,
                       trace_filename,
                       image_producer,
                       params,
                       summary_op=None):
  """Advance one step of benchmarking."""
  if trace_filename is not None and step == -1:
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
  else:
    run_options = None
    run_metadata = None
  summary_str = None
  start_time = time.time()
  if summary_op is None:
    results = sess.run(fetches, options=run_options, run_metadata=run_metadata)
  else:
    (results, summary_str) = sess.run(
        [fetches, summary_op], options=run_options, run_metadata=run_metadata)

  if not params.forward_only:
    lossval = results['total_loss']
  else:
    lossval = 0.
  image_producer.notify_image_consumption()
  train_time = time.time() - start_time
  step_train_times.append(train_time)
  if step >= 0 and (step == 0 or (step + 1) % params.display_every == 0):
    log_str = '%i\t%s\t%.3f' % (
        step + 1, get_perf_timing_str(batch_size, step_train_times), lossval)
    if 'top_1_accuracy' in results:
      log_str += '\t%.3f\t%.3f' % (results['top_1_accuracy'],
                                   results['top_5_accuracy'])
    log_fn(log_str)
  if trace_filename is not None and step == -1:
    log_fn('Dumping trace to %s' % trace_filename)
    trace = timeline.Timeline(step_stats=run_metadata.step_stats)
    with gfile.Open(trace_filename, 'w') as trace_file:
      trace_file.write(trace.generate_chrome_trace_format(show_memory=True))
  return summary_str
项目:stuff    作者:yaroslavvb    | 项目源码 | 文件源码
def store_data_in_json(
    stat_entries, timestamp, output_file=None, test_name=None):
  """Stores benchmark results in JSON format.

  Args:
    stat_entries: list of StatEntry objects.
    timestamp: (datetime) start time of the test run.
    output_file: if specified, writes benchmark results to output_file.
      Otherwise, if TF_DIST_BENCHMARK_RESULTS_FILE environment variable is set,
      writes to file specified by this environment variable. If neither
      output_file is passed in, nor TF_DIST_BENCHMARK_RESULTS_FILE is set,
      does nothing.
    test_name: benchmark name. This argument is required if
      TF_DIST_BENCHMARK_NAME environment variable is not set.

  Raises:
    ValueError: when neither test_name is passed in nor
      TF_DIST_BENCHMARK_NAME is set.
  """
  test_result = test_log_pb2.TestResults(
      start_time=calendar.timegm(timestamp.timetuple()))
  if not output_file:
    if _OUTPUT_FILE_ENV_VAR not in os.environ:
      logging.warning(
          'Skipping storing json output, since we could not determine '
          'location to store results at. Either output_file argument or '
          '%s environment variable needs to be set.', _OUTPUT_FILE_ENV_VAR)
      return
    output_file = os.environ[_OUTPUT_FILE_ENV_VAR]

  if test_name is not None:
    test_result.name = test_name
  elif _TEST_NAME_ENV_VAR in os.environ:
    test_result.name = os.environ[_TEST_NAME_ENV_VAR]
  else:
    raise ValueError(
        'Could not determine test name. test_name argument is not passed in '
        'and TF_DIST_BENCHMARK_NAME environment variable is not set.')

  for stat_entry in stat_entries:
    test_result.entries.entry.add(
        name=stat_entry.name,
        iters=stat_entry.num_samples,
        wall_time=stat_entry.stat_value
    )
  json_test_results = json_format.MessageToJson(test_result)

  with gfile.Open(output_file, 'wb') as jsonfile:
    jsonfile.write(json_test_results)
项目:imgrec    作者:Marsan-Ma    | 项目源码 | 文件源码
def load_csv(filepath, target_column=-1, columns_to_ignore=None,
             has_header=True, categorical_labels=False, n_classes=None):
    """ load_csv.

    Load data from a CSV file. By default the labels are considered to be the
    last column, but it can be changed by filling 'target_column' parameter.

    Arguments:
        filepath: `str`. The csv file path.
        target_column: The id of the column representing the labels.
            Default: -1 (The last column).
        columns_to_ignore: `list of int`. A list of columns index to ignore.
        has_header: `bool`. Whether the csv file has a header or not.
        categorical_labels: `bool`. If True, labels are returned as binary
            vectors (to be used with 'categorical_crossentropy').
        n_classes: `int`. Total number of class (needed if
            categorical_labels is True).

    Returns:
        A tuple (data, target).

    """

    from tensorflow.python.platform import gfile
    with gfile.Open(filepath) as csv_file:
        data_file = csv.reader(csv_file)
        if not columns_to_ignore:
            columns_to_ignore = []
        if has_header:
            header = next(data_file)
        data, target = [], []
        # Fix column to ignore ids after removing target_column
        for i, c in enumerate(columns_to_ignore):
            if c > target_column:
                columns_to_ignore[i] -= 1
        for i, d in enumerate(data_file):
            target.append(d.pop(target_column))
            data.append([_d for j, _d in enumerate(d) if j not in columns_to_ignore])
        if categorical_labels:
            assert isinstance(n_classes, int), "n_classes not specified!"
            target = to_categorical(target, n_classes)
        return data, target