Python tensorflow.python.platform.gfile 模块,Glob() 实例源码

我们从Python开源项目中,提取了以下47个代码示例,用于说明如何使用tensorflow.python.platform.gfile.Glob()

项目:dahoam2017    作者:KarimJedda    | 项目源码 | 文件源码
def create_data_list(image_dir):
  if not gfile.Exists(image_dir):
    print("Image director '" + image_dir + "' not found.")
    return None
  extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']
  print("Looking for images in '" + image_dir + "'")
  file_list = []
  for extension in extensions:
    file_glob = os.path.join(image_dir, '*.' + extension)
    file_list.extend(gfile.Glob(file_glob))
  if not file_list:
    print("No files found in '" + image_dir + "'")
    return None
  images = []
  labels = []
  for file_name in file_list:
    image = Image.open(file_name)
    image_gray = image.convert('L')
    image_resize = image_gray.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT))
    input_img = np.array(image_resize, dtype='int16')
    image.close()
    label_name = os.path.basename(file_name).split('_')[0]
    images.append(input_img)
    labels.append(label_name)
  return zip(images, labels)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_read_text_lines_multifile(self):
    gfile.Glob = self._orig_glob
    filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      inputs = tf.contrib.learn.io.read_batch_examples(
          filenames, batch_size, reader=tf.TextLineReader,
          randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
          name=name)
      session.run(tf.initialize_local_variables())

      coord = tf.train.Coordinator()
      tf.train.start_queue_runners(session, coord=coord)

      self.assertAllEqual(session.run(inputs), [b"ABC"])
      self.assertAllEqual(session.run(inputs), [b"DEF"])
      self.assertAllEqual(session.run(inputs), [b"GHK"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_batch_text_lines(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file("A\nB\nC\nD\nE\n")

    batch_size = 3
    queue_capacity = 10
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      inputs = tf.contrib.learn.io.read_batch_examples(
          [filename], batch_size, reader=tf.TextLineReader,
          randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
          read_batch_size=10, name=name)
      session.run(tf.initialize_local_variables())

      coord = tf.train.Coordinator()
      tf.train.start_queue_runners(session, coord=coord)

      self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
      self.assertAllEqual(session.run(inputs), [b"D", b"E"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
项目:Machine-Learning    作者:sfeng15    | 项目源码 | 文件源码
def restore(self, sess, save_path):
    """Restores previously saved variables.

    This method runs the ops added by the constructor for restoring variables.
    It requires a session in which the graph was launched.  The variables to
    restore do not have to have been initialized, as restoring is itself a way
    to initialize variables.

    The `save_path` argument is typically a value previously returned from a
    `save()` call, or a call to `latest_checkpoint()`.

    Args:
      sess: A `Session` to use to restore the parameters.
      save_path: Path where parameters were previously saved.

    Raises:
      ValueError: If the given `save_path` does not point to a file.
    """
    if not gfile.Glob(save_path):
      raise ValueError("Restore called with invalid save path %s" % save_path)
    sess.run(self.saver_def.restore_op_name,
             {self.saver_def.filename_tensor_name: save_path})
项目:Machine-Learning    作者:sfeng15    | 项目源码 | 文件源码
def latest_checkpoint(checkpoint_dir, latest_filename=None):
  """Finds the filename of latest saved checkpoint file.

  Args:
    checkpoint_dir: Directory where the variables were saved.
    latest_filename: Optional name for the protocol buffer file that
      contains the list of most recent checkpoint filenames.
      See the corresponding argument to `Saver.save()`.

  Returns:
    The full path to the latest checkpoint or `None` if no checkpoint was found.
  """
  # Pick the latest checkpoint based on checkpoint state.
  ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
  if ckpt and ckpt.model_checkpoint_path:
    if gfile.Glob(ckpt.model_checkpoint_path):
      return ckpt.model_checkpoint_path

  return None
项目:captcha_recognize    作者:PatrickLib    | 项目源码 | 文件源码
def create_data_list(image_dir):
  if not gfile.Exists(image_dir):
    print("Image director '" + image_dir + "' not found.")
    return None
  extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']
  print("Looking for images in '" + image_dir + "'")
  file_list = []
  for extension in extensions:
    file_glob = os.path.join(image_dir, '*.' + extension)
    file_list.extend(gfile.Glob(file_glob))
  if not file_list:
    print("No files found in '" + image_dir + "'")
    return None
  images = []
  labels = []
  for file_name in file_list:
    image = Image.open(file_name)
    image_gray = image.convert('L')
    image_resize = image_gray.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT))
    input_img = np.array(image_resize, dtype='int16')
    image.close()
    label_name = os.path.basename(file_name).split('_')[0]
    images.append(input_img)
    labels.append(label_name)
  return zip(images, labels)
项目:dahoam2017    作者:KarimJedda    | 项目源码 | 文件源码
def input_data(image_dir):
  if not gfile.Exists(image_dir):
    print(">> Image director '" + image_dir + "' not found.")
    return None
  extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']
  print(">> Looking for images in '" + image_dir + "'")
  file_list = []
  for extension in extensions:
    file_glob = os.path.join(image_dir, '*.' + extension)
    file_list.extend(gfile.Glob(file_glob))
  if not file_list:
    print(">> No files found in '" + image_dir + "'")
    return None
  batch_size = len(file_list)
  images = np.zeros([batch_size, IMAGE_HEIGHT*IMAGE_WIDTH], dtype='float32')
  files = []
  i = 0
  for file_name in file_list:
    image = Image.open(file_name)
    image_gray = image.convert('L')
    image_resize = image_gray.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT))
    image.close()
    input_img = np.array(image_resize, dtype='float32')
    input_img = np.multiply(input_img.flatten(), 1./255) - 0.5    
    images[i,:] = input_img
    base_name = os.path.basename(file_name)
    files.append(base_name)
    i += 1
  return images, files
项目:Youtube8mdataset_kagglechallenge    作者:jasonlee27    | 项目源码 | 文件源码
def get_data(data_path,
             data_usedfor,
             data_lvl,
             feature_type="rgb",
             preprocess=None,
             shuffle=True,
             num_epochs=1):
    files_pattern = data_usedfor+"*.tfrecord"
    data_files = gfile.Glob(data_path + files_pattern)
    filename_queue = tf.train.string_input_producer(data_files, num_epochs=num_epochs, shuffle=shuffle)
    tfrecord_list = tfrecord_reader(filename_queue, data_lvl)
    vids = np.array([tfrecord_list[i][GLOBAL_FEAT_NAMES[0]] for i, _ in enumerate(tfrecord_list)])
    labels = np.array([tfrecord_list[i][GLOBAL_FEAT_NAMES[1]] for i, _ in enumerate(tfrecord_list)])

    if data_lvl == "video":
        if feature_type == "rgb":
            X = [tfrecord_list[i][VID_LVL_FEAT_NAMES[0]] for i, _ in enumerate(tfrecord_list)]
        elif feature_type == "audio":
            X = [tfrecord_list[i][VID_LVL_FEAT_NAMES[1]] for i, _ in enumerate(tfrecord_list)]
    elif data_lvl == "frame":
        if feature_type == "rgb":
            X = [tfrecord_list[i][FRM_LVL_FEAT_NAMES[0]] for i, _ in enumerate(tfrecord_list)]
            #X = [np.concatenate((tfrecord_list[i][FRM_LVL_FEAT_NAMES[0]],
            #                        get_framediff(tfrecord_list[i][FRM_LVL_FEAT_NAMES[0]])))
            #        for i, _ in enumerate(tfrecord_list)]
        elif feature_type == "audio":
            X = [tfrecord_list[i][FRM_LVL_FEAT_NAMES[1]] for i, _ in enumerate(tfrecord_list)]
    Y = to_multi_categorical(labels, NUM_CLASSES)
    print "get_data done."
    return X, Y
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def read_batch_record_features(file_pattern, batch_size, features,
                               randomize_input=True, num_epochs=None,
                               queue_capacity=10000, reader_num_threads=1,
                               parser_num_threads=1,
                               name='dequeue_record_examples'):
  """Reads TFRecord, queues, batches and parses `Example` proto.

  See more detailed description in `read_examples`.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    features: A `dict` mapping feature keys to `FixedLenFeature` or
      `VarLenFeature` values.
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If None, cycles through the dataset forever. NOTE - If specified,
      creates a variable that must be initialized, so call
      tf.initialize_local_variables() as shown in the tests.
    queue_capacity: Capacity for input queue.
    reader_num_threads: The number of threads to read examples.
    parser_num_threads: The number of threads to parse examples.
    name: Name of resulting op.

  Returns:
    A dict of `Tensor` or `SparseTensor` objects for each in `features`.

  Raises:
    ValueError: for invalid inputs.
  """
  return read_batch_features(
      file_pattern=file_pattern, batch_size=batch_size, features=features,
      reader=io_ops.TFRecordReader,
      randomize_input=randomize_input, num_epochs=num_epochs,
      queue_capacity=queue_capacity, reader_num_threads=reader_num_threads,
      parser_num_threads=parser_num_threads, name=name)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def setUp(self):
    super(GraphIOTest, self).setUp()
    random.seed(42)
    self._orig_glob = gfile.Glob
    gfile.Glob = self._mock_glob
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def tearDown(self):
    gfile.Glob = self._orig_glob
    super(GraphIOTest, self).tearDown()
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_keyed_read_text_lines(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file("ABC\nDEF\nGHK\n")

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      keys, inputs = tf.contrib.learn.io.read_keyed_batch_examples(
          filename, batch_size,
          reader=tf.TextLineReader, randomize_input=False,
          num_epochs=1, queue_capacity=queue_capacity, name=name)
      session.run(tf.initialize_local_variables())

      coord = tf.train.Coordinator()
      tf.train.start_queue_runners(session, coord=coord)

      self.assertAllEqual(session.run([keys, inputs]),
                          [[filename.encode("utf-8") + b":1"], [b"ABC"]])
      self.assertAllEqual(session.run([keys, inputs]),
                          [[filename.encode("utf-8") + b":2"], [b"DEF"]])
      self.assertAllEqual(session.run([keys, inputs]),
                          [[filename.encode("utf-8") + b":3"], [b"GHK"]])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_keyed_parse_json(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file(
        '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}\n'
        '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}\n'
        '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}\n'
    )

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      dtypes = {"age": tf.FixedLenFeature([1], tf.int64)}
      parse_fn = lambda example: tf.parse_single_example(  # pylint: disable=g-long-lambda
          tf.decode_json_example(example), dtypes)
      keys, inputs = tf.contrib.learn.io.read_keyed_batch_examples(
          filename, batch_size,
          reader=tf.TextLineReader, randomize_input=False,
          num_epochs=1, queue_capacity=queue_capacity,
          parse_fn=parse_fn, name=name)
      session.run(tf.initialize_local_variables())

      coord = tf.train.Coordinator()
      tf.train.start_queue_runners(session, coord=coord)

      key, age = session.run([keys, inputs["age"]])
      self.assertAllEqual(age, [[0]])
      self.assertAllEqual(key, [filename.encode("utf-8") + b":1"])
      key, age = session.run([keys, inputs["age"]])
      self.assertAllEqual(age, [[1]])
      self.assertAllEqual(key, [filename.encode("utf-8") + b":2"])
      key, age = session.run([keys, inputs["age"]])
      self.assertAllEqual(age, [[2]])
      self.assertAllEqual(key, [filename.encode("utf-8") + b":3"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _expand_file_names(filepatterns):
  """Takes a list of file patterns and returns a list of resolved file names."""
  if not isinstance(filepatterns, (list, tuple, set)):
    filepatterns = [filepatterns]
  filenames = set()
  for filepattern in filepatterns:
    names = set(gfile.Glob(filepattern))
    filenames |= names
  return list(filenames)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_file_names(file_pattern, randomize_input):
  """Parse list of file names from pattern, optionally shuffled.

  Args:
    file_pattern: File glob pattern, or list of strings.
    randomize_input: Whether to shuffle the order of file names.

  Returns:
    List of file names matching `file_pattern`.

  Raises:
    ValueError: If `file_pattern` is empty, or pattern matches no files.
  """
  if isinstance(file_pattern, list):
    file_names = file_pattern
    if not file_names:
      raise ValueError('No files given to dequeue_examples.')
  else:
    file_names = list(gfile.Glob(file_pattern))
    if not file_names:
      raise ValueError('No files match %s.' % file_pattern)

  # Sort files so it will be deterministic for unit tests. They'll be shuffled
  # in `string_input_producer` if `randomize_input` is enabled.
  if not randomize_input:
    file_names = sorted(file_names)
  return file_names
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def read_batch_record_features(file_pattern, batch_size, features,
                               randomize_input=True, num_epochs=None,
                               queue_capacity=10000, reader_num_threads=1,
                               name='dequeue_record_examples'):
  """Reads TFRecord, queues, batches and parses `Example` proto.

  See more detailed description in `read_examples`.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    features: A `dict` mapping feature keys to `FixedLenFeature` or
      `VarLenFeature` values.
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If None, cycles through the dataset forever. NOTE - If specified,
      creates a variable that must be initialized, so call
      tf.local_variables_initializer() as shown in the tests.
    queue_capacity: Capacity for input queue.
    reader_num_threads: The number of threads to read examples.
    name: Name of resulting op.

  Returns:
    A dict of `Tensor` or `SparseTensor` objects for each in `features`.

  Raises:
    ValueError: for invalid inputs.
  """
  return read_batch_features(
      file_pattern=file_pattern,
      batch_size=batch_size,
      features=features,
      reader=io_ops.TFRecordReader,
      randomize_input=randomize_input,
      num_epochs=num_epochs,
      queue_capacity=queue_capacity,
      reader_num_threads=reader_num_threads,
      name=name)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def tearDown(self):
    gfile.Glob = self._orig_glob
    super(GraphIOTest, self).tearDown()
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_read_text_lines(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file("ABC\nDEF\nGHK\n")

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      inputs = tf.contrib.learn.io.read_batch_examples(
          filename, batch_size, reader=tf.TextLineReader,
          randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
          name=name)
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(tf.local_variables_initializer())

      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(session, coord=coord)

      self.assertAllEqual(session.run(inputs), [b"ABC"])
      self.assertAllEqual(session.run(inputs), [b"DEF"])
      self.assertAllEqual(session.run(inputs), [b"GHK"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_read_text_lines_multifile(self):
    gfile.Glob = self._orig_glob
    filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      inputs = tf.contrib.learn.io.read_batch_examples(
          filenames, batch_size, reader=tf.TextLineReader,
          randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
          name=name)
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(tf.local_variables_initializer())

      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(session, coord=coord)

      self.assertEqual("%s:1" % name, inputs.name)
      file_name_queue_name = "%s/file_name_queue" % name
      file_names_name = "%s/input" % file_name_queue_name
      example_queue_name = "%s/fifo_queue" % name
      test_util.assert_ops_in_graph({
          file_names_name: "Const",
          file_name_queue_name: "FIFOQueue",
          "%s/read/TextLineReader" % name: "TextLineReader",
          example_queue_name: "FIFOQueue",
          name: "QueueDequeueUpTo"
      }, g)

      self.assertAllEqual(session.run(inputs), [b"ABC"])
      self.assertAllEqual(session.run(inputs), [b"DEF"])
      self.assertAllEqual(session.run(inputs), [b"GHK"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_keyed_read_text_lines(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file("ABC\nDEF\nGHK\n")

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      keys, inputs = tf.contrib.learn.io.read_keyed_batch_examples(
          filename, batch_size,
          reader=tf.TextLineReader, randomize_input=False,
          num_epochs=1, queue_capacity=queue_capacity, name=name)
      self.assertAllEqual((None,), keys.get_shape().as_list())
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(tf.local_variables_initializer())

      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(session, coord=coord)

      self.assertAllEqual(session.run([keys, inputs]),
                          [[filename.encode("utf-8") + b":1"], [b"ABC"]])
      self.assertAllEqual(session.run([keys, inputs]),
                          [[filename.encode("utf-8") + b":2"], [b"DEF"]])
      self.assertAllEqual(session.run([keys, inputs]),
                          [[filename.encode("utf-8") + b":3"], [b"GHK"]])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _expand_file_names(filepatterns):
  """Takes a list of file patterns and returns a list of resolved file names."""
  if not isinstance(filepatterns, (list, tuple, set)):
    filepatterns = [filepatterns]
  filenames = set()
  for filepattern in filepatterns:
    names = set(gfile.Glob(filepattern))
    filenames |= names
  return list(filenames)
项目:captcha_recognize    作者:PatrickLib    | 项目源码 | 文件源码
def input_data(image_dir):
  if not gfile.Exists(image_dir):
    print(">> Image director '" + image_dir + "' not found.")
    return None
  extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']
  print(">> Looking for images in '" + image_dir + "'")
  file_list = []
  for extension in extensions:
    file_glob = os.path.join(image_dir, '*.' + extension)
    file_list.extend(gfile.Glob(file_glob))
  if not file_list:
    print(">> No files found in '" + image_dir + "'")
    return None
  batch_size = len(file_list)
  images = np.zeros([batch_size, IMAGE_HEIGHT*IMAGE_WIDTH], dtype='float32')
  files = []
  i = 0
  for file_name in file_list:
    image = Image.open(file_name)
    image_gray = image.convert('L')
    image_resize = image_gray.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT))
    image.close()
    input_img = np.array(image_resize, dtype='float32')
    input_img = np.multiply(input_img.flatten(), 1./255) - 0.5    
    images[i,:] = input_img
    base_name = os.path.basename(file_name)
    files.append(base_name)
    i += 1
  return images, files
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def _get_file_names(file_pattern, randomize_input):
  """Parse list of file names from pattern, optionally shuffled.

  Args:
    file_pattern: File glob pattern, or list of strings.
    randomize_input: Whether to shuffle the order of file names.

  Returns:
    List of file names matching `file_pattern`.

  Raises:
    ValueError: If `file_pattern` is empty, or pattern matches no files.
  """
  if isinstance(file_pattern, list):
    file_names = file_pattern
    if not file_names:
      raise ValueError('No files given to dequeue_examples.')
  else:
    file_names = list(gfile.Glob(file_pattern))
    if not file_names:
      raise ValueError('No files match %s.' % file_pattern)

  # Sort files so it will be deterministic for unit tests. They'll be shuffled
  # in `string_input_producer` if `randomize_input` is enabled.
  if not randomize_input:
    file_names = sorted(file_names)
  return file_names
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def setUp(self):
    super(GraphIOTest, self).setUp()
    random.seed(42)
    self._orig_glob = gfile.Glob
    gfile.Glob = self._mock_glob
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def tearDown(self):
    gfile.Glob = self._orig_glob
    super(GraphIOTest, self).tearDown()
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def test_read_text_lines(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file("ABC\nDEF\nGHK\n")

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
      inputs = graph_io.read_batch_examples(
          filename,
          batch_size,
          reader=io_ops.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          name=name)
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(variables.local_variables_initializer())

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(session, coord=coord)

      self.assertAllEqual(session.run(inputs), [b"ABC"])
      self.assertAllEqual(session.run(inputs), [b"DEF"])
      self.assertAllEqual(session.run(inputs), [b"GHK"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def test_batch_text_lines(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file("A\nB\nC\nD\nE\n")

    batch_size = 3
    queue_capacity = 10
    name = "my_batch"

    with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
      inputs = graph_io.read_batch_examples(
          [filename],
          batch_size,
          reader=io_ops.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          read_batch_size=10,
          name=name)
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(variables.local_variables_initializer())

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(session, coord=coord)

      self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
      self.assertAllEqual(session.run(inputs), [b"D", b"E"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def test_keyed_read_text_lines(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file("ABC\nDEF\nGHK\n")

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
      keys, inputs = graph_io.read_keyed_batch_examples(
          filename,
          batch_size,
          reader=io_ops.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          name=name)
      self.assertAllEqual((None,), keys.get_shape().as_list())
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(variables.local_variables_initializer())

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(session, coord=coord)

      self.assertAllEqual(
          session.run([keys, inputs]),
          [[filename.encode("utf-8") + b":1"], [b"ABC"]])
      self.assertAllEqual(
          session.run([keys, inputs]),
          [[filename.encode("utf-8") + b":2"], [b"DEF"]])
      self.assertAllEqual(
          session.run([keys, inputs]),
          [[filename.encode("utf-8") + b":3"], [b"GHK"]])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def _expand_file_names(filepatterns):
  """Takes a list of file patterns and returns a list of resolved file names."""
  if not isinstance(filepatterns, (list, tuple, set)):
    filepatterns = [filepatterns]
  filenames = set()
  for filepattern in filepatterns:
    names = set(gfile.Glob(filepattern))
    filenames |= names
  return list(filenames)
项目:PlantImageRecognition    作者:HeavenMin    | 项目源码 | 文件源码
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
项目:PlantImageRecognition    作者:HeavenMin    | 项目源码 | 文件源码
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
项目:PlantImageRecognition    作者:HeavenMin    | 项目源码 | 文件源码
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
项目:PlantImageRecognition    作者:HeavenMin    | 项目源码 | 文件源码
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
项目:PlantImageRecognition    作者:HeavenMin    | 项目源码 | 文件源码
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def read_batch_examples(file_pattern, batch_size, reader,
                        randomize_input=True, num_epochs=None,
                        queue_capacity=10000, num_threads=1,
                        read_batch_size=1, parse_fn=None,
                        name=None):
  """Adds operations to read, queue, batch `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size`.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Use `parse_fn` if you need to do parsing / processing on single examples.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If `None`, cycles through the dataset forever.
      NOTE - If specified, creates a variable that must be initialized, so call
      `tf.initialize_all_variables()` as shown in the tests.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    read_batch_size: An int or scalar `Tensor` specifying the number of
      records to read at once
    parse_fn: Parsing function, takes `Example` Tensor returns parsed
      representation. If `None`, no parsing is done.
    name: Name of resulting op.

  Returns:
    String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
  _, examples = read_keyed_batch_examples(
      file_pattern=file_pattern, batch_size=batch_size, reader=reader,
      randomize_input=randomize_input, num_epochs=num_epochs,
      queue_capacity=queue_capacity, num_threads=num_threads,
      read_batch_size=read_batch_size, parse_fn=parse_fn, name=name)
  return examples
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def read_batch_features(file_pattern, batch_size, features, reader,
                        randomize_input=True, num_epochs=None,
                        queue_capacity=10000, feature_queue_capacity=100,
                        reader_num_threads=1, parser_num_threads=1,
                        parse_fn=None, name=None):
  """Adds operations to read, queue, batch and parse `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size` and parse example given `features`
  specification.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    features: A `dict` mapping feature keys to `FixedLenFeature` or
      `VarLenFeature` values.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If None, cycles through the dataset forever. NOTE - If specified,
      creates a variable that must be initialized, so call
      tf.initialize_local_variables() as shown in the tests.
    queue_capacity: Capacity for input queue.
    feature_queue_capacity: Capacity of the parsed features queue. Set this
      value to a small number, for example 5 if the parsed features are large.
    reader_num_threads: The number of threads to read examples.
    parser_num_threads: The number of threads to parse examples.
      records to read at once
    parse_fn: Parsing function, takes `Example` Tensor returns parsed
      representation. If `None`, no parsing is done.
    name: Name of resulting op.

  Returns:
    A dict of `Tensor` or `SparseTensor` objects for each in `features`.

  Raises:
    ValueError: for invalid inputs.
  """
  _, features = read_keyed_batch_features(
      file_pattern, batch_size, features, reader,
      randomize_input=randomize_input, num_epochs=num_epochs,
      queue_capacity=queue_capacity,
      feature_queue_capacity=feature_queue_capacity,
      reader_num_threads=reader_num_threads,
      parser_num_threads=parser_num_threads,
      parse_fn=parse_fn, name=name)
  return features
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def read_batch_examples(file_pattern, batch_size, reader,
                        randomize_input=True, num_epochs=None,
                        queue_capacity=10000, num_threads=1,
                        read_batch_size=1, parse_fn=None,
                        name=None):
  """Adds operations to read, queue, batch `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size`.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Use `parse_fn` if you need to do parsing / processing on single examples.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If `None`, cycles through the dataset forever.
      NOTE - If specified, creates a variable that must be initialized, so call
      `tf.global_variables_initializer()` as shown in the tests.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    read_batch_size: An int or scalar `Tensor` specifying the number of
      records to read at once
    parse_fn: Parsing function, takes `Example` Tensor returns parsed
      representation. If `None`, no parsing is done.
    name: Name of resulting op.

  Returns:
    String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
  _, examples = read_keyed_batch_examples(
      file_pattern=file_pattern, batch_size=batch_size, reader=reader,
      randomize_input=randomize_input, num_epochs=num_epochs,
      queue_capacity=queue_capacity, num_threads=num_threads,
      read_batch_size=read_batch_size, parse_fn=parse_fn, name=name)
  return examples
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def read_keyed_batch_examples(
    file_pattern, batch_size, reader,
    randomize_input=True, num_epochs=None,
    queue_capacity=10000, num_threads=1,
    read_batch_size=1, parse_fn=None,
    name=None):
  """Adds operations to read, queue, batch `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size`.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Use `parse_fn` if you need to do parsing / processing on single examples.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If `None`, cycles through the dataset forever.
      NOTE - If specified, creates a variable that must be initialized, so call
      `tf.global_variables_initializer()` as shown in the tests.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    read_batch_size: An int or scalar `Tensor` specifying the number of
      records to read at once
    parse_fn: Parsing function, takes `Example` Tensor returns parsed
      representation. If `None`, no parsing is done.
    name: Name of resulting op.

  Returns:
    Returns tuple of:
    - `Tensor` of string keys.
    - String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
  return _read_keyed_batch_examples_helper(
      file_pattern,
      batch_size,
      reader,
      randomize_input,
      num_epochs,
      queue_capacity,
      num_threads,
      read_batch_size,
      parse_fn,
      setup_shared_queue=False,
      name=name)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def read_batch_features(file_pattern,
                        batch_size,
                        features,
                        reader,
                        randomize_input=True,
                        num_epochs=None,
                        queue_capacity=10000,
                        feature_queue_capacity=100,
                        reader_num_threads=1,
                        parse_fn=None,
                        name=None):
  """Adds operations to read, queue, batch and parse `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size` and parse example given `features`
  specification.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    features: A `dict` mapping feature keys to `FixedLenFeature` or
      `VarLenFeature` values.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If None, cycles through the dataset forever. NOTE - If specified,
      creates a variable that must be initialized, so call
      tf.local_variables_initializer() as shown in the tests.
    queue_capacity: Capacity for input queue.
    feature_queue_capacity: Capacity of the parsed features queue. Set this
      value to a small number, for example 5 if the parsed features are large.
    reader_num_threads: The number of threads to read examples.
    parse_fn: Parsing function, takes `Example` Tensor returns parsed
      representation. If `None`, no parsing is done.
    name: Name of resulting op.

  Returns:
    A dict of `Tensor` or `SparseTensor` objects for each in `features`.

  Raises:
    ValueError: for invalid inputs.
  """
  _, features = read_keyed_batch_features(
      file_pattern, batch_size, features, reader,
      randomize_input=randomize_input, num_epochs=num_epochs,
      queue_capacity=queue_capacity,
      feature_queue_capacity=feature_queue_capacity,
      reader_num_threads=reader_num_threads,
      parse_fn=parse_fn, name=name)
  return features
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_read_text_lines_large(self):
    gfile.Glob = self._orig_glob
    sequence_prefix = "abcdefghijklmnopqrstuvwxyz123456789"
    num_records = 49999
    lines = ["".join([sequence_prefix, str(l)]).encode("ascii")
             for l in xrange(num_records)]
    json_lines = ["".join(['{"features": { "feature": { "sequence": {',
                           '"bytes_list": { "value": ["',
                           base64.b64encode(l).decode("ascii"),
                           '"]}}}}}\n']) for l in lines]
    filename = self._create_temp_file("".join(json_lines))
    batch_size = 10000
    queue_capacity = 10000
    name = "my_large_batch"

    features = {"sequence": tf.FixedLenFeature([], tf.string)}

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      keys, result = tf.contrib.learn.read_keyed_batch_features(
          filename, batch_size, features, tf.TextLineReader,
          randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
          num_enqueue_threads=2, parse_fn=tf.decode_json_example, name=name)
      self.assertAllEqual((None,), keys.get_shape().as_list())
      self.assertEqual(1, len(result))
      self.assertAllEqual((None,), result["sequence"].get_shape().as_list())
      session.run(tf.local_variables_initializer())
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(session, coord=coord)

      data = []
      try:
        while not coord.should_stop():
          data.append(session.run(result))
      except errors.OutOfRangeError:
        pass
      finally:
        coord.request_stop()

      coord.join(threads)

    parsed_records = [item for sublist in [d["sequence"] for d in data]
                      for item in sublist]
    # Check that the number of records matches expected and all records
    # are present.
    self.assertEqual(len(parsed_records), num_records)
    self.assertEqual(set(parsed_records), set(lines))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_read_text_lines_multifile_with_shared_queue(self):
    gfile.Glob = self._orig_glob
    filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      keys, inputs = _read_keyed_batch_examples_shared_queue(
          filenames,
          batch_size,
          reader=tf.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          name=name)
      self.assertAllEqual((None,), keys.get_shape().as_list())
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(tf.local_variables_initializer())

      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(session, coord=coord)

      self.assertEqual("%s:1" % name, inputs.name)
      shared_file_name_queue_name = "%s/file_name_queue" % name
      file_names_name = "%s/input" % shared_file_name_queue_name
      example_queue_name = "%s/fifo_queue" % name
      worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name
      test_util.assert_ops_in_graph({
          file_names_name: "Const",
          shared_file_name_queue_name: "FIFOQueue",
          "%s/read/TextLineReader" % name: "TextLineReader",
          example_queue_name: "FIFOQueue",
          worker_file_name_queue_name: "FIFOQueue",
          name: "QueueDequeueUpTo"
      }, g)

      self.assertAllEqual(session.run(inputs), [b"ABC"])
      self.assertAllEqual(session.run(inputs), [b"DEF"])
      self.assertAllEqual(session.run(inputs), [b"GHK"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_keyed_parse_json(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file(
        '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}\n'
        '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}\n'
        '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}\n'
    )

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      dtypes = {"age": tf.FixedLenFeature([1], tf.int64)}
      parse_fn = lambda example: tf.parse_single_example(  # pylint: disable=g-long-lambda
          tf.decode_json_example(example), dtypes)
      keys, inputs = tf.contrib.learn.io.read_keyed_batch_examples(
          filename, batch_size,
          reader=tf.TextLineReader, randomize_input=False,
          num_epochs=1, queue_capacity=queue_capacity,
          parse_fn=parse_fn, name=name)
      self.assertAllEqual((None,), keys.get_shape().as_list())
      self.assertEqual(1, len(inputs))
      self.assertAllEqual((None, 1), inputs["age"].get_shape().as_list())
      session.run(tf.local_variables_initializer())

      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(session, coord=coord)

      key, age = session.run([keys, inputs["age"]])
      self.assertAllEqual(age, [[0]])
      self.assertAllEqual(key, [filename.encode("utf-8") + b":1"])
      key, age = session.run([keys, inputs["age"]])
      self.assertAllEqual(age, [[1]])
      self.assertAllEqual(key, [filename.encode("utf-8") + b":2"])
      key, age = session.run([keys, inputs["age"]])
      self.assertAllEqual(age, [[2]])
      self.assertAllEqual(key, [filename.encode("utf-8") + b":3"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
项目:Machine-Learning    作者:sfeng15    | 项目源码 | 文件源码
def _MaybeDeleteOldCheckpoints(self, latest_save_path,
                                 meta_graph_suffix="meta"):
    """Deletes old checkpoints if necessary.

    Always keep the last `max_to_keep` checkpoints.  If
    `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
    every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
    kept for every 0.5 hours of training; if `N` is 10, an additional
    checkpoint is kept for every 10 hours of training.

    Args:
      latest_save_path: Name including path of checkpoint file to save.
      meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
    """
    if not self.saver_def.max_to_keep:
      return
    # Remove first from list if the same name was used before.
    for p in self._last_checkpoints:
      if latest_save_path == self._CheckpointFilename(p):
        self._last_checkpoints.remove(p)
    # Append new path to list
    self._last_checkpoints.append((latest_save_path, time.time()))
    # If more than max_to_keep, remove oldest.
    if len(self._last_checkpoints) > self.saver_def.max_to_keep:
      p = self._last_checkpoints.pop(0)
      # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
      # have reached N hours of training.
      should_keep = p[1] > self._next_checkpoint_time
      if should_keep:
        self._next_checkpoint_time += (
            self.saver_def.keep_checkpoint_every_n_hours * 3600)
        return
      # Otherwise delete the files.
      for f in gfile.Glob(self._CheckpointFilename(p)):
        try:
          gfile.Remove(f)
          meta_graph_filename = self._MetaGraphFilename(
              f, meta_graph_suffix=meta_graph_suffix)
          if gfile.Exists(meta_graph_filename):
            gfile.Remove(meta_graph_filename)
        except OSError as e:
          logging.warning("Ignoring: %s", str(e))
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def read_batch_record_features(file_pattern,
                               batch_size,
                               features,
                               randomize_input=True,
                               num_epochs=None,
                               queue_capacity=10000,
                               reader_num_threads=1,
                               name='dequeue_record_examples'):
  """Reads TFRecord, queues, batches and parses `Example` proto.

  See more detailed description in `read_examples`.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    features: A `dict` mapping feature keys to `FixedLenFeature` or
      `VarLenFeature` values.
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If None, cycles through the dataset forever. NOTE - If specified,
      creates a variable that must be initialized, so call
      tf.local_variables_initializer() and run the op in a session.
    queue_capacity: Capacity for input queue.
    reader_num_threads: The number of threads to read examples.
    name: Name of resulting op.

  Returns:
    A dict of `Tensor` or `SparseTensor` objects for each in `features`.

  Raises:
    ValueError: for invalid inputs.
  """
  return read_batch_features(
      file_pattern=file_pattern,
      batch_size=batch_size,
      features=features,
      reader=io_ops.TFRecordReader,
      randomize_input=randomize_input,
      num_epochs=num_epochs,
      queue_capacity=queue_capacity,
      reader_num_threads=reader_num_threads,
      name=name)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def test_read_text_lines_multifile(self):
    gfile.Glob = self._orig_glob
    filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
      inputs = graph_io.read_batch_examples(
          filenames,
          batch_size,
          reader=io_ops.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          name=name)
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(variables.local_variables_initializer())

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(session, coord=coord)

      self.assertEqual("%s:1" % name, inputs.name)
      file_name_queue_name = "%s/file_name_queue" % name
      file_names_name = "%s/input" % file_name_queue_name
      example_queue_name = "%s/fifo_queue" % name
      test_util.assert_ops_in_graph({
          file_names_name: "Const",
          file_name_queue_name: "FIFOQueueV2",
          "%s/read/TextLineReaderV2" % name: "TextLineReaderV2",
          example_queue_name: "FIFOQueueV2",
          name: "QueueDequeueUpToV2"
      }, g)

      self.assertAllEqual(session.run(inputs), [b"ABC"])
      self.assertAllEqual(session.run(inputs), [b"DEF"])
      self.assertAllEqual(session.run(inputs), [b"GHK"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def test_read_text_lines_multifile_with_shared_queue(self):
    gfile.Glob = self._orig_glob
    filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
      keys, inputs = _read_keyed_batch_examples_shared_queue(
          filenames,
          batch_size,
          reader=io_ops.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          name=name)
      self.assertAllEqual((None,), keys.get_shape().as_list())
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run([
          variables.local_variables_initializer(),
          variables.global_variables_initializer()
      ])

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(session, coord=coord)

      self.assertEqual("%s:1" % name, inputs.name)
      example_queue_name = "%s/fifo_queue" % name
      worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name
      test_util.assert_ops_in_graph({
          "%s/read/TextLineReaderV2" % name: "TextLineReaderV2",
          example_queue_name: "FIFOQueueV2",
          worker_file_name_queue_name: "FIFOQueueV2",
          name: "QueueDequeueUpToV2"
      }, g)

      self.assertAllEqual(session.run(inputs), [b"ABC"])
      self.assertAllEqual(session.run(inputs), [b"DEF"])
      self.assertAllEqual(session.run(inputs), [b"GHK"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def test_keyed_parse_json(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file(
        '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}\n'
        '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}\n'
        '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}\n')

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
      dtypes = {"age": parsing_ops.FixedLenFeature([1], dtypes_lib.int64)}
      parse_fn = lambda example: parsing_ops.parse_single_example(  # pylint: disable=g-long-lambda
          parsing_ops.decode_json_example(example), dtypes)
      keys, inputs = graph_io.read_keyed_batch_examples(
          filename,
          batch_size,
          reader=io_ops.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          parse_fn=parse_fn,
          name=name)
      self.assertAllEqual((None,), keys.get_shape().as_list())
      self.assertEqual(1, len(inputs))
      self.assertAllEqual((None, 1), inputs["age"].get_shape().as_list())
      session.run(variables.local_variables_initializer())

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(session, coord=coord)

      key, age = session.run([keys, inputs["age"]])
      self.assertAllEqual(age, [[0]])
      self.assertAllEqual(key, [filename.encode("utf-8") + b":1"])
      key, age = session.run([keys, inputs["age"]])
      self.assertAllEqual(age, [[1]])
      self.assertAllEqual(key, [filename.encode("utf-8") + b":2"])
      key, age = session.run([keys, inputs["age"]])
      self.assertAllEqual(age, [[2]])
      self.assertAllEqual(key, [filename.encode("utf-8") + b":3"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)