Python tensorflow 模块,tables_initializer() 实例源码

我们从Python开源项目中,提取了以下44个代码示例,用于说明如何使用tensorflow.tables_initializer()

项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def __init__(self, saved_model_dir, input_schema, exclude_outputs,
                 tf_config):
      self.saved_model_dir = saved_model_dir
      self.session = tf.Session(graph=tf.Graph(), config=tf_config)
      with self.session.graph.as_default():
        with tf.Session(config=tf_config):
          inputs, outputs = saved_transform_io.partially_apply_saved_transform(
              saved_model_dir, {})
        self.session.run(tf.tables_initializer())

        input_schema_keys = input_schema.column_schemas.keys()
        extra_input_keys = set(input_schema_keys).difference(inputs.keys())
        if extra_input_keys:
          raise ValueError('Input schema contained keys not in graph: %s' %
                           input_schema_keys)
        extra_output_keys = set(exclude_outputs).difference(outputs.keys())
        if extra_output_keys:
          raise ValueError('Excluded outputs contained keys not in graph: %s' %
                           exclude_outputs)
        non_excluded_output_keys = set(outputs.keys()).difference(
            exclude_outputs)
        self.inputs = {key: inputs[key] for key in input_schema_keys}
        self.outputs = {key: outputs[key] for key in non_excluded_output_keys}
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def test_table_roundtrip(self):
    export_path = os.path.join(tempfile.mkdtemp(), 'export')

    with tf.Graph().as_default():
      with tf.Session().as_default() as session:
        input_string = tf.placeholder(tf.string)
        # Map string through a table, in this case based on a constant tensor.
        table = lookup.index_table_from_tensor(
            tf.constant(['cat', 'dog', 'giraffe']))
        output = table.lookup(input_string)
        inputs = {'input': input_string}
        outputs = {'output': output}
        saved_transform_io.write_saved_transform_from_session(
            session, inputs, outputs, export_path)

    with tf.Graph().as_default():
      with tf.Session().as_default() as session:
        # Using a computed input gives confidence that the graphs are fused.
        input_string = tf.constant('dog')
        inputs = {'input': input_string}
        outputs = saved_transform_io.apply_saved_transform(export_path, inputs)
        session.run(tf.tables_initializer())
        result = session.run(outputs['output'])
        self.assertEqual(1, result)
项目:nmt    作者:tensorflow    | 项目源码 | 文件源码
def _createTestInferModel(
      self, m_creator, hparams, sess, init_global_vars=False):
    infer_mode = tf.contrib.learn.ModeKeys.INFER
    (infer_iterator, src_vocab_table,
     tgt_vocab_table, reverse_tgt_vocab_table) = (
         common_test_utils.create_test_iterator(hparams, infer_mode))
    infer_m = m_creator(
        hparams,
        infer_mode,
        infer_iterator,
        src_vocab_table,
        tgt_vocab_table,
        reverse_tgt_vocab_table,
        scope='dynamic_seq2seq')
    if init_global_vars:
      sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())
    sess.run(infer_iterator.initializer)
    return infer_m
项目:botcycle    作者:D2KLab    | 项目源码 | 文件源码
def __init__(self, model_path, embedding_size, language, nlp):

        # Step 1: restore the meta graph

        with tf.Graph().as_default() as graph:
            saver = tf.train.import_meta_graph(model_path + "model.ckpt.meta")

            self.graph = graph

            # get tensors for inputs and outputs by name
            self.decoder_prediction = graph.get_tensor_by_name('decoder_prediction:0')
            self.intent = graph.get_tensor_by_name('intent:0')
            self.words_inputs = graph.get_tensor_by_name('words_inputs:0')
            self.encoder_inputs_actual_length = graph.get_tensor_by_name('encoder_inputs_actual_length:0')
            # redefine the py_func that is not serializable
            def static_wrapper(words):
                return spacy_wrapper(embedding_size, language, nlp, words)

            after_py_func = tf.py_func(static_wrapper, [self.words_inputs], tf.float32, stateful=False)

            # Step 2: restore weights
            self.sess = tf.Session()
            self.sess.run(tf.tables_initializer())
            saver.restore(self.sess, model_path + "model.ckpt")
项目:GNMT2    作者:Mingyearn    | 项目源码 | 文件源码
def _createTestInferModel(
      self, m_creator, hparams, sess, init_global_vars=False):
    infer_mode = tf.contrib.learn.ModeKeys.INFER
    infer_iterator, src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table = (
        common_test_utils.create_test_iterator(hparams, infer_mode))
    infer_m = m_creator(
        hparams,
        infer_mode,
        infer_iterator,
        src_vocab_table,
        tgt_vocab_table,
        reverse_tgt_vocab_table,
        scope='dynamic_seq2seq')
    if init_global_vars:
      sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())
    sess.run(infer_iterator.initializer)
    return infer_m
项目:GNMT2    作者:Mingyearn    | 项目源码 | 文件源码
def create_or_load_model(model, model_dir, session, out_dir, name):
  """Create translation model and initialize or load parameters in session."""
  start_time = time.time()
  latest_ckpt = tf.train.latest_checkpoint(model_dir)
  if latest_ckpt:
    model.saver.restore(session, latest_ckpt)
    utils.print_out(
        "  loaded %s model parameters from %s, time %.2fs" %
        (name, latest_ckpt, time.time() - start_time))
  else:
    utils.print_out("  created %s model with fresh parameters, time %.2fs." %
                    (name, time.time() - start_time))
    session.run(tf.global_variables_initializer())

  session.run(tf.tables_initializer())

  global_step = model.global_step.eval(session=session)
  return model, global_step
项目:pydatalab    作者:googledatalab    | 项目源码 | 文件源码
def _run_graph(self, analysis_path, features, schema, stats, predict_data):
    """Runs the preprocessing graph.

    Args:
      analysis_path: path to folder containing analysis output. Should contain
          the stats file.
      features: features dict
      schema: schema list
      stats: stats dict
      predict_data: list of csv strings
    """
    stats = {'column_stats': {}}
    with tf.Graph().as_default():
      with tf.Session().as_default() as session:
        outputs, labels, inputs = feature_transforms.build_csv_serving_tensors_for_transform_step(
            analysis_path, features, schema, stats, keep_target=False)
        feed_inputs = {inputs['csv_example']: predict_data}

        session.run(tf.tables_initializer())
        result = session.run(outputs, feed_dict=feed_inputs)
        return result
项目:pydatalab    作者:googledatalab    | 项目源码 | 文件源码
def start_bundle(self, element=None):
    """Build the transfromation graph once."""
    import tensorflow as tf
    from trainer import feature_transforms

    g = tf.Graph()
    session = tf.Session(graph=g)

    # Build the transformation graph
    with g.as_default():
      transformed_features, _, placeholders = (
          feature_transforms.build_csv_serving_tensors_for_transform_step(
              analysis_path=self._analysis_output_dir, 
              features=self._features, 
              schema=self._schema,
              stats=self._stats,
              keep_target=True))
      session.run(tf.tables_initializer())

    self._session = session
    self._transformed_features = transformed_features
    self._input_placeholder_tensor = placeholders['csv_example']
项目:pydatalab    作者:googledatalab    | 项目源码 | 文件源码
def start_bundle(self, element=None):
    """Build the transfromation graph once."""
    import tensorflow as tf
    from trainer import feature_transforms

    g = tf.Graph()
    session = tf.Session(graph=g)

    # Build the transformation graph
    with g.as_default():
      transformed_features, _, placeholders = (
          feature_transforms.build_csv_serving_tensors_for_transform_step(
              analysis_path=self._analysis_output_dir, 
              features=self._features, 
              schema=self._schema,
              stats=self._stats,
              keep_target=True))
      session.run(tf.tables_initializer())

    self._session = session
    self._transformed_features = transformed_features
    self._input_placeholder_tensor = placeholders['csv_example']
项目:seq2seq    作者:google    | 项目源码 | 文件源码
def _test_pipeline(self, mode, params=None):
    """Helper function to test the full model pipeline.
    """
    # Create source and target example
    source_len = self.sequence_length + 5
    target_len = self.sequence_length + 10
    source = " ".join(np.random.choice(self.vocab_list, source_len))
    target = " ".join(np.random.choice(self.vocab_list, target_len))
    sources_file, targets_file = test_utils.create_temp_parallel_data(
        sources=[source], targets=[target])

    # Build model graph
    model = self.create_model(mode, params)
    input_pipeline_ = input_pipeline.ParallelTextInputPipeline(
        params={
            "source_files": [sources_file.name],
            "target_files": [targets_file.name]
        },
        mode=mode)
    input_fn = training_utils.create_input_fn(
        pipeline=input_pipeline_, batch_size=self.batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)
    fetches = [_ for _ in fetches if _ is not None]

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())
      with tf.contrib.slim.queues.QueueRunners(sess):
        fetches_ = sess.run(fetches)

    sources_file.close()
    targets_file.close()

    return model, fetches_
项目:seq2seq    作者:google    | 项目源码 | 文件源码
def test_without_counts(self):
    vocab_list = ["Hello", ".", "?"]
    vocab_file = test_utils.create_temporary_vocab_file(vocab_list)

    vocab_to_id_table, id_to_vocab_table, _, vocab_size = \
      vocab.create_vocabulary_lookup_table(vocab_file.name)

    self.assertEqual(vocab_size, 6)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())

      ids = vocab_to_id_table.lookup(
          tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
      ids = sess.run(ids)
      np.testing.assert_array_equal(ids, [0, 1, 2, 3, 3])

      words = id_to_vocab_table.lookup(
          tf.convert_to_tensor(
              [0, 1, 2, 3], dtype=tf.int64))
      words = sess.run(words)
      np.testing.assert_array_equal(
          np.char.decode(words.astype("S"), "utf-8"),
          ["Hello", ".", "?", "UNK"])
项目:seq2seq    作者:google    | 项目源码 | 文件源码
def test_with_counts(self):
    vocab_list = ["Hello", ".", "?"]
    vocab_counts = [100, 200, 300]
    vocab_file = test_utils.create_temporary_vocab_file(vocab_list,
                                                        vocab_counts)

    vocab_to_id_table, id_to_vocab_table, word_to_count_table, vocab_size = \
      vocab.create_vocabulary_lookup_table(vocab_file.name)

    self.assertEqual(vocab_size, 6)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())

      ids = vocab_to_id_table.lookup(
          tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
      ids = sess.run(ids)
      np.testing.assert_array_equal(ids, [0, 1, 2, 3, 3])

      words = id_to_vocab_table.lookup(
          tf.convert_to_tensor(
              [0, 1, 2, 3], dtype=tf.int64))
      words = sess.run(words)
      np.testing.assert_array_equal(
          np.char.decode(words.astype("S"), "utf-8"),
          ["Hello", ".", "?", "UNK"])

      counts = word_to_count_table.lookup(
          tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
      counts = sess.run(counts)
      np.testing.assert_array_equal(counts, [100, 200, 300, -1, -1])
项目:seq2seq    作者:google    | 项目源码 | 文件源码
def test_sampling(self):
    hook = hooks.TrainSampleHook(
        params={"every_n_steps": 10}, model_dir=self.model_dir,
        run_config=tf.contrib.learn.RunConfig())

    global_step = tf.contrib.framework.get_or_create_global_step()
    no_op = tf.no_op()
    hook.begin()
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())

      #pylint: disable=W0212
      mon_sess = monitored_session._HookedSession(sess, [hook])
      # Should trigger for step 0
      sess.run(tf.assign(global_step, 0))
      mon_sess.run(no_op)

      outfile = os.path.join(self.sample_dir, "samples_000000.txt")
      with open(outfile, "rb") as readfile:
        self.assertIn("Prediction followed by Target @ Step 0",
                      readfile.read().decode("utf-8"))

      # Should not trigger for step 9
      sess.run(tf.assign(global_step, 9))
      mon_sess.run(no_op)
      outfile = os.path.join(self.sample_dir, "samples_000009.txt")
      self.assertFalse(os.path.exists(outfile))

      # Should trigger for step 10
      sess.run(tf.assign(global_step, 10))
      mon_sess.run(no_op)
      outfile = os.path.join(self.sample_dir, "samples_000010.txt")
      with open(outfile, "rb") as readfile:
        self.assertIn("Prediction followed by Target @ Step 10",
                      readfile.read().decode("utf-8"))
项目:ISLES2017    作者:MiguelMonteiro    | 项目源码 | 文件源码
def _run_eval(self):
        """Run model evaluation and generate summaries."""
        coord = tf.train.Coordinator(clean_stop_exception_types=(
            tf.errors.CancelledError, tf.errors.OutOfRangeError))

        with tf.Session(graph=self._graph) as session:
            # Restores previously saved variables from latest checkpoint
            self._saver.restore(session, self._latest_checkpoint)

            session.run([tf.tables_initializer(), tf.local_variables_initializer()])
            tf.train.start_queue_runners(coord=coord, sess=session)
            train_step = session.run(self._gs)

            tf.logging.info('Starting evaluation')
            d = {key: [] for key in ['loss', 'accuracy', 'dice_coefficient', 'hausdorff_distance',
                                     'average_symmetric_surface_distance']}
            with coord.stop_on_exception():
                while not coord.should_stop():
                    metric_dict = session.run(self._metric_dict)

                    prediction = metric_dict.pop('prediction')
                    ground_truth = metric_dict.pop('ground_truth')

                    d['loss'].append(metric_dict.pop('loss'))
                    d['accuracy'].append(metric_dict.pop('accuracy'))
                    d['dice_coefficient'].append(metric_dict.pop('dice_coefficient'))
                    d['hausdorff_distance'].append(hd(prediction, ground_truth))
                    d['average_symmetric_surface_distance'].append(assd(prediction, ground_truth))

            # Save histogram, mean and std for each variable
            for key, value in d.iteritems():
                self.logger.log_histogram(tag=key, values=value, step=train_step, bins=15)
                self.logger.log_random_variable(tag='eval_'+key, var=value, step=train_step)
            tf.logging.info('Finished evaluation')
项目:cloudml-samples    作者:GoogleCloudPlatform    | 项目源码 | 文件源码
def _run_eval(self):
    """Run model evaluation and generate summaries."""
    coord = tf.train.Coordinator(clean_stop_exception_types=(
        tf.errors.CancelledError, tf.errors.OutOfRangeError))

    with tf.Session(graph=self._graph) as session:
      # Restores previously saved variables from latest checkpoint
      self._saver.restore(session, self._latest_checkpoint)

      session.run([
          tf.tables_initializer(),
          tf.local_variables_initializer()
      ])
      tf.train.start_queue_runners(coord=coord, sess=session)
      train_step = session.run(self._gs)

      tf.logging.info('Starting Evaluation For Step: {}'.format(train_step))
      with coord.stop_on_exception():
        eval_step = 0
        while not coord.should_stop() and (self._eval_steps is None or
                                           eval_step < self._eval_steps):
          summaries, final_values, _ = session.run(
              [self._summary_op, self._final_ops_dict, self._eval_ops])
          if eval_step % 100 == 0:
            tf.logging.info("On Evaluation Step: {}".format(eval_step))
          eval_step += 1

      # Write the summaries
      self._file_writer.add_summary(summaries, global_step=train_step)
      self._file_writer.flush()
      tf.logging.info(final_values)
项目:cloudml-samples    作者:GoogleCloudPlatform    | 项目源码 | 文件源码
def main_op():
  init_local = variables.local_variables_initializer()
  init_tables = lookup_ops.tables_initializer()
  return control_flow_ops.group(init_local, init_tables)
项目:conv_seq2seq    作者:tobyyouup    | 项目源码 | 文件源码
def _test_pipeline(self, mode, params=None):
    """Helper function to test the full model pipeline.
    """
    # Create source and target example
    source_len = self.sequence_length + 5
    target_len = self.sequence_length + 10
    source = " ".join(np.random.choice(self.vocab_list, source_len))
    target = " ".join(np.random.choice(self.vocab_list, target_len))
    sources_file, targets_file = test_utils.create_temp_parallel_data(
        sources=[source], targets=[target])

    # Build model graph
    model = self.create_model(mode, params)
    input_pipeline_ = input_pipeline.ParallelTextInputPipeline(
        params={
            "source_files": [sources_file.name],
            "target_files": [targets_file.name]
        },
        mode=mode)
    input_fn = training_utils.create_input_fn(
        pipeline=input_pipeline_, batch_size=self.batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)
    fetches = [_ for _ in fetches if _ is not None]

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())
      with tf.contrib.slim.queues.QueueRunners(sess):
        fetches_ = sess.run(fetches)

    sources_file.close()
    targets_file.close()

    return model, fetches_
项目:conv_seq2seq    作者:tobyyouup    | 项目源码 | 文件源码
def test_without_counts(self):
    vocab_list = ["Hello", ".", "?"]
    vocab_file = test_utils.create_temporary_vocab_file(vocab_list)

    vocab_to_id_table, id_to_vocab_table, _, vocab_size = \
      vocab.create_vocabulary_lookup_table(vocab_file.name)

    self.assertEqual(vocab_size, 6)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())

      ids = vocab_to_id_table.lookup(
          tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
      ids = sess.run(ids)
      np.testing.assert_array_equal(ids, [0, 1, 2, 3, 3])

      words = id_to_vocab_table.lookup(
          tf.convert_to_tensor(
              [0, 1, 2, 3], dtype=tf.int64))
      words = sess.run(words)
      np.testing.assert_array_equal(
          np.char.decode(words.astype("S"), "utf-8"),
          ["Hello", ".", "?", "UNK"])
项目:conv_seq2seq    作者:tobyyouup    | 项目源码 | 文件源码
def test_with_counts(self):
    vocab_list = ["Hello", ".", "?"]
    vocab_counts = [100, 200, 300]
    vocab_file = test_utils.create_temporary_vocab_file(vocab_list,
                                                        vocab_counts)

    vocab_to_id_table, id_to_vocab_table, word_to_count_table, vocab_size = \
      vocab.create_vocabulary_lookup_table(vocab_file.name)

    self.assertEqual(vocab_size, 6)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())

      ids = vocab_to_id_table.lookup(
          tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
      ids = sess.run(ids)
      np.testing.assert_array_equal(ids, [0, 1, 2, 3, 3])

      words = id_to_vocab_table.lookup(
          tf.convert_to_tensor(
              [0, 1, 2, 3], dtype=tf.int64))
      words = sess.run(words)
      np.testing.assert_array_equal(
          np.char.decode(words.astype("S"), "utf-8"),
          ["Hello", ".", "?", "UNK"])

      counts = word_to_count_table.lookup(
          tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
      counts = sess.run(counts)
      np.testing.assert_array_equal(counts, [100, 200, 300, -1, -1])
项目:conv_seq2seq    作者:tobyyouup    | 项目源码 | 文件源码
def test_sampling(self):
    hook = hooks.TrainSampleHook(
        params={"every_n_steps": 10}, model_dir=self.model_dir,
        run_config=tf.contrib.learn.RunConfig())

    global_step = tf.contrib.framework.get_or_create_global_step()
    no_op = tf.no_op()
    hook.begin()
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())

      #pylint: disable=W0212
      mon_sess = monitored_session._HookedSession(sess, [hook])
      # Should trigger for step 0
      sess.run(tf.assign(global_step, 0))
      mon_sess.run(no_op)

      outfile = os.path.join(self.sample_dir, "samples_000000.txt")
      with open(outfile, "rb") as readfile:
        self.assertIn("Prediction followed by Target @ Step 0",
                      readfile.read().decode("utf-8"))

      # Should not trigger for step 9
      sess.run(tf.assign(global_step, 9))
      mon_sess.run(no_op)
      outfile = os.path.join(self.sample_dir, "samples_000009.txt")
      self.assertFalse(os.path.exists(outfile))

      # Should trigger for step 10
      sess.run(tf.assign(global_step, 10))
      mon_sess.run(no_op)
      outfile = os.path.join(self.sample_dir, "samples_000010.txt")
      with open(outfile, "rb") as readfile:
        self.assertIn("Prediction followed by Target @ Step 10",
                      readfile.read().decode("utf-8"))
项目:transform    作者:tensorflow    | 项目源码 | 文件源码
def assertSparseOutput(self, expected_indices, expected_values,
                         expected_shape, actual_sparse_tensor, close_values):
    with tf.Session() as sess:
      sess.run(tf.tables_initializer())
      actual = actual_sparse_tensor.eval()
      self.assertAllEqual(expected_indices, actual.indices)
      self.assertAllEqual(expected_shape, actual.dense_shape)
      if close_values:
        self.assertAllClose(expected_values, actual.values)
      else:
        self.assertAllEqual(expected_values, actual.values)
项目:nmt    作者:tensorflow    | 项目源码 | 文件源码
def _createTestTrainModel(self, m_creator, hparams, sess):
    train_mode = tf.contrib.learn.ModeKeys.TRAIN
    train_iterator, src_vocab_table, tgt_vocab_table = (
        common_test_utils.create_test_iterator(hparams, train_mode))
    train_m = m_creator(
        hparams,
        train_mode,
        train_iterator,
        src_vocab_table,
        tgt_vocab_table,
        scope='dynamic_seq2seq')
    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())
    sess.run(train_iterator.initializer)
    return train_m
项目:nmt    作者:tensorflow    | 项目源码 | 文件源码
def _createTestEvalModel(self, m_creator, hparams, sess):
    eval_mode = tf.contrib.learn.ModeKeys.EVAL
    eval_iterator, src_vocab_table, tgt_vocab_table = (
        common_test_utils.create_test_iterator(hparams, eval_mode))
    eval_m = m_creator(
        hparams,
        eval_mode,
        eval_iterator,
        src_vocab_table,
        tgt_vocab_table,
        scope='dynamic_seq2seq')
    sess.run(tf.tables_initializer())
    sess.run(eval_iterator.initializer)
    return eval_m
项目:nmt    作者:tensorflow    | 项目源码 | 文件源码
def load_model(model, ckpt, session, name):
  start_time = time.time()
  model.saver.restore(session, ckpt)
  session.run(tf.tables_initializer())
  utils.print_out(
      "  loaded %s model parameters from %s, time %.2fs" %
      (name, ckpt, time.time() - start_time))
  return model
项目:tefla    作者:openAGI    | 项目源码 | 文件源码
def test_sampling(self):
        hook = learner_hooks.TrainSampleHook(
            params={"every_n_steps": 10}, model_dir=self.model_dir,
            run_config=tf.contrib.learn.RunConfig())

        global_step = tf.contrib.framework.get_or_create_global_step()
        no_op = tf.no_op()
        hook.begin()
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())

            mon_sess = monitored_session._HookedSession(sess, [hook])
            # Should trigger for step 0
            sess.run(tf.assign(global_step, 0))
            mon_sess.run(no_op)

            outfile = os.path.join(self.sample_dir, "samples_000000.txt")
            with open(outfile, "rb") as readfile:
                self.assertIn("Prediction followed by Target @ Step 0",
                              readfile.read().decode("utf-8"))

            # Should not trigger for step 9
            sess.run(tf.assign(global_step, 9))
            mon_sess.run(no_op)
            outfile = os.path.join(self.sample_dir, "samples_000009.txt")
            self.assertFalse(os.path.exists(outfile))

            # Should trigger for step 10
            sess.run(tf.assign(global_step, 10))
            mon_sess.run(no_op)
            outfile = os.path.join(self.sample_dir, "samples_000010.txt")
            with open(outfile, "rb") as readfile:
                self.assertIn("Prediction followed by Target @ Step 10",
                              readfile.read().decode("utf-8"))
项目:tefla    作者:openAGI    | 项目源码 | 文件源码
def test_without_counts(self):
        vocab_list = ["Hello", ".", "?"]
        vocab_file = create_temporary_vocab_file(vocab_list)

        vocab_to_id_table, id_to_vocab_table, _, vocab_size = \
            vocabulary.create_vocabulary_lookup_table(vocab_file.name)

        self.assertEqual(vocab_size, 6)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())

            ids = vocab_to_id_table.lookup(
                tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
            ids = sess.run(ids)
            self.assertAllEqual(ids, [0, 1, 2, 3, 3])

            words = id_to_vocab_table.lookup(
                tf.convert_to_tensor(
                    [0, 1, 2, 3], dtype=tf.int64))
            words = sess.run(words)
            self.assertAllEqual(
                np.char.decode(words.astype("S"), "utf-8"),
                ["Hello", ".", "?", "UNK"])
项目:tefla    作者:openAGI    | 项目源码 | 文件源码
def test_with_counts(self):
        vocab_list = ["Hello", ".", "?"]
        vocab_counts = [100, 200, 300]
        vocab_file = create_temporary_vocab_file(vocab_list,
                                                 vocab_counts)

        vocab_to_id_table, id_to_vocab_table, word_to_count_table, vocab_size = \
            vocabulary.create_vocabulary_lookup_table(vocab_file.name)

        self.assertEqual(vocab_size, 6)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())

            ids = vocab_to_id_table.lookup(
                tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
            ids = sess.run(ids)
            self.assertAllEqual(ids, [0, 1, 2, 3, 3])

            words = id_to_vocab_table.lookup(
                tf.convert_to_tensor(
                    [0, 1, 2, 3], dtype=tf.int64))
            words = sess.run(words)
            self.assertAllEqual(
                np.char.decode(words.astype("S"), "utf-8"),
                ["Hello", ".", "?", "UNK"])

            counts = word_to_count_table.lookup(
                tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
            counts = sess.run(counts)
            self.assertAllEqual(counts, [100, 200, 300, -1, -1])
项目:GNMT2    作者:Mingyearn    | 项目源码 | 文件源码
def _createTestTrainModel(self, m_creator, hparams, sess):
    train_mode = tf.contrib.learn.ModeKeys.TRAIN
    train_iterator, src_vocab_table, tgt_vocab_table = common_test_utils.create_test_iterator(
        hparams, train_mode)
    train_m = m_creator(
        hparams,
        train_mode,
        train_iterator,
        src_vocab_table,
        tgt_vocab_table,
        scope='dynamic_seq2seq')
    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())
    sess.run(train_iterator.initializer)
    return train_m
项目:GNMT2    作者:Mingyearn    | 项目源码 | 文件源码
def _createTestEvalModel(self, m_creator, hparams, sess):
    eval_mode = tf.contrib.learn.ModeKeys.EVAL
    eval_iterator, src_vocab_table, tgt_vocab_table = common_test_utils.create_test_iterator(
        hparams, eval_mode)
    eval_m = m_creator(
        hparams,
        eval_mode,
        eval_iterator,
        src_vocab_table,
        tgt_vocab_table,
        scope='dynamic_seq2seq')
    sess.run(tf.tables_initializer())
    sess.run(eval_iterator.initializer)
    return eval_m
项目:tensorfx    作者:TensorLab    | 项目源码 | 文件源码
def build_init(self):
    """Builds the initialization sub-graph.

    The default implementation creates an initialization op that initializes all variables,
    locals for initialization, and another for all non-traininable variables and tables for local
    initialization.

    Initialization is run when the graph is first created, before training. Local initialization is
    performed after a previously trained model is loaded.

    Returns:
      A tuple containing the init op and local init op to use to initialize the graph.
    """
    init_op = tf.variables_initializer(tf.global_variables(), name='init')

    # For some reason not all local variables are in the local variables collection, but some are in
    # the global variables collection (such as those setup by reader ops).
    # So in addition to initializing local variables in the local_init_op, we also initialize the
    # set of variables in the global variables, that are not trainable.
    # Just to add to the mix, tables are neither, and so must be explicitly included as well.
    # All of these will be initialized after restoring from a checkpoint.
    variables = tf.global_variables()
    for trainable in tf.trainable_variables():
      variables.remove(trainable)

    local_init_op = tf.group(tf.variables_initializer(variables),
                             tf.variables_initializer(tf.local_variables()),
                             tf.tables_initializer(),
                             name='local_init_op')

    # Add the local initialization op to the main op collection, which is looked up at model loading
    # time, and is automatically invoked after it has been loaded.
    tf.add_to_collection('saved_model_main_op', local_init_op)

    return init_op, local_init_op
项目:automatic-summarization    作者:mozilla    | 项目源码 | 文件源码
def _test_pipeline(self, mode, params=None):
    """Helper function to test the full model pipeline.
    """
    # Create source and target example
    source_len = self.sequence_length + 5
    target_len = self.sequence_length + 10
    source = " ".join(np.random.choice(self.vocab_list, source_len))
    target = " ".join(np.random.choice(self.vocab_list, target_len))
    sources_file, targets_file = test_utils.create_temp_parallel_data(
        sources=[source], targets=[target])

    # Build model graph
    model = self.create_model(mode, params)
    input_pipeline_ = input_pipeline.ParallelTextInputPipeline(
        params={
            "source_files": [sources_file.name],
            "target_files": [targets_file.name]
        },
        mode=mode)
    input_fn = training_utils.create_input_fn(
        pipeline=input_pipeline_, batch_size=self.batch_size)
    features, labels = input_fn()
    fetches = model(features, labels, None)
    fetches = [_ for _ in fetches if _ is not None]

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())
      with tf.contrib.slim.queues.QueueRunners(sess):
        fetches_ = sess.run(fetches)

    sources_file.close()
    targets_file.close()

    return model, fetches_
项目:automatic-summarization    作者:mozilla    | 项目源码 | 文件源码
def test_without_counts(self):
    vocab_list = ["Hello", ".", "?"]
    vocab_file = test_utils.create_temporary_vocab_file(vocab_list)

    vocab_to_id_table, id_to_vocab_table, _, vocab_size = \
      vocab.create_vocabulary_lookup_table(vocab_file.name)

    self.assertEqual(vocab_size, 6)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())

      ids = vocab_to_id_table.lookup(
          tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
      ids = sess.run(ids)
      np.testing.assert_array_equal(ids, [0, 1, 2, 3, 3])

      words = id_to_vocab_table.lookup(
          tf.convert_to_tensor(
              [0, 1, 2, 3], dtype=tf.int64))
      words = sess.run(words)
      np.testing.assert_array_equal(
          np.char.decode(words.astype("S"), "utf-8"),
          ["Hello", ".", "?", "UNK"])
项目:automatic-summarization    作者:mozilla    | 项目源码 | 文件源码
def test_with_counts(self):
    vocab_list = ["Hello", ".", "?"]
    vocab_counts = [100, 200, 300]
    vocab_file = test_utils.create_temporary_vocab_file(vocab_list,
                                                        vocab_counts)

    vocab_to_id_table, id_to_vocab_table, word_to_count_table, vocab_size = \
      vocab.create_vocabulary_lookup_table(vocab_file.name)

    self.assertEqual(vocab_size, 6)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())

      ids = vocab_to_id_table.lookup(
          tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
      ids = sess.run(ids)
      np.testing.assert_array_equal(ids, [0, 1, 2, 3, 3])

      words = id_to_vocab_table.lookup(
          tf.convert_to_tensor(
              [0, 1, 2, 3], dtype=tf.int64))
      words = sess.run(words)
      np.testing.assert_array_equal(
          np.char.decode(words.astype("S"), "utf-8"),
          ["Hello", ".", "?", "UNK"])

      counts = word_to_count_table.lookup(
          tf.convert_to_tensor(["Hello", ".", "?", "??", "xxx"]))
      counts = sess.run(counts)
      np.testing.assert_array_equal(counts, [100, 200, 300, -1, -1])
项目:automatic-summarization    作者:mozilla    | 项目源码 | 文件源码
def test_sampling(self):
    hook = hooks.TrainSampleHook(
        params={"every_n_steps": 10}, model_dir=self.model_dir,
        run_config=tf.contrib.learn.RunConfig())

    global_step = tf.contrib.framework.get_or_create_global_step()
    no_op = tf.no_op()
    hook.begin()
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      sess.run(tf.tables_initializer())

      #pylint: disable=W0212
      mon_sess = monitored_session._HookedSession(sess, [hook])
      # Should trigger for step 0
      sess.run(tf.assign(global_step, 0))
      mon_sess.run(no_op)

      outfile = os.path.join(self.sample_dir, "samples_000000.txt")
      with open(outfile, "rb") as readfile:
        self.assertIn("Prediction followed by Target @ Step 0",
                      readfile.read().decode("utf-8"))

      # Should not trigger for step 9
      sess.run(tf.assign(global_step, 9))
      mon_sess.run(no_op)
      outfile = os.path.join(self.sample_dir, "samples_000009.txt")
      self.assertFalse(os.path.exists(outfile))

      # Should trigger for step 10
      sess.run(tf.assign(global_step, 10))
      mon_sess.run(no_op)
      outfile = os.path.join(self.sample_dir, "samples_000010.txt")
      with open(outfile, "rb") as readfile:
        self.assertIn("Prediction followed by Target @ Step 10",
                      readfile.read().decode("utf-8"))
项目:kaggle-youtube-8m    作者:liufuyang    | 项目源码 | 文件源码
def _run_eval(self):
    """Run model evaluation and generate summaries."""
    coord = tf.train.Coordinator(clean_stop_exception_types=(
        tf.errors.CancelledError, tf.errors.OutOfRangeError))

    with tf.Session(graph=self._graph) as session:
      # Restores previously saved variables from latest checkpoint
      self._saver.restore(session, self._latest_checkpoint)

      session.run([
        tf.tables_initializer(),
        tf.local_variables_initializer()
      ])
      tf.train.start_queue_runners(coord=coord, sess=session)
      train_step = session.run(self._gs)

      tf.logging.info('Starting Evaluation For Step: {}'.format(train_step))
      with coord.stop_on_exception():
        eval_step = 0
        while self._eval_steps is None or eval_step < self._eval_steps:
          summaries, final_values, _ = session.run(
              [self._summary_op, self._final_ops_dict, self._eval_ops])
          if eval_step % 100 == 0:
            tf.logging.info("On Evaluation Step: {}".format(eval_step))
          eval_step += 1

      # Write the summaries
      self._file_writer.add_summary(summaries, global_step=train_step)
      self._file_writer.flush()
      tf.logging.info(final_values)
项目:nmt_v2    作者:rpryzant    | 项目源码 | 文件源码
def load_model(model, ckpt, session, name):
    start_time = time.time()
    model.saver.restore(session, ckpt)
    session.run(tf.tables_initializer())
    print "  loaded %s model parameters from %s, time %.2fs" % \
        (name, ckpt, time.time() - start_time)
    return model
项目:nmt_v2    作者:rpryzant    | 项目源码 | 文件源码
def create_or_load_model(model, model_dir, session, name):
    latest_ckpt = tf.train.latest_checkpoint(model_dir)

    if latest_ckpt:
        model = load_model(model, latest_ckpt, session, name)
    else:
        start_time = time.time()
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        print "  created %s model with fresh parameters, time %.2fs" % \
                        (name, time.time() - start_time)

    global_step = model.global_step.eval(session=session)
    return model, global_step
项目:pydatalab    作者:googledatalab    | 项目源码 | 文件源码
def export(self, last_checkpoint, output_dir):
    """Builds a prediction graph and xports the model.

    Args:
      last_checkpoint: Path to the latest checkpoint file from training.
      output_dir: Path to the folder to be used to output the model.
    """
    logging.info('Exporting prediction graph to %s', output_dir)
    with tf.Session(graph=tf.Graph()) as sess:
      # Build and save prediction meta graph and trained variable values.
      inputs, outputs = self.build_prediction_graph()
      signature_def_map = {
        'serving_default': signature_def_utils.predict_signature_def(inputs, outputs)
      }
      init_op = tf.global_variables_initializer()
      sess.run(init_op)
      self.restore_from_checkpoint(sess, self.inception_checkpoint_file,
                                   last_checkpoint)
      init_op_serving = control_flow_ops.group(
          variables.local_variables_initializer(),
          tf.tables_initializer())

      builder = saved_model_builder.SavedModelBuilder(output_dir)
      builder.add_meta_graph_and_variables(
          sess, [tag_constants.SERVING],
          signature_def_map=signature_def_map,
          legacy_init_op=init_op_serving)
      builder.save(False)
项目:cloudml-samples    作者:GoogleCloudPlatform    | 项目源码 | 文件源码
def build_and_run_exports(latest, job_dir, serving_input_fn, hidden_units):
  """Given the latest checkpoint file export the saved model.

  Args:
    latest (string): Latest checkpoint file
    job_dir (string): Location of checkpoints and model files
    name (string): Name of the checkpoint to be exported. Used in building the
      export path.
    hidden_units (list): Number of hidden units
    learning_rate (float): Learning rate for the SGD
  """

  prediction_graph = tf.Graph()
  exporter = tf.saved_model.builder.SavedModelBuilder(
      os.path.join(job_dir, 'export'))
  with prediction_graph.as_default():
    features, inputs_dict = serving_input_fn()
    prediction_dict = model.model_fn(
        model.PREDICT,
        features.copy(),
        None,  # labels
        hidden_units=hidden_units,
        learning_rate=None  # learning_rate unused in prediction mode
    )
    saver = tf.train.Saver()

    inputs_info = {
        name: tf.saved_model.utils.build_tensor_info(tensor)
        for name, tensor in inputs_dict.iteritems()
    }
    output_info = {
        name: tf.saved_model.utils.build_tensor_info(tensor)
        for name, tensor in prediction_dict.iteritems()
    }
    signature_def = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=inputs_info,
        outputs=output_info,
        method_name=sig_constants.PREDICT_METHOD_NAME
    )

  with tf.Session(graph=prediction_graph) as session:
    session.run([tf.local_variables_initializer(), tf.tables_initializer()])
    saver.restore(session, latest)
    exporter.add_meta_graph_and_variables(
        session,
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            sig_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
        },
        legacy_init_op=main_op()
    )

  exporter.save()
项目:dynamic-training-bench    作者:galeone    | 项目源码 | 文件源码
def extract_features(self,
                         checkpoint_path,
                         inputs,
                         layer_name,
                         num_classes=0):
        """Restore model parameters from checkpoint_path. Search in the model
        the layer with name `layer_name`. If found places `inputs` as input to the model
        and returns the values extracted by the layer.
        Args:
            checkpoint_path: path of the trained model checkpoint directory
            inputs: a Tensor with a shape compatible with the model's input
            layer_name: a string, the name of the layer to extract from model
            num_classes: number of classes to classify, this number must be equal to the number
            of classes the classifier was trained on, if the model is a classifier or however is
            a model class aware, otherwise let the number = 0
        Returns:
            features: a numpy ndarray that contains the extracted features
        """

        # Evaluate the inputs in the current default graph
        # then user a placeholder to inject the computed values into the new graph
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            evaluated_inputs = sess.run(inputs)

        # Create a new graph to not making dirty the default graph after subsequent
        # calls
        with tf.Graph().as_default() as graph:
            inputs_ = tf.placeholder(inputs.dtype, shape=inputs.shape)

            # Build a Graph that computes the predictions from the inference model.
            _ = self._model.get(
                inputs_, num_classes, train_phase=False, l2_penalty=0.0)

            # This will raise an exception if layer_name is not found
            layer = graph.get_tensor_by_name(layer_name)

            saver = tf.train.Saver(variables_to_restore())
            init = [
                tf.variables_initializer(
                    tf.global_variables() + tf.local_variables()),
                tf.tables_initializer()
            ]
            features = np.zeros(layer.shape)
            with tf.Session(config=tf.ConfigProto(
                    allow_soft_placement=True)) as sess:
                ckpt = tf.train.get_checkpoint_state(checkpoint_path)
                if ckpt and ckpt.model_checkpoint_path:
                    # Restores from checkpoint
                    saver.restore(sess, ckpt.model_checkpoint_path)
                else:
                    print('[!] No checkpoint file found')
                    return features
                sess.run(init)
                features = sess.run(
                    layer, feed_dict={
                        inputs_: evaluated_inputs
                    })

            return features
项目:nmt    作者:tensorflow    | 项目源码 | 文件源码
def testGetInferIterator(self):
    src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant(["a", "b", "c", "eos", "sos"]))
    src_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["c c a", "c a", "d", "f e a g"]))
    hparams = tf.contrib.training.HParams(
        random_seed=3,
        eos="eos",
        sos="sos")
    batch_size = 2
    src_max_len = 3
    iterator = iterator_utils.get_infer_iterator(
        src_dataset=src_dataset,
        src_vocab_table=src_vocab_table,
        batch_size=batch_size,
        eos=hparams.eos,
        src_max_len=src_max_len)
    table_initializer = tf.tables_initializer()
    source = iterator.source
    seq_len = iterator.source_sequence_length
    self.assertEqual([None, None], source.shape.as_list())
    self.assertEqual([None], seq_len.shape.as_list())
    with self.test_session() as sess:
      sess.run(table_initializer)
      sess.run(iterator.initializer)

      (source_v, seq_len_v) = sess.run((source, seq_len))
      self.assertAllEqual(
          [[2, 2, 0],   # c c a
           [2, 0, 3]],  # c a eos
          source_v)
      self.assertAllEqual([3, 2], seq_len_v)

      (source_v, seq_len_v) = sess.run((source, seq_len))
      self.assertAllEqual(
          [[-1, 3, 3],    # "d" == unknown, eos eos
           [-1, -1, 0]],  # "f" == unknown, "e" == unknown, a
          source_v)
      self.assertAllEqual([1, 3], seq_len_v)

      with self.assertRaisesOpError("End of sequence"):
        sess.run((source, seq_len))
项目:GNMT2    作者:Mingyearn    | 项目源码 | 文件源码
def testGetInferIterator(self):
    src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant(["a", "b", "c", "eos", "sos"]))
    src_dataset = tf.contrib.data.Dataset.from_tensor_slices(
        tf.constant(["c c a", "c a", "d", "f e a g"]))
    hparams = tf.contrib.training.HParams(
        random_seed=3,
        source_reverse=False,
        eos="eos",
        sos="sos")
    batch_size = 2
    src_max_len = 3
    iterator = iterator_utils.get_infer_iterator(
        src_dataset=src_dataset,
        src_vocab_table=src_vocab_table,
        batch_size=batch_size,
        eos=hparams.eos,
        source_reverse=hparams.source_reverse,
        src_max_len=src_max_len)
    table_initializer = tf.tables_initializer()
    source = iterator.source
    seq_len = iterator.source_sequence_length
    self.assertEqual([None, None], source.shape.as_list())
    self.assertEqual([None], seq_len.shape.as_list())
    with self.test_session() as sess:
      sess.run(table_initializer)
      sess.run(iterator.initializer)

      (source_v, seq_len_v) = sess.run((source, seq_len))
      self.assertAllEqual(
          [[2, 2, 0],   # c c a
           [2, 0, 3]],  # c a eos
          source_v)
      self.assertAllEqual([3, 2], seq_len_v)

      (source_v, seq_len_v) = sess.run((source, seq_len))
      self.assertAllEqual(
          [[-1, 3, 3],    # "d" == unknown, eos eos
           [-1, -1, 0]],  # "f" == unknown, "e" == unknown, a
          source_v)
      self.assertAllEqual([1, 3], seq_len_v)

      with self.assertRaisesOpError("End of sequence"):
        sess.run((source, seq_len))
项目:tensorflow_fasttext    作者:apcode    | 项目源码 | 文件源码
def test_reading_inputs():
    parse_spec = {
        "text": tf.VarLenFeature(tf.string),
        "label": tf.FixedLenFeature(shape=(1,), dtype=tf.int64,
                                    default_value=None)
    }
    sess = tf.Session()
    reader = tf.python_io.tf_record_iterator(INPUT_FILE)
    ESZ = 4
    HSZ = 100
    NC = 4
    n = 0
    text_lookup_table = tf.contrib.lookup.index_table_from_file(
        VOCAB_FILE, 10, VOCAB_SIZE)
    text_embedding_w = tf.Variable(tf.random_uniform(
        [VOCAB_SIZE, ESZ], -1.0, 1.0))
    sess.run([tf.tables_initializer()])
    for record in reader:
        example = tf.parse_single_example(
            record,
            parse_spec)
        text = example["text"]
        labels = tf.subtract(example["label"], 1)
        text_ids = text_lookup_table.lookup(text)
        dense = tf.sparse_tensor_to_dense(text_ids)
        print dense.shape
        text_embedding = tf.reduce_mean(tf.nn.embedding_lookup(
            text_embedding_w, dense), axis=-2)
        print text_embedding.shape
        text_embedding = tf.expand_dims(text_embedding, -2)
        print text_embedding.shape
        text_embedding_2 = tf.contrib.layers.bow_encoder(
            dense, VOCAB_SIZE, ESZ)
        print text_embedding_2.shape
        num_classes = 2
        logits = tf.contrib.layers.fully_connected(
            inputs=text_embedding, num_outputs=4,
            activation_fn=None)
        sess.run([tf.global_variables_initializer()])
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=labels, logits=logits)
        x = sess.run([text_embedding, text_embedding_2, logits, labels, loss])
        print(len(x), list(str(x[i]) for i in range(len(x))))
        if n > 2:
            break
        n += 1
项目:kaggle-youtube-8m    作者:liufuyang    | 项目源码 | 文件源码
def build_and_run_exports(latest, job_dir, name, serving_input_fn, hidden_units):
  """Given the latest checkpoint file export the saved model.

  Args:
    latest (string): Latest checkpoint file
    job_dir (string): Location of checkpoints and model files
    name (string): Name of the checkpoint to be exported. Used in building the
      export path.
    hidden_units (list): Number of hidden units
    learning_rate (float): Learning rate for the SGD
  """

  prediction_graph = tf.Graph()
  exporter = tf.saved_model.builder.SavedModelBuilder(
      os.path.join(job_dir, 'export', name))
  with prediction_graph.as_default():
    features, inputs_dict = serving_input_fn()
    prediction_dict = model.model_fn(
        model.PREDICT,
        features,
        None,  # labels
        hidden_units=hidden_units,
        learning_rate=None  # learning_rate unused in prediction mode
    )
    saver = tf.train.Saver()

    inputs_info = {
        name: tf.saved_model.utils.build_tensor_info(tensor)
        for name, tensor in inputs_dict.iteritems()
    }
    output_info = {
        name: tf.saved_model.utils.build_tensor_info(tensor)
        for name, tensor in prediction_dict.iteritems()
    }
    signature_def = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=inputs_info,
        outputs=output_info,
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
    )


  with tf.Session(graph=prediction_graph) as session:
    session.run([tf.local_variables_initializer(), tf.tables_initializer()])
    saver.restore(session, latest)
    exporter.add_meta_graph_and_variables(
        session,
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
        },
    )

  exporter.save()