Python model 模块,model_fn() 实例源码

我们从Python开源项目中,提取了以下4个代码示例,用于说明如何使用model.model_fn()

项目:ISLES2017    作者:MiguelMonteiro    | 项目源码 | 文件源码
def _run_export(self):

        export_dir = 'export_ckpt_' + re.findall('\d+', self._latest_checkpoint)[-1]
        tf.logging.info('Exporting model from checkpoint {0}'.format(self._latest_checkpoint))
        prediction_graph = tf.Graph()
        try:
            exporter = tf.saved_model.builder.SavedModelBuilder(os.path.join(self._checkpoint_dir, export_dir))
        except IOError:
            tf.logging.info('Checkpoint {0} already exported, continuing...'.format(self._latest_checkpoint))
            return

        with prediction_graph.as_default():
            image, name, inputs_dict = model.serving_input_fn()
            prediction_dict = model.model_fn(model.PREDICT, name, image, None, 6, None)

            saver = tf.train.Saver()

            inputs_info = {name: tf.saved_model.utils.build_tensor_info(tensor)
                           for name, tensor in inputs_dict.iteritems()}

            output_info = {name: tf.saved_model.utils.build_tensor_info(tensor)
                           for name, tensor in prediction_dict.iteritems()}

            signature_def = tf.saved_model.signature_def_utils.build_signature_def(
                inputs=inputs_info,
                outputs=output_info,
                method_name=sig_constants.PREDICT_METHOD_NAME
            )

        with tf.Session(graph=prediction_graph) as session:
            saver.restore(session, self._latest_checkpoint)
            exporter.add_meta_graph_and_variables(
                session,
                tags=[tf.saved_model.tag_constants.SERVING],
                signature_def_map={sig_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def},
                legacy_init_op=my_main_op()
            )

        exporter.save()
项目:ISLES2017    作者:MiguelMonteiro    | 项目源码 | 文件源码
def run(target, is_chief, train_steps, job_dir, train_files, eval_files, num_epochs, learning_rate):
    num_channels = 6
    hooks = list()
    # does not work well in distributed mode cause it only counts local steps (I think...)
    hooks.append(tf.train.StopAtStepHook(train_steps))

    if is_chief:
        evaluation_graph = tf.Graph()
        with evaluation_graph.as_default():
            # Features and label tensors
            image, ground_truth, name = model.input_fn(eval_files, 1, shuffle=False, shared_name=None)
            # Returns dictionary of tensors to be evaluated
            metric_dict = model.model_fn(model.EVAL, name, image, ground_truth, num_channels, learning_rate)
            # hook that performs evaluation separate from training
            hooks.append(EvaluationRunHook(job_dir, metric_dict, evaluation_graph))
        hooks.append(CheckpointExporterHook(job_dir))

    # Create a new graph and specify that as default
    with tf.Graph().as_default():
        with tf.device(tf.train.replica_device_setter()):

            # Features and label tensors as read using filename queue
            image, ground_truth, name = model.input_fn(train_files, num_epochs, shuffle=True, shared_name='train_queue')

            # Returns the training graph and global step tensor
            train_op, log_hook, train_summaries = model.model_fn(model.TRAIN, name, image, ground_truth,
                                                                 num_channels, learning_rate)
            # Hook that logs training to the console
            hooks.append(log_hook)

            train_summary_hook = tf.train.SummarySaverHook(save_steps=1, output_dir=get_summary_dir(job_dir),
                                                           summary_op=train_summaries)
            hooks.append(train_summary_hook)

        # Creates a MonitoredSession for training
        # MonitoredSession is a Session-like object that handles
        # initialization, recovery and hooks
        # https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession
        with tf.train.MonitoredTrainingSession(master=target,
                                               is_chief=is_chief,
                                               checkpoint_dir=job_dir,
                                               hooks=hooks,
                                               save_checkpoint_secs=60*3,
                                               save_summaries_steps=1,
                                               log_step_count_steps=5) as session:
            # Run the training graph which returns the step number as tracked by
            # the global step tensor.
            # When train epochs is reached, session.should_stop() will be true.
            while not session.should_stop():
                session.run(train_op)
项目:cloudml-samples    作者:GoogleCloudPlatform    | 项目源码 | 文件源码
def build_and_run_exports(latest, job_dir, serving_input_fn, hidden_units):
  """Given the latest checkpoint file export the saved model.

  Args:
    latest (string): Latest checkpoint file
    job_dir (string): Location of checkpoints and model files
    name (string): Name of the checkpoint to be exported. Used in building the
      export path.
    hidden_units (list): Number of hidden units
    learning_rate (float): Learning rate for the SGD
  """

  prediction_graph = tf.Graph()
  exporter = tf.saved_model.builder.SavedModelBuilder(
      os.path.join(job_dir, 'export'))
  with prediction_graph.as_default():
    features, inputs_dict = serving_input_fn()
    prediction_dict = model.model_fn(
        model.PREDICT,
        features.copy(),
        None,  # labels
        hidden_units=hidden_units,
        learning_rate=None  # learning_rate unused in prediction mode
    )
    saver = tf.train.Saver()

    inputs_info = {
        name: tf.saved_model.utils.build_tensor_info(tensor)
        for name, tensor in inputs_dict.iteritems()
    }
    output_info = {
        name: tf.saved_model.utils.build_tensor_info(tensor)
        for name, tensor in prediction_dict.iteritems()
    }
    signature_def = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=inputs_info,
        outputs=output_info,
        method_name=sig_constants.PREDICT_METHOD_NAME
    )

  with tf.Session(graph=prediction_graph) as session:
    session.run([tf.local_variables_initializer(), tf.tables_initializer()])
    saver.restore(session, latest)
    exporter.add_meta_graph_and_variables(
        session,
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            sig_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
        },
        legacy_init_op=main_op()
    )

  exporter.save()
项目:kaggle-youtube-8m    作者:liufuyang    | 项目源码 | 文件源码
def build_and_run_exports(latest, job_dir, name, serving_input_fn, hidden_units):
  """Given the latest checkpoint file export the saved model.

  Args:
    latest (string): Latest checkpoint file
    job_dir (string): Location of checkpoints and model files
    name (string): Name of the checkpoint to be exported. Used in building the
      export path.
    hidden_units (list): Number of hidden units
    learning_rate (float): Learning rate for the SGD
  """

  prediction_graph = tf.Graph()
  exporter = tf.saved_model.builder.SavedModelBuilder(
      os.path.join(job_dir, 'export', name))
  with prediction_graph.as_default():
    features, inputs_dict = serving_input_fn()
    prediction_dict = model.model_fn(
        model.PREDICT,
        features,
        None,  # labels
        hidden_units=hidden_units,
        learning_rate=None  # learning_rate unused in prediction mode
    )
    saver = tf.train.Saver()

    inputs_info = {
        name: tf.saved_model.utils.build_tensor_info(tensor)
        for name, tensor in inputs_dict.iteritems()
    }
    output_info = {
        name: tf.saved_model.utils.build_tensor_info(tensor)
        for name, tensor in prediction_dict.iteritems()
    }
    signature_def = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=inputs_info,
        outputs=output_info,
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
    )


  with tf.Session(graph=prediction_graph) as session:
    session.run([tf.local_variables_initializer(), tf.tables_initializer()])
    saver.restore(session, latest)
    exporter.add_meta_graph_and_variables(
        session,
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
        },
    )

  exporter.save()