Python tensorflow.contrib.slim 模块,get_or_create_global_step() 实例源码

我们从Python开源项目中,提取了以下11个代码示例,用于说明如何使用tensorflow.contrib.slim.get_or_create_global_step()

项目:vessel-classification    作者:GlobalFishingWatch    | 项目源码 | 文件源码
def build_training_net(self, features, timestamps, mmsis):

        self.build_model(tf.constant(True), features)

        trainers = []
        for obj in self.objectives:
            trainers.append(obj.build_trainer(timestamps, mmsis))

        example = slim.get_or_create_global_step() * self.batch_size

        learning_rate = tf.train.exponential_decay(
            self.initial_learning_rate, example, self.decay_examples,
            self.learning_decay_rate)

        optimizer = tf.train.MomentumOptimizer(learning_rate, self.momentum)

        return TrainNetInfo(optimizer, trainers)
项目:deep_sort    作者:nwojke    | 项目源码 | 文件源码
def _create_image_encoder(preprocess_fn, factory_fn, image_shape, batch_size=32,
                         session=None, checkpoint_path=None,
                         loss_mode="cosine"):
    image_var = tf.placeholder(tf.uint8, (None, ) + image_shape)

    preprocessed_image_var = tf.map_fn(
        lambda x: preprocess_fn(x, is_training=False),
        tf.cast(image_var, tf.float32))

    l2_normalize = loss_mode == "cosine"
    feature_var, _ = factory_fn(
        preprocessed_image_var, l2_normalize=l2_normalize, reuse=None)
    feature_dim = feature_var.get_shape().as_list()[-1]

    if session is None:
        session = tf.Session()
    if checkpoint_path is not None:
        slim.get_or_create_global_step()
        init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
            checkpoint_path, slim.get_variables_to_restore())
        session.run(init_assign_op, feed_dict=init_feed_dict)

    def encoder(data_x):
        out = np.zeros((len(data_x), feature_dim), np.float32)
        _run_in_batches(
            lambda x: session.run(feature_var, feed_dict=x),
            {image_var: data_x}, out, batch_size)
        return out

    return encoder
项目:attention    作者:louishenrifranc    | 项目源码 | 文件源码
def get_model_fn(self):
        def model_fn(features, labels, mode, params=None, config=None):
            train_op = None
            loss = None
            eval_metrics = None
            predictions = None
            if mode == ModeKeys.TRAIN:
                transformer_model = TransformerModule(params=self.model_params)
                step = slim.get_or_create_global_step()
                loss = transformer_model(features)
                train_op = slim.optimize_loss(loss=loss,
                                              global_step=step,
                                              learning_rate=self.training_params["learning_rate"],
                                              clip_gradients=self.training_params["clip_gradients"],
                                              optimizer=params["optimizer"],
                                              summaries=slim.OPTIMIZER_SUMMARIES
                                              )
            elif mode == ModeKeys.PREDICT:
                raise NotImplementedError
            elif mode == ModeKeys.EVAL:
                transformer_model = TransformerModule(params=self.model_params)
                loss = transformer_model(features)

            return EstimatorSpec(train_op=train_op, loss=loss, eval_metric_ops=eval_metrics, predictions=predictions,
                                 mode=mode)

        return model_fn
项目:Bayesian-FlowNet    作者:Johswald    | 项目源码 | 文件源码
def main(_):
    """Train FlowNet"""

    with tf.Graph().as_default():
        # get data
        imgs_0, imgs_1, flows = flownet_tools.get_data(FLAGS.datadir, True)

        # img summary after loading
        #flownet.image_summary(imgs_0, imgs_1, "A_input", flows)

        # apply augmentation
        imgs_0, imgs_1, flows = apply_augmentation(imgs_0, imgs_1, flows)

        # model
        calc_flows = model(imgs_0, imgs_1, flows)

        # img summary of result
        flownet.image_summary(None, None, "E_result", calc_flows)

        # global step and other config
        global_step = slim.get_or_create_global_step()
        train_op = flownet.create_train_op(global_step)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                               keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours)

        # start slim training
        slim.learning.train(
            train_op,
            logdir=FLAGS.logdir + '/train',
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            summary_op=tf.summary.merge_all(),
            log_every_n_steps=FLAGS.log_every_n_steps,
            trace_every_n_steps=FLAGS.trace_every_n_steps,
            session_config=config,
            saver=saver,
            number_of_steps=FLAGS.max_steps,
        )
项目:vessel-classification    作者:GlobalFishingWatch    | 项目源码 | 文件源码
def build_training_net(self, features, timestamps, mmsis):
        self._build_net(features, timestamps, mmsis, True)

        trainers = [
            self.fishing_localisation_objective.build_trainer(timestamps,
                                                              mmsis)
        ]

        learning_rate = tf.train.exponential_decay(
            self.initial_learning_rate, slim.get_or_create_global_step(), 
            self.decay_examples, self.learning_decay_rate)

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

        return TrainNetInfo(optimizer, trainers)
项目:vessel-classification    作者:GlobalFishingWatch    | 项目源码 | 文件源码
def build_training_net(self, features, timestamps, mmsis):
        self._build_net(features, timestamps, mmsis, True)

        trainers = [
            self.fishing_localisation_objective.build_trainer(timestamps,
                                                              mmsis)
        ]

        learning_rate = tf.train.exponential_decay(
            self.initial_learning_rate, slim.get_or_create_global_step(), 
            self.decay_examples, self.learning_decay_rate)

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

        return TrainNetInfo(optimizer, trainers)
项目:vessel-classification    作者:GlobalFishingWatch    | 项目源码 | 文件源码
def build_training_net(self, features, timestamps, mmsis):
        self._build_model(features, timestamps, mmsis, is_training=True)

        trainers = []
        for i in range(len(self.training_objectives)):
            trainers.append(self.training_objectives[i].build_trainer(
                timestamps, mmsis))

        learning_rate = tf.train.exponential_decay(
            self.initial_learning_rate, slim.get_or_create_global_step(), 
            self.decay_examples, self.learning_decay_rate)

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

        return TrainNetInfo(optimizer, trainers)
项目:SSD_tensorflow_VOC    作者:LevinJ    | 项目源码 | 文件源码
def __get_init_fn(self):
        """Returns a function run by the chief worker to warm-start the training.

        Note that the init_fn is only run when initializing the model during the very
        first global step.

        Returns:
            An init function run by the supervisor.
        """  

        if self.checkpoint_path is None:
            return None

        # Warn the user if a checkpoint exists in the train_dir. Then we'll be
        # ignoring the checkpoint anyway.


        if tf.train.latest_checkpoint(self.train_dir):
            tf.logging.info(
                    'Ignoring --checkpoint_path because a checkpoint already exists in %s'
                    % self.train_dir)
            return None

        exclusions = []
        if self.checkpoint_exclude_scopes:
            exclusions = [scope.strip()
                                        for scope in self.checkpoint_exclude_scopes.split(',')]

        # TODO(sguada) variables.filter_variables()
        variables_to_restore = []
        all_variables = slim.get_model_variables()
        if self.fine_tune_vgg16:
            global_step = slim.get_or_create_global_step()
            all_variables.append(global_step)
        for var in all_variables:
            excluded = False

            for exclusion in exclusions:
                if var.op.name.startswith(exclusion):
                    excluded = True
                    break
            if not excluded:
                variables_to_restore.append(var)

        if tf.gfile.IsDirectory(self.checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(self.checkpoint_path)
        else:
            checkpoint_path = self.checkpoint_path

        tf.logging.info('Fine-tuning from %s' % checkpoint_path)

        return slim.assign_from_checkpoint_fn(
                checkpoint_path,
                variables_to_restore,
                ignore_missing_vars=self.ignore_missing_vars)
项目:sact    作者:mfigurnov    | 项目源码 | 文件源码
def main(_):
  if not tf.gfile.Exists(FLAGS.output_dir):
    tf.gfile.MakeDirs(FLAGS.output_dir)

  assert FLAGS.model is not None
  assert FLAGS.model_type in ('vanilla', 'act', 'act_early_stopping', 'sact')
  assert FLAGS.dataset in ('imagenet', 'cifar')

  batch_size = 1

  if FLAGS.dataset == 'imagenet':
    height, width = 224, 224
    num_classes = 1001
  elif FLAGS.dataset == 'cifar':
    height, width = 32, 32
    num_classes = 10

  images = tf.random_uniform((batch_size, height, width, 3))
  model = utils.split_and_int(FLAGS.model)

  # Define the model
  if FLAGS.dataset == 'imagenet':
    with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=False)):
      logits, end_points = imagenet_model.get_network(
          images,
          model,
          num_classes,
          model_type=FLAGS.model_type)
  elif FLAGS.dataset == 'cifar':
    # Define the model:
    with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=False)):
      logits, end_points = cifar_model.resnet(
          images,
          model=model,
          num_classes=num_classes,
          model_type=FLAGS.model_type)

  tf_global_step = slim.get_or_create_global_step()

  checkpoint_path = tf.train.latest_checkpoint(FLAGS.input_dir)
  assert checkpoint_path is not None

  saver = tf.train.Saver(write_version=2)

  with tf.Session() as sess:
    saver.restore(sess, checkpoint_path)
    saver.save(sess, FLAGS.output_dir + '/model', global_step=tf_global_step)
项目:GestureRecognition    作者:gkchai    | 项目源码 | 文件源码
def main(_):
    assert FLAGS.train_dir, "--train_dir is required."
    if tf.gfile.Exists(FLAGS.summaries_dir):
        tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
    tf.gfile.MakeDirs(FLAGS.summaries_dir)

    config = configuration.Config()

    dataset_eval = loader.get_split(FLAGS.split_name, dataset_dir=FLAGS.data_dir)
    if FLAGS.preprocess_abs:
        preprocess_fn = tf.abs
    else:
        preprocess_fn = None

    # whther it is a 2d input
    is_2D = common.is_2D(FLAGS.model)

    series, labels, labels_one_hot = loader.load_batch(dataset_eval, batch_size=config.batch_size, is_2D=is_2D,
                                                          preprocess_fn=preprocess_fn)

    # Build lazy model
    model = common.convert_name_to_instance(FLAGS.model, config, 'eval')

    endpoints = model.build(inputs=series, is_training=False)
    predictions = tf.to_int64(tf.argmax(endpoints.logits, 1))

    slim.get_or_create_global_step()

    # Choose the metrics to compute:
    names_to_values, names_to_updates = metrics.aggregate_metric_map({
        'accuracy': metrics.streaming_accuracy(predictions, labels),
        'precision': metrics.streaming_precision(predictions, labels),
        'recall': metrics.streaming_recall(predictions, labels),
    })

    # Create the summary ops such that they also print out to std output:
    summary_ops = []
    for metric_name, metric_value in names_to_values.iteritems():
        op = tf.summary.scalar(metric_name, metric_value)
        op = tf.Print(op, [metric_value], metric_name)
        summary_ops.append(op)

    slim.evaluation.evaluation_loop(
        master='',
        checkpoint_dir=FLAGS.train_dir,
        logdir=FLAGS.summaries_dir,
        eval_op=names_to_updates.values(),
        num_evals=min(FLAGS.num_batches, dataset_eval.num_samples),
        eval_interval_secs=FLAGS.eval_interval_secs,
        max_number_of_evaluations=FLAGS.num_of_steps,
        summary_op=tf.summary.merge(summary_ops),
        session_config=config.session_config,
        )
项目:vessel-classification    作者:GlobalFishingWatch    | 项目源码 | 文件源码
def run_evaluation(self, master):
        """ The function for running model evaluation on the master. """
        while True:
            with tf.Graph().as_default():

                features, timestamps, time_bounds, mmsis, count = self._feature_data_reader(
                    utility.TEST_SPLIT, False)

                objectives = self.model.build_inference_net(features,
                                                            timestamps, mmsis)

                aggregate_metric_maps = [o.build_test_metrics()
                                         for o in objectives]

                summary_ops = []
                update_ops = []
                for names_to_values, names_to_updates in aggregate_metric_maps:
                    for metric_name, metric_value in names_to_values.iteritems(
                    ):
                        op = tf.summary.scalar(metric_name, metric_value)
                        op = tf.Print(op, [metric_value], metric_name)
                        summary_ops.append(op)
                    for update_op in names_to_updates.values():
                        update_ops.append(update_op)

                count = min(max(count, MIN_TEST_EXAMPLES), MAX_TEST_EXAMPLES)
                num_evals = math.ceil(count / float(self.model.batch_size))

                # Setup the global step.
                slim.get_or_create_global_step()

                merged_summary_ops = tf.summary.merge(summary_ops)

                try:
                    slim.evaluation.evaluation_loop(
                        master,
                        self.checkpoint_dir,
                        self.eval_dir,
                        num_evals=num_evals,
                        eval_op=update_ops,
                        summary_op=merged_summary_ops,
                        eval_interval_secs=120,
                        timeout=20 * 60,
                        variables_to_restore=variables.
                        get_variables_to_restore())
                except (tf.errors.CancelledError, tf.errors.AbortedError):
                    logging.warning(
                        'Caught cancel/abort while running `slim.learning.train`; reraising')
                    raise
                except:
                    logging.exception(
                        'Error while running slim.evaluation.evaluation_loop, ignoring')
                    continue