Python tensorflow 模块,Example() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.Example()

项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def encode(self, instance):
    """Encode a tf.transform encoded dict as serialized tf.Example."""
    if self._encode_example_cache is None:
      # Initialize the encode Example cache (used by this and all subsequent
      # calls to encode).
      example = tf.train.Example()
      for feature_handler in self._feature_handlers:
        feature_handler.initialize_encode_cache(example)
      self._encode_example_cache = example

    # Encode and serialize using the Example cache.
    for feature_handler in self._feature_handlers:
      value = instance[feature_handler.name]
      try:
        feature_handler.encode_value(value)
      except TypeError as e:
        raise TypeError('%s while encoding feature "%s"' %
                        (e, feature_handler.name))

    return self._encode_example_cache.SerializeToString()
项目:easy-tensorflow    作者:khanhptnk    | 项目源码 | 文件源码
def __init__(self, model, loss_fn, data_path, log_dir, graph, input_reader):
    """Initialize a `TrainEvalBase` object.
      Args:
        model: an instance of a subclass of the `ModelBase` class (defined in
          `model_base.py`).
        loss_fn: a tensorflow op, a loss function for training a model. See:
            https://www.tensorflow.org/code/tensorflow/contrib/losses/python/losses/loss_ops.py
          for a list of available loss functions.
        data_path: a string, path to files of tf.Example protos containing data.
        log_dir: a string, logging directory.
        graph: a tensorflow computation graph.
        input_reader: an instance of a subclass of the `InputReaderBase` class
          (defined in `input_reader_base.py`).
    """
    self._data_path = data_path
    self._log_dir = log_dir
    self._train_log_dir = os.path.join(self._log_dir, "train")
    self._eval_log_dir = os.path.join(self._log_dir, "eval")

    self._model = model
    self._loss_fn = loss_fn
    self._graph = graph
    self._input_reader = input_reader

    self._summary_ops = []
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def as_feature_spec(self, column):
    ind = self.index_fields
    if len(ind) != 1 or len(column.axes) != 1:
      raise ValueError('tf.Example parser supports only 1-d sparse features.')
    index = ind[0]

    if column.domain.dtype not in _TF_EXAMPLE_ALLOWED_TYPES:
      raise ValueError('tf.Example parser supports only types {}, so it is '
                       'invalid to generate a feature_spec with type '
                       '{}.'.format(
                           _TF_EXAMPLE_ALLOWED_TYPES,
                           repr(column.domain.dtype)))

    return tf.SparseFeature(index.name,
                            self._value_field_name,
                            column.domain.dtype,
                            column.axes[0].size,
                            index.is_sorted)
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def segment_indices(segment_ids, name=None):
  """Returns a `Tensor` of indices within each segment.

  segment_ids should be a sequence of non-decreasing non-negative integers that
  define a set of segments, e.g. [0, 0, 1, 2, 2, 2] defines 3 segments of length
  2, 1 and 3.  The return value is a `Tensor` containing the indices within each
  segment.

  Example input: [0, 0, 1, 2, 2, 2]
  Example output: [0, 1, 0, 0, 1, 2]

  Args:
    segment_ids: A 1-d `Tensor` containing an non-decreasing sequence of
        non-negative integers with type `tf.int32` or `tf.int64`.
    name: (Optional) A name for this operation.

  Returns:
    A `Tensor` containing the indices within each segment.
  """
  with tf.name_scope(name, 'segment_indices'):
    segment_lengths = tf.segment_sum(tf.ones_like(segment_ids), segment_ids)
    segment_starts = tf.gather(tf.concat([[0], tf.cumsum(segment_lengths)], 0),
                               segment_ids)
    return (tf.range(tf.size(segment_ids, out_type=segment_ids.dtype)) -
            segment_starts)
项目:text2text    作者:google    | 项目源码 | 文件源码
def _GetExFeatureText(self, example, key):
    """Extracts text for a feature from tf.Example.

    Args:
      example: tf.Example.
      key: Key of the feature to be extracted.

    Returns:
      A feature text extracted.
    """

    values = []
    for value in example.features.feature[key].bytes_list.value:
      values.append(value.decode("utf-8"))

    return values
项目:text2text    作者:google    | 项目源码 | 文件源码
def _GetExFeatureText(self, example, key):
    """Extracts text for a feature from tf.Example.

    Args:
      example: tf.Example.
      key: Key of the feature to be extracted.

    Returns:
      A feature text extracted.
    """

    values = []
    for value in example.features.feature[key].bytes_list.value:
      values.append(value.decode("utf-8"))

    return values
项目:tefla    作者:openAGI    | 项目源码 | 文件源码
def to_example(self, dictionary):
        """Helper: build tf.Example from (string -> int/float/str list) dictionary."""
        features = {}
        for (k, v) in six.iteritems(dictionary):
            if not v:
                raise ValueError("Empty generated field: %s", str((k, v)))
            if isinstance(v[0], six.integer_types):
                features[k] = tf.train.Feature(
                    int64_list=tf.train.Int64List(value=v))
            elif isinstance(v[0], float):
                features[k] = tf.train.Feature(
                    float_list=tf.train.FloatList(value=v))
            elif isinstance(v[0], six.string_types):
                if not six.PY2:  # Convert in python 3.
                    v = [bytes(x, "utf-8") for x in v]
                features[k] = tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=v))
            elif isinstance(v[0], bytes):
                features[k] = tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=v))
            else:
                raise ValueError("Value for %s is not a recognized type; v: %s type: %s" %
                                 (k, str(v[0]), str(type(v[0]))))
        return tf.train.Example(features=tf.train.Features(feature=features))
项目:kaggle-youtube-8m    作者:liufuyang    | 项目源码 | 文件源码
def get_placeholder_input_fn(config, model_type, vocab_sizes, use_crosses):
  """Wrap the get input features function to provide the metadata."""

  def get_input_features():
    """Read the input features from the given placeholder."""
    columns = feature_columns(config, model_type, vocab_sizes, use_crosses)
    feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(columns)

    # Add a dense feature for the keys, use '' if not on the tf.Example proto.
    feature_spec[KEY_FEATURE_COLUMN] = tf.FixedLenFeature(
        [1], dtype=tf.string, default_value='')

    # Add a placeholder for the serialized tf.Example proto input.
    examples = tf.placeholder(tf.string, shape=(None,))

    features = tf.parse_example(examples, feature_spec)
    # Pass the input tensor so it can be used for export.
    features[EXAMPLES_PLACEHOLDER_KEY] = examples
    return features, None

  # Return a function to input the feaures into the model from a placeholder.
  return get_input_features
项目:tf_oreilly    作者:chiphuyen    | 项目源码 | 文件源码
def read_from_tfrecord(filenames):
    tfrecord_file_queue = tf.train.string_input_producer(filenames, name='queue')
    reader = tf.TFRecordReader()
    _, tfrecord_serialized = reader.read(tfrecord_file_queue)

    # label and image are stored as bytes but could be stored as 
    # int64 or float64 values in a serialized tf.Example protobuf.
    tfrecord_features = tf.parse_single_example(tfrecord_serialized,
                        features={
                            'label': tf.FixedLenFeature([], tf.int64),
                            'shape': tf.FixedLenFeature([], tf.string),
                            'image': tf.FixedLenFeature([], tf.string),
                        }, name='features')
    # image was saved as uint8, so we have to decode as uint8.
    image = tf.decode_raw(tfrecord_features['image'], tf.uint8)
    shape = tf.decode_raw(tfrecord_features['shape'], tf.int32)
    # the image tensor is flattened out, so we have to reconstruct the shape
    image = tf.reshape(image, shape)
    label = tfrecord_features['label']
    return label, shape, image
项目:tensor2tensor    作者:tensorflow    | 项目源码 | 文件源码
def to_example(dictionary):
  """Helper: build tf.Example from (string -> int/float/str list) dictionary."""
  features = {}
  for (k, v) in six.iteritems(dictionary):
    if not v:
      raise ValueError("Empty generated field: %s", str((k, v)))
    if isinstance(v[0], six.integer_types):
      features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))
    elif isinstance(v[0], float):
      features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v))
    elif isinstance(v[0], six.string_types):
      if not six.PY2:  # Convert in python 3.
        v = [bytes(x, "utf-8") for x in v]
      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
    elif isinstance(v[0], bytes):
      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
    else:
      raise ValueError("Value for %s is not a recognized type; v: %s type: %s" %
                       (k, str(v[0]), str(type(v[0]))))
  return tf.train.Example(features=tf.train.Features(feature=features))
项目:scalable_analytics    作者:broadinstitute    | 项目源码 | 文件源码
def _predict_input_fn():
  """Supplies the input to the model.

  Returns:
    A tuple consisting of 1) a dictionary of tensors whose keys are
    the feature names, and 2) a tensor of target labels which for
    clustering must be 'None'.
  """

  # Add a placeholder for the serialized tf.Example proto input.
  examples = tf.placeholder(tf.string, shape=(None,), name="examples")

  raw_features = tf.parse_example(examples, _get_feature_columns())

  dense = _raw_features_to_dense_tensor(raw_features)

  return input_fn_utils.InputFnOps(
      features={DENSE_KEY: dense},
      labels=None,
      default_inputs={EXAMPLE_KEY: examples})
项目:scalable_analytics    作者:broadinstitute    | 项目源码 | 文件源码
def measurements_to_examples(input_data):
  """Converts sparse measurements to TensorFlow Example protos.

  Args:
    input_data: dictionary objects with keys from
      DATA_QUERY_REPLACEMENTS

  Returns:
    TensorFlow Example protos.
  """
  meas_kvs = input_data | 'BucketMeasurements' >> beam.Map(
      lambda row: (row[SAMPLE_COLUMN], row))

  sample_meas_kvs = meas_kvs | 'GroupBySample' >> beam.GroupByKey()

  examples = (
      sample_meas_kvs
      | 'SamplesToExamples' >>
      beam.Map(lambda (key, vals): sample_measurements_to_example(key, vals)))

  return examples
项目:cloudml-examples    作者:googlegenomics    | 项目源码 | 文件源码
def _predict_input_fn():
  """Supplies the input to the model.

  Returns:
    A tuple consisting of 1) a dictionary of tensors whose keys are
    the feature names, and 2) a tensor of target labels if the mode
    is not INFER (and None, otherwise).
  """
  feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(
      feature_columns=_get_feature_columns(include_target_column=False))

  feature_spec[FLAGS.id_field] = tf.FixedLenFeature([], dtype=tf.string)
  feature_spec[FLAGS.target_field + "_string"] = tf.FixedLenFeature(
      [], dtype=tf.string)

  # Add a placeholder for the serialized tf.Example proto input.
  examples = tf.placeholder(tf.string, shape=(None,), name="examples")

  features = tf.parse_example(examples, feature_spec)
  features[PREDICTION_KEY] = features[FLAGS.id_field]

  inputs = {PREDICTION_EXAMPLES: examples}

  return input_fn_utils.InputFnOps(
      features=features, labels=None, default_inputs=inputs)
项目:CRF-image-segmentation    作者:therealnidhin    | 项目源码 | 文件源码
def get_class_name_from_filename(file_name):
  """Gets the class name from a file.

  Args:
    file_name: The file name to get the class name from.
               ie. "american_pit_bull_terrier_105.jpg"

  Returns:
    example: The converted tf.Example.
  """
  match = re.match(r'([A-Za-z_]+)(-[0-9]+\.jpg)', file_name, re.I)
  return match.groups()[0]
项目:easy-tensorflow    作者:khanhptnk    | 项目源码 | 文件源码
def convert(data, filename):
  images = data["data"]
  labels = data["labels"]
  num_examples = images.shape[0]
  with tf.python_io.TFRecordWriter(filename) as writer:
    for i in xrange(num_examples):
      logging.info("Writing batch " + str(i) + "/" + str(num_examples))
      image = [int(x) for x in images[i, :]]
      label = labels[i]
      example = tf.train.Example()
      features_map = example.features.feature
      features_map["image"].int64_list.value.extend(list(image))
      features_map["label"].int64_list.value.append(label)
      writer.write(example.SerializeToString())
项目:easy-tensorflow    作者:khanhptnk    | 项目源码 | 文件源码
def _load_data(self):
    """Load data from files of tf.Example protos."""
    keys, examples = self._input_reader.read_input(
        self._data_path,
        self._config.batch_size,
        randomize_input=self._model.is_training,
        distort_inputs=self._model.is_training)

    self._observations = examples["decoded_observation"]
    self._labels = examples["decoded_label"]
项目:magenta    作者:tensorflow    | 项目源码 | 文件源码
def get_example(self, batch_size):
    """Get a single example from the tfrecord file.

    Args:
      batch_size: Int, minibatch size.

    Returns:
      tf.Example protobuf parsed from tfrecord.
    """
    reader = tf.TFRecordReader()
    num_epochs = None if self.is_training else 1
    capacity = batch_size
    path_queue = tf.train.input_producer(
        [self.record_path],
        num_epochs=num_epochs,
        shuffle=self.is_training,
        capacity=capacity)
    unused_key, serialized_example = reader.read(path_queue)
    features = {
        "note_str": tf.FixedLenFeature([], dtype=tf.string),
        "pitch": tf.FixedLenFeature([1], dtype=tf.int64),
        "velocity": tf.FixedLenFeature([1], dtype=tf.int64),
        "audio": tf.FixedLenFeature([64000], dtype=tf.float32),
        "qualities": tf.FixedLenFeature([10], dtype=tf.int64),
        "instrument_source": tf.FixedLenFeature([1], dtype=tf.int64),
        "instrument_family": tf.FixedLenFeature([1], dtype=tf.int64),
    }
    example = tf.parse_single_example(serialized_example, features)
    return example
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def encode_value(self, values):
    """Encodes a feature into its Example proto representation."""
    del self._value[:]
    if self._rank == 0:
      self._value.append(self._cast_fn(values))
    else:
      flattened_values = (values if self._rank == 1 else
                          np.asarray(values).reshape(-1))
      if len(flattened_values) != self._size:
        raise ValueError('FixedLenFeature %r got wrong number of values. '
                         'Expected %d but got %d' %
                         (self._name, self._size, len(flattened_values)))
      self._value.extend(self._cast_fn(flattened_values))
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def decode(self, serialized_example_proto):
    """Decode serialized tf.Example as a tf.transform encoded dict."""
    if self._decode_example_cache is None:
      # Initialize the decode Example cache (used by this and all subsequent
      # calls to decode).
      self._decode_example_cache = tf.train.Example()

    example = self._decode_example_cache
    example.ParseFromString(serialized_example_proto)
    feature_map = example.features.feature
    return {feature_handler.name: feature_handler.parse_value(feature_map)
            for feature_handler in self._feature_handlers}
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def as_feature_spec(self, column):
    if not column.is_fixed_size():
      raise ValueError('A column of unknown size cannot be represented as '
                       'fixed-size.')
    if column.domain.dtype not in _TF_EXAMPLE_ALLOWED_TYPES:
      raise ValueError('tf.Example parser supports only types {}, so it is '
                       'invalid to generate a feature_spec with type '
                       '{}.'.format(
                           _TF_EXAMPLE_ALLOWED_TYPES,
                           repr(column.domain.dtype)))
    return tf.FixedLenFeature(column.tf_shape().as_list(),
                              column.domain.dtype,
                              self.default_value)
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def test_feature_spec_unsupported_dtype(self):
    schema = sch.Schema()
    schema.column_schemas['fixed_float_with_default'] = (
        sch.ColumnSchema(tf.float64, [1], sch.FixedColumnRepresentation(0.0)))

    with self.assertRaisesRegexp(ValueError,
                                 'tf.Example parser supports only types '
                                 r'\[tf.string, tf.int64, tf.float32, tf.bool\]'
                                 ', so it is invalid to generate a feature_spec'
                                 ' with type tf.float64.'):
      schema.as_feature_spec()
项目:tensorflow_object_detection_create_coco_tfrecord    作者:MetaPeak    | 项目源码 | 文件源码
def dict_to_coco_example(img_data):
    """Convert python dictionary formath data of one image to tf.Example proto.
    Args:
        img_data: infomation of one image, inclue bounding box, labels of bounding box,\
            height, width, encoded pixel data.
    Returns:
        example: The converted tf.Example
    """
    bboxes = img_data['bboxes']
    xmin, xmax, ymin, ymax = [], [], [], []
    for bbox in bboxes:
        xmin.append(bbox[0])
        xmax.append(bbox[0] + bbox[2])
        ymin.append(bbox[1])
        ymax.append(bbox[1] + bbox[3])

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(img_data['height']),
        'image/width': dataset_util.int64_feature(img_data['width']),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
        'image/object/class/label': dataset_util.int64_list_feature(img_data['labels']),
        'image/encoded': dataset_util.bytes_feature(img_data['pixel_data']),
        'image/format': dataset_util.bytes_feature('jpeg'.encode('utf-8')),
    }))
    return example
项目:FYP-AutoTextSum    作者:MrRexZ    | 项目源码 | 文件源码
def _TextGenerator(self, example_gen):
    """Generates article and abstract text from tf.Example."""
    while True:
      e = next(example_gen)
      try:
        article_text = self._GetExFeatureText(e, self._article_key)
        abstract_text = self._GetExFeatureText(e, self._abstract_key)
      except ValueError:
        tf.logging.error('Failed to get article or abstract from example')
        continue

      yield (article_text, abstract_text)
项目:FYP-AutoTextSum    作者:MrRexZ    | 项目源码 | 文件源码
def _GetExFeatureText(self, ex, key):
    """Extract text for a feature from td.Example.

    Args:
      ex: tf.Example.
      key: key of the feature to be extracted.
    Returns:
      feature: a feature text extracted.
    """
    return ex.features.feature[key].bytes_list.value[0]
项目:text2text    作者:google    | 项目源码 | 文件源码
def __init__(self, data_path, config):
    """Batcher initializer.

    Args:
      data_path: tf.Example filepattern.
      config: model hyperparameters.
    """
    self._data_path = data_path
    self._config = config
    self._input_vocab = config.input_vocab
    self._output_vocab = config.output_vocab
    self._source_key = config.source_key
    self._target_key = config.target_key
    self.use_bucketing = config.use_bucketing
    self._truncate_input = config.truncate_input
    self._input_queue = queue.Queue(QUEUE_NUM_BATCH * config.batch_size)
    self._bucket_input_queue = queue.Queue(QUEUE_NUM_BATCH)
    self._input_threads = []
    for _ in range(DAEMON_READER_THREADS):
      self._input_threads.append(Thread(target=self._FillInputQueue))
      self._input_threads[-1].daemon = True
      self._input_threads[-1].start()
    self._bucketing_threads = []
    for _ in range(BUCKETING_THREADS):
      self._bucketing_threads.append(Thread(target=self._FillBucketInputQueue))
      self._bucketing_threads[-1].daemon = True
      self._bucketing_threads[-1].start()

    self._watch_thread = Thread(target=self._WatchThreads)
    self._watch_thread.daemon = True
    self._watch_thread.start()
项目:text2text    作者:google    | 项目源码 | 文件源码
def __init__(self, data_path, config):
    """Batcher initializer.

    Args:
      data_path: tf.Example filepattern.
      config: model hyperparameters.
    """
    self._data_path = data_path
    self._config = config
    self._input_vocab = config.input_vocab
    self._output_vocab = config.output_vocab
    self._source_key = config.source_key
    self._target_key = config.target_key
    self.use_bucketing = config.use_bucketing
    self._truncate_input = config.truncate_input
    self._input_queue = queue.Queue(QUEUE_NUM_BATCH * config.batch_size)
    self._bucket_input_queue = queue.Queue(QUEUE_NUM_BATCH)
    self._input_threads = []
    for _ in range(DAEMON_READER_THREADS):
      self._input_threads.append(Thread(target=self._FillInputQueue))
      self._input_threads[-1].daemon = True
      self._input_threads[-1].start()
    self._bucketing_threads = []
    for _ in range(BUCKETING_THREADS):
      self._bucketing_threads.append(Thread(target=self._FillBucketInputQueue))
      self._bucketing_threads[-1].daemon = True
      self._bucketing_threads[-1].start()

    self._watch_thread = Thread(target=self._WatchThreads)
    self._watch_thread.daemon = True
    self._watch_thread.start()
项目:TensorflowFramework    作者:vahidk    | 项目源码 | 文件源码
def parallel_record_writer(iterator, create_example, path, num_threads=4):
  """Create a RecordIO file from data for efficient reading."""

  def _queue(inputs):
    for item in iterator:
      inputs.put(item)
    for _ in range(num_threads):
      inputs.put(None)

  def _map_fn(inputs, outputs):
    while True:
      item = inputs.get()
      if item is None:
        break
      example = create_example(item)
      outputs.put(example)
    outputs.put(None)

  # Read the inputs.
  inputs = mp.Queue()
  mp.Process(target=_queue, args=(inputs,)).start()

  # Convert to tf.Example
  outputs = mp.Queue()
  for _ in range(num_threads):
    mp.Process(target=_map_fn, args=(inputs, outputs)).start()

  # Write the output to file.
  writer = tf.python_io.TFRecordWriter(path)
  counter = 0
  while True:
    example = outputs.get()
    if example is None:
      counter += 1
      if counter == num_threads:
        break
      else:
        continue
    writer.write(example.SerializeToString())
  writer.close()
项目:tefla    作者:openAGI    | 项目源码 | 文件源码
def generate_files(self, generator, output_filenames, max_cases=None):
        """Generate cases from a generator and save as TFRecord files.

        Generated cases are transformed to tf.Example protos and saved as TFRecords
        in sharded files named output_dir/output_name-00..N-of-00..M=num_shards.

        Args:
          generator: a generator yielding (string -> int/float/str list) dictionaries.
          output_filenames: List of output file paths.
          max_cases: maximum number of cases to get from the generator;
            if None (default), we use the generator until StopIteration is raised.
        """
        num_shards = len(output_filenames)
        writers = [tf.python_io.TFRecordWriter(
            fname) for fname in output_filenames]
        counter, shard = 0, 0
        for case in generator:
            if counter > 0 and counter % 100000 == 0:
                tf.logging.info("Generating case %d." % counter)
            counter += 1
            if max_cases and counter > max_cases:
                break
            sequence_example = self.to_example(case)
            writers[shard].write(sequence_example.SerializeToString())
            shard = (shard + 1) % num_shards

        for writer in writers:
            writer.close()
项目:savchenko    作者:JuleLaryushina    | 项目源码 | 文件源码
def _TextGenerator(self, example_gen):
    """Generates article and abstract text from tf.Example."""
    while True:
      e = example_gen.next()
      try:
        article_text = self._GetExFeatureText(e, self._article_key)
        abstract_text = self._GetExFeatureText(e, self._abstract_key)
      except ValueError:
        tf.logging.error('Failed to get article or abstract from example')
        continue

      yield (article_text, abstract_text)
项目:savchenko    作者:JuleLaryushina    | 项目源码 | 文件源码
def _GetExFeatureText(self, ex, key):
    """Extract text for a feature from td.Example.

    Args:
      ex: tf.Example.
      key: key of the feature to be extracted.
    Returns:
      feature: a feature text extracted.
    """
    return ex.features.feature[key].bytes_list.value[0]
项目:tensorfx    作者:TensorLab    | 项目源码 | 文件源码
def parse_instances(self, instances, prediction=False):
    """Parses input instances according to the associated schema.

    Arguments:
      instances: The tensor containing input strings.
      prediction: Whether the instances are being parsed for producing predictions or not.
    Returns:
      A dictionary of tensors key'ed by field names.
    """
    # Convert the schema into an equivalent Example schema (expressed as features in Example
    # terminology).
    features = {}
    for field in self.schema:
      if field.type == SchemaFieldType.integer:
        dtype = tf.int64
        default_value = [0]
      elif field.type == SchemaFieldType.real:
        dtype = tf.float32
        default_value = [0.0]
      else:
        # discrete
        dtype = tf.string
        default_value = ['']

      if field.length == 0:
        feature = tf.VarLenFeature(dtype=dtype)
      else:
        if field.length != 1:
          default_value = default_value * field.length
        feature = tf.FixedLenFeature(shape=[field.length], dtype=dtype, default_value=default_value)

      features[field.name] = feature

    return tf.parse_example(instances, features, name='examples')
项目:NMT    作者:keon    | 项目源码 | 文件源码
def _TextGenerator(self, example_gen):
    """Generates article and abstract text from tf.Example."""
    while True:
      e = example_gen.next()
      try:
        article_text = self._GetExFeatureText(e, self._article_key)
        abstract_text = self._GetExFeatureText(e, self._abstract_key)
      except ValueError:
        tf.logging.error('Failed to get article or abstract from example')
        continue

      yield (article_text, abstract_text)
项目:NMT    作者:keon    | 项目源码 | 文件源码
def _GetExFeatureText(self, ex, key):
    """Extract text for a feature from td.Example.

    Args:
      ex: tf.Example.
      key: key of the feature to be extracted.
    Returns:
      feature: a feature text extracted.
    """
    return ex.features.feature[key].bytes_list.value[0]
项目:dynamic-coattention-network    作者:marshmelloX    | 项目源码 | 文件源码
def tf_Examples(data_path, num_epochs=None):
  """Generates tf.Examples from path of data files.
    Binary data format: <length><blob>. <length> represents the byte size
    of <blob>. <blob> is serialized tf.Example proto. The tf.Example contains
    the tokenized article text and summary.
  Args:
    data_path: path to tf.Example data files.
    num_epochs: Number of times to go through the data. None means infinite.
  Yields:
    Deserialized tf.Example.
  If there are multiple files specified, they accessed in a random order.
  """
  epoch = 0
  while True:
    if num_epochs is not None and epoch >= num_epochs:
      break
    filelist = glob.glob(data_path)
    assert filelist, 'Empty filelist.'
    shuffle(filelist)
    for f in filelist:
      reader = open(f, 'rb')
      while True:
        len_bytes = reader.read(8)
        if not len_bytes: break
        str_len = struct.unpack('q', len_bytes)[0]
        example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
        yield example_pb2.Example.FromString(example_str)

    epoch += 1
项目:tf_oreilly    作者:chiphuyen    | 项目源码 | 文件源码
def write_to_tfrecord(label, shape, binary_image, tfrecord_file):
    """ This example is to write a sample to TFRecord file. If you want to write
    more samples, just use a loop.
    """
    writer = tf.python_io.TFRecordWriter(tfrecord_file)
    # write label, shape, and image content to the TFRecord file
    example = tf.train.Example(features=tf.train.Features(feature={
                'label': _int64_feature(label),
                'shape': _bytes_feature(shape),
                'image': _bytes_feature(binary_image)
                }))
    writer.write(example.SerializeToString())
    writer.close()
项目:TensorFlowOnSpark    作者:yahoo    | 项目源码 | 文件源码
def loadTFRecords(sc, input_dir, binary_features=[]):
  """Load TFRecords from disk into a Spark DataFrame.

  This will attempt to automatically convert the tf.train.Example features into Spark DataFrame columns of equivalent types.

  Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to
  disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a "hint"
  from the caller in the ``binary_features`` argument.

  Args:
    :sc: SparkContext
    :input_dir: location of TFRecords on disk.
    :binary_features: a list of tf.train.Example features which are expected to be binary/bytearrays.

  Returns:
    A Spark DataFrame mirroring the tf.train.Example schema.
  """
  import tensorflow as tf

  tfr_rdd = sc.newAPIHadoopFile(input_dir, "org.tensorflow.hadoop.io.TFRecordFileInputFormat",
                              keyClass="org.apache.hadoop.io.BytesWritable",
                              valueClass="org.apache.hadoop.io.NullWritable")

  # infer Spark SQL types from tf.Example
  record = tfr_rdd.take(1)[0]
  example = tf.train.Example()
  example.ParseFromString(bytes(record[0]))
  schema = infer_schema(example, binary_features)

  # convert serialized protobuf to tf.Example to Row
  example_rdd = tfr_rdd.mapPartitions(lambda x: fromTFExample(x, binary_features))

  # create a Spark DataFrame from RDD[Row]
  df = example_rdd.toDF(schema)

  # save reference of this dataframe
  loadedDF[df] = input_dir
  return df
项目:TensorFlowOnSpark    作者:yahoo    | 项目源码 | 文件源码
def infer_schema(example, binary_features=[]):
  """Given a tf.train.Example, infer the Spark DataFrame schema (StructFields).

  Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to
  disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a "hint"
  from the caller in the ``binary_features`` argument.

  Args:
    :example: a tf.train.Example
    :binary_features: a list of tf.train.Example features which are expected to be binary/bytearrays.

  Returns:
    A DataFrame StructType schema
  """
  def _infer_sql_type(k, v):
    # special handling for binary features
    if k in binary_features:
      return BinaryType()

    if v.int64_list.value:
      result = v.int64_list.value
      sql_type = LongType()
    elif v.float_list.value:
      result = v.float_list.value
      sql_type = DoubleType()
    else:
      result = v.bytes_list.value
      sql_type = StringType()

    if len(result) > 1:             # represent multi-item tensors as Spark SQL ArrayType() of base types
      return ArrayType(sql_type)
    else:                           # represent everything else as base types (and empty tensors as StringType())
      return sql_type

  return StructType([ StructField(k, _infer_sql_type(k, v), True) for k,v in sorted(example.features.feature.items()) ])
项目:tpu-demos    作者:tensorflow    | 项目源码 | 文件源码
def get_input_fn(filename):
  """Returns an `input_fn` for train and eval."""

  def input_fn(params):
    """A simple input_fn using the experimental input pipeline."""
    # Retrieves the batch size for the current shard. The # of shards is
    # computed according to the input pipeline deployment. See
    # `tf.contrib.tpu.RunConfig` for details.
    batch_size = params["batch_size"]

    def parser(serialized_example):
      """Parses a single tf.Example into image and label tensors."""
      features = tf.parse_single_example(
          serialized_example,
          features={
              "image_raw": tf.FixedLenFeature([], tf.string),
              "label": tf.FixedLenFeature([], tf.int64),
          })
      image = tf.decode_raw(features["image_raw"], tf.uint8)
      image.set_shape([28 * 28])
      # Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
      image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
      label = tf.cast(features["label"], tf.int32)
      return image, label

    dataset = tf.data.TFRecordDataset(
        filename, buffer_size=FLAGS.dataset_reader_buffer_size)
    dataset = dataset.map(parser).cache().repeat()
    dataset = dataset.apply(
        tf.contrib.data.batch_and_drop_remainder(batch_size))
    images, labels = dataset.make_one_shot_iterator().get_next()
    return images, labels
  return input_fn
项目:dialog_research    作者:wjbianjason    | 项目源码 | 文件源码
def __init__(self, dataGenerator, bucketing=True, truncate_input=False):
    """Batcher constructor.

    Args:
      data_path: tf.Example filepattern.
      vocab: Vocabulary.
      hps: Seq2SeqAttention model hyperparameters.
      article_key: article feature key in tf.Example.
      abstract_key: abstract feature key in tf.Example.
      max_article_sentences: Max number of sentences used from article.
      max_abstract_sentences: Max number of sentences used from abstract.
      bucketing: Whether bucket articles of similar length into the same batch.
      truncate_input: Whether to truncate input that is too long. Alternative is
        to discard such examples.
    """
    self._data_generator = dataGenerator
    self._vocab = dataGenerator.vocab
    self._hps =  dataGenerator._hps 

    # self._max_article_sentences = self.
    # self._max_abstract_sentences = max_abstract_sentences
    self._bucketing = bucketing
    self._truncate_input = truncate_input
    self._input_queue = Queue.Queue(QUEUE_NUM_BATCH * self._hps.batch_size)
    self._bucket_input_queue = Queue.Queue(QUEUE_NUM_BATCH)
    self._input_threads = []
    for _ in xrange(8):
      self._input_threads.append(Thread(target=self._FillInputQueue))
      self._input_threads[-1].daemon = True
      self._input_threads[-1].start()
    self._bucketing_threads = []
    for _ in xrange(2):
      self._bucketing_threads.append(Thread(target=self._FillBucketInputQueue))
      self._bucketing_threads[-1].daemon = True
      self._bucketing_threads[-1].start()

    # self._watch_thread = Thread(target=self._WatchThreads)
    # self._watch_thread.daemon = True
    # self._watch_thread.start()
项目:tensor2tensor    作者:tensorflow    | 项目源码 | 文件源码
def generate_files(generator, output_filenames, max_cases=None):
  """Generate cases from a generator and save as TFRecord files.

  Generated cases are transformed to tf.Example protos and saved as TFRecords
  in sharded files named output_dir/output_name-00..N-of-00..M=num_shards.

  Args:
    generator: a generator yielding (string -> int/float/str list) dictionaries.
    output_filenames: List of output file paths.
    max_cases: maximum number of cases to get from the generator;
      if None (default), we use the generator until StopIteration is raised.
  """
  if outputs_exist(output_filenames):
    tf.logging.info("Skipping generator because outputs files exist")
    return
  num_shards = len(output_filenames)
  writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filenames]
  counter, shard = 0, 0
  for case in generator:
    if counter > 0 and counter % 100000 == 0:
      tf.logging.info("Generating case %d." % counter)
    counter += 1
    if max_cases and counter > max_cases:
      break
    sequence_example = to_example(case)
    writers[shard].write(sequence_example.SerializeToString())
    shard = (shard + 1) % num_shards

  for writer in writers:
    writer.close()
项目:scalable_analytics    作者:broadinstitute    | 项目源码 | 文件源码
def sample_measurements_to_example(sample, sample_measurements):
  """Convert sparse measurements to TensorFlow Example protocol buffers.

  See also
  https://www.tensorflow.org/versions/r0.10/how_tos/reading_data/index.html

  Args:
    sample: the identifier for the sample
    sample_measurements: list of the sample's sparse measurements

  Returns:
    A filled in TensorFlow Example proto for this sample.
  """
  feature_tuples = [(str(cnt[MEASUREMENT_COLUMN]), cnt[VALUE_COLUMN])
                    for cnt in sample_measurements]
  measurements, values = map(list, zip(*feature_tuples))
  features = {
      SAMPLE_NAME_FEATURE:
          tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(sample)])),
      # These are tf.VarLenFeature.
      MEASUREMENTS_FEATURE:
          tf.train.Feature(bytes_list=tf.train.BytesList(value=measurements)),
      VALUES_FEATURE:
          tf.train.Feature(float_list=tf.train.FloatList(value=values))
  }

  return tf.train.Example(features=tf.train.Features(feature=features))
项目:scalable_analytics    作者:broadinstitute    | 项目源码 | 文件源码
def run(argv=None):
  """Runs the sparse measurements preprocess pipeline.

  Args:
    argv: Pipeline options as a list of arguments.
  """
  pipeline_options = PipelineOptions(flags=argv)
  preprocess_options = pipeline_options.view_as(PreprocessOptions)
  cloud_options = pipeline_options.view_as(GoogleCloudOptions)
  output_dir = os.path.join(preprocess_options.output,
                            datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
  pipeline_options.view_as(SetupOptions).save_main_session = True
  pipeline_options.view_as(
      WorkerOptions).autoscaling_algorithm = 'THROUGHPUT_BASED'
  cloud_options.staging_location = os.path.join(output_dir, 'tmp', 'staging')
  cloud_options.temp_location = os.path.join(output_dir, 'tmp')
  cloud_options.job_name = 'preprocess-measurements-%s' % (
      datetime.datetime.now().strftime('%y%m%d-%H%M%S'))

  data_query = str(
      Template(open(preprocess_options.input, 'r').read()).render(
          DATA_QUERY_REPLACEMENTS))
  logging.info('data query : %s', data_query)

  with beam.Pipeline(options=pipeline_options) as p:
    # Read the table rows into a PCollection.
    rows = p | 'ReadMeasurements' >> beam.io.Read(
        beam.io.BigQuerySource(query=data_query, use_standard_sql=True))

    # Convert the data into TensorFlow Example Protocol Buffers.
    examples = measurements_to_examples(rows)

    # Write the serialized compressed protocol buffers to Cloud Storage.
    _ = (examples
         | 'EncodeExamples'
         >> beam.Map(lambda example: example.SerializeToString())
         | 'WriteExamples' >> tfrecordio.WriteToTFRecord(
             file_path_prefix=os.path.join(output_dir, 'examples'),
             compression_type=CompressionTypes.GZIP,
             file_name_suffix='.tfrecord.gz'))
项目:tensorflow    作者:luyishisi    | 项目源码 | 文件源码
def get_class_name_from_filename(file_name):
  """Gets the class name from a file.

  Args:
    file_name: The file name to get the class name from.
               ie. "american_pit_bull_terrier_105.jpg"

  Returns:
    example: The converted tf.Example.
  """
  match = re.match(r'([A-Za-z_]+)(_[0-9]+\.jpg)', file_name, re.I)
  return match.groups()[0]
项目:cloudml-examples    作者:googlegenomics    | 项目源码 | 文件源码
def filter_and_revise_example(serialized_example, samples_metadata):
  """Filter and revise a collection of existing TensorFlow examples.

  Args:
    serialized_example: the example to be revised and/or filtered
    samples_metadata: dictionary of metadata for all samples

  Returns:
    A list containing the revised example or the empty list if the
    example should be removed from the collection.
  """
  example = tf.train.Example.FromString(serialized_example)
  sample_name = example.features.feature[
      encoder.SAMPLE_NAME_FEATURE].bytes_list.value[0]
  logging.info('Checking ' + sample_name)
  if sample_name not in samples_metadata:
    logging.info('Omitting ' + sample_name)
    return []

  revised_features = {}
  # Initialize with current example features.
  revised_features.update(example.features.feature)
  # Overwrite metadata features.
  revised_features.update(
      metadata_encoder.metadata_to_ancestry_features(
          samples_metadata[sample_name]))
  return [
      tf.train.Example(features=tf.train.Features(feature=revised_features))
  ]
项目:easy-tensorflow    作者:khanhptnk    | 项目源码 | 文件源码
def read_input(self, data_path, batch_size, randomize_input=True,
                 distort_inputs=True, name="read_input"):
    """Read input labeled images and make a batch of examples.

      Labeled images are read from files of tf.Example protos. This proto has
      to contain two features: `image` and `label`, corresponding to an image
      and its label. After being read, the labeled images are put into queues
      to make a batch of examples every time the batching op is executed.

      Args:
        data_path: a string, path to files of tf.Example protos containing
          labeled images.
        batch_size: a int, number of labeled images in a batch.
        randomize_input: a bool, whether the images in the batch are randomized.
        distort_inputs: a bool, whether to distort the images.
        name: a string, name of the op.
      Returns:
        keys: a tensowflow op, the keys of tf.Example protos.
        examples: a tensorflow op, a batch of examples containing labeled
          images. After being materialized, this op becomes a dict, in which the
          `decoded_observation` key is an image and the `decoded_label` is the
          label of that image.
    """
    feature_types = {}
    feature_types["image"] = tf.FixedLenFeature(
        shape=[3072,], dtype=tf.int64, default_value=None)

    feature_types["label"] = tf.FixedLenFeature(
        shape=[1,], dtype=tf.int64, default_value=None)

    keys, examples = tf.contrib.learn.io.graph_io.read_keyed_batch_examples(
        file_pattern=data_path,
        batch_size=batch_size,
        reader=tf.TFRecordReader,
        randomize_input=randomize_input,
        queue_capacity=batch_size * 4,
        num_threads=10 if randomize_input else 1,
        parse_fn=lambda example_proto: self._preprocess_input(example_proto,
                                                              feature_types,
                                                              distort_inputs),
        name=name)

    return keys, examples
项目:easy-tensorflow    作者:khanhptnk    | 项目源码 | 文件源码
def _preprocess_input(self, example_proto, feature_types, distort_inputs):
    """Parse an tf.Example proto and preprocess its image and label.

      Args:
        example_proto: a tensorflow op, a tf.Example proto.
        feature_types: a dict, used for parsing a tf.Example proto. This is the
          same `feature_types` dict constructed in the `read_input` method.
        distort_inputs: a bool, whether to distort the images.
      Returns:
        example: a tensorflow op, after being materialized becomes a dict, in
          in which the `decoded_observation` key is a processed image, a tensor
          of size InputReaderCifar10.IMAGE_SIZE x
          InputReaderCifar10.IMAGE_SIZE x InputReaderCifar10.NUM_CHANNELS and
          the `decoded_label` is the label of that image, a vector of size
          InputReaderCifar10.NUM_CLASSES.
    """
    example = tf.parse_single_example(example_proto, feature_types)
    image = tf.reshape(example["image"], [InputReaderCifar10.NUM_CHANNELS,
                                          InputReaderCifar10.IMAGE_SIZE,
                                          InputReaderCifar10.IMAGE_SIZE])
    image = tf.transpose(image, perm=[1, 2, 0])
    image = tf.cast(image, tf.float32)
    if distort_inputs:
      image = tf.random_crop(image, [InputReaderCifar10.IMAGE_CROPPED_SIZE,
                                     InputReaderCifar10.IMAGE_CROPPED_SIZE,
                                     3])
      image = tf.image.random_flip_left_right(image)
      image = tf.image.random_brightness(image, max_delta=63)
      image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
    else:
      image = tf.image.resize_image_with_crop_or_pad(image,
          InputReaderCifar10.IMAGE_CROPPED_SIZE,
          InputReaderCifar10.IMAGE_CROPPED_SIZE)
    image = tf.image.per_image_whitening(image)
    example["decoded_observation"] = image

    label = tf.one_hot(example["label"], InputReaderCifar10.NUM_CLASSES, on_value=1, off_value=0)
    label = tf.reshape(label, [InputReaderCifar10.NUM_CLASSES,])
    label = tf.cast(label, tf.int64)
    example["decoded_label"] = label

    return example
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def __init__(self, schema):
    """Build an ExampleProtoCoder.

    Args:
      schema: A `Schema` object.
    Raises:
      ValueError: If `schema` is invalid.
    """
    self._schema = schema

    # Using pre-allocated tf.train.Example objects for performance reasons.
    #
    # The _encode_example_cache is used solely by "encode" paths while the
    # the _decode_example_cache is used solely be "decode" paths, since the
    # caching strategies are incompatible with each other (due to proto
    # parsing/merging implementation).
    #
    # Since the output of both "encode" and "decode" are deep as opposed to
    # shallow transformations, and since the schema always fully defines the
    # Example's FeatureMap (ie all fields are always cleared/assigned or
    # copied), the optimizations and implementation are correct and
    # thread-compatible.
    #
    # Due to pickling issues actual initialization of this will happen lazily
    # in encode or decode respectively.
    self._encode_example_cache = None
    self._decode_example_cache = None

    self._feature_handlers = []
    for name, feature_spec in six.iteritems(schema.as_feature_spec()):
      if isinstance(feature_spec, tf.FixedLenFeature):
        self._feature_handlers.append(
            _FixedLenFeatureHandler(name, feature_spec))
      elif isinstance(feature_spec, tf.VarLenFeature):
        self._feature_handlers.append(
            _VarLenFeatureHandler(name, feature_spec))
      elif isinstance(feature_spec, tf.SparseFeature):
        self._feature_handlers.append(
            _SparseFeatureHandler(name, feature_spec))
      else:
        raise ValueError('feature_spec should be one of tf.FixedLenFeature, '
                         'tf.VarLenFeature or tf.SparseFeature: %s was %s' %
                         (name, type(feature_spec)))
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def tfidf(x, vocab_size, smooth=True, name=None):
  """Maps the terms in x to their term frequency * inverse document frequency.

  The inverse document frequency of a term is calculated as 1+
  log((corpus size + 1) / (document frequency of term + 1)) by default.

  Example usage:
    example strings [["I", "like", "pie", "pie", "pie"], ["yum", "yum", "pie]]
    in: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4],
                              [1, 0], [1, 1], [1, 2]],
                     values=[1, 2, 0, 0, 0, 3, 3, 0])
    out: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]],
                      values=[1, 2, 0, 3, 0])
         SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]],
                      values=[(1/5)*(log(3/2)+1), (1/5)*(log(3/2)+1), (1/5),
                              (1/3), (2/3)*(log(3/2)+1])
    NOTE that the first doc's duplicate "pie" strings have been combined to
    one output, as have the second doc's duplicate "yum" strings.

  Args:
    x: A `SparseTensor` representing int64 values (most likely that are the
        result of calling string_to_int on a tokenized string).
    vocab_size: An int - the count of vocab used to turn the string into int64s
        including any OOV buckets.
    smooth: A bool indicating if the inverse document frequency should be
        smoothed. If True, which is the default, then the idf is calculated as
        1 + log((corpus size + 1) / (document frequency of term + 1)).
        Otherwise, the idf is
        1 +log((corpus size) / (document frequency of term)), which could
        result in a divizion by zero error.
    name: (Optional) A name for this operation.

  Returns:
    Two `SparseTensor`s with indices [index_in_batch, index_in_bag_of_words].
    The first has values vocab_index, which is taken from input `x`.
    The second has values tfidf_weight.
  """

  def _to_vocab_range(x):
    """Enforces that the vocab_ids in x are positive."""
    return tf.SparseTensor(
        indices=x.indices,
        values=tf.mod(x.values, vocab_size),
        dense_shape=x.dense_shape)

  with tf.name_scope(name, 'tfidf'):
    cleaned_input = _to_vocab_range(x)

    term_frequencies = _to_term_frequency(cleaned_input, vocab_size)

    count_docs_with_term_column = _count_docs_with_term(term_frequencies)
    # Expand dims to get around the min_tensor_rank checks
    sizes = tf.expand_dims(tf.shape(cleaned_input)[0], 0)
    # [batch, vocab] - tfidf
    tfidfs = _to_tfidf(term_frequencies,
                       analyzers.sum(count_docs_with_term_column,
                                     reduce_instance_dims=False),
                       analyzers.sum(sizes),
                       smooth)
    return _split_tfidfs_to_outputs(tfidfs)
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def bucketize(x, num_buckets, epsilon=None, name=None):
  """Returns a bucketized column, with a bucket index assigned to each input.

  Args:
    x: A numeric input `Tensor` whose values should be mapped to buckets.
    num_buckets: Values in the input `x` are divided into approximately
      equal-sized buckets, where the number of buckets is num_buckets.
    epsilon: (Optional) Error tolerance, typically a small fraction close to
      zero. If a value is not specified by the caller, a suitable value is
      computed based on experimental results.  For `num_buckets` less than 100,
      the value of 0.01 is chosen to handle a dataset of up to ~1 trillion input
      data values.  If `num_buckets` is larger, then epsilon is set to
      (1/`num_buckets`) to enforce a stricter error tolerance, because more
      buckets will result in smaller range for each bucket, and so we want the
      the boundaries to be less fuzzy.
      See analyzers.quantiles() for details.
    name: (Optional) A name for this operation.

  Returns:
    A `Tensor` of the same shape as `x`, with each element in the
    returned tensor representing the bucketized value. Bucketized value is
    in the range [0, num_buckets).

  Raises:
    ValueError: If value of num_buckets is not > 1.
  """
  with tf.name_scope(name, 'bucketize'):
    if not isinstance(num_buckets, int):
      raise TypeError('num_buckets must be an int, got %s', type(num_buckets))

    if num_buckets < 1:
      raise ValueError('Invalid num_buckets %d', num_buckets)

    if epsilon is None:
      # See explanation in args documentation for epsilon.
      epsilon = min(1.0 / num_buckets, 0.01)

    bucket_boundaries = analyzers.quantiles(x, num_buckets, epsilon)
    buckets = quantile_ops.bucketize_with_input_boundaries(
        x,
        boundaries=bucket_boundaries,
        name='assign_buckets')

    # Convert to int64 because int32 is not compatible with tf.Example parser.
    # See _TF_EXAMPLE_ALLOWED_TYPES in FixedColumnRepresentation()
    # in tf_metadata/dataset_schema.py
    return tf.to_int64(buckets)
项目:FYP-AutoTextSum    作者:MrRexZ    | 项目源码 | 文件源码
def __init__(self, data_path, vocab, hps,
               article_key, abstract_key, max_article_sentences,
               max_abstract_sentences, bucketing=True, truncate_input=False):
    """Batcher constructor.

    Args:
      data_path: tf.Example filepattern.
      vocab: Vocabulary.
      hps: Seq2SeqAttention model hyperparameters.
      article_key: article feature key in tf.Example.
      abstract_key: abstract feature key in tf.Example.
      max_article_sentences: Max number of sentences used from article.
      max_abstract_sentences: Max number of sentences used from abstract.
      bucketing: Whether bucket articles of similar length into the same batch.
      truncate_input: Whether to truncate input that is too long. Alternative is
        to discard such examples.
    """
    self._data_path = data_path
    self._vocab = vocab
    self._hps = hps
    self._article_key = article_key
    self._abstract_key = abstract_key
    self._max_article_sentences = max_article_sentences
    self._max_abstract_sentences = max_abstract_sentences
    self._bucketing = bucketing
    self._truncate_input = truncate_input
    self._input_queue = Queue.Queue(QUEUE_NUM_BATCH * self._hps.batch_size)
    self._bucket_input_queue = Queue.Queue(QUEUE_NUM_BATCH)
    self._input_threads = []
    for _ in xrange(16):
      self._input_threads.append(Thread(target=self._FillInputQueue))
      self._input_threads[-1].daemon = True
      self._input_threads[-1].start()
    self._bucketing_threads = []
    for _ in xrange(4):
      self._bucketing_threads.append(Thread(target=self._FillBucketInputQueue))
      self._bucketing_threads[-1].daemon = True
      self._bucketing_threads[-1].start()

    self._watch_thread = Thread(target=self._WatchThreads)
    self._watch_thread.daemon = True
    self._watch_thread.start()