Python tensorflow 模块,local_variables() 实例源码

我们从Python开源项目中,提取了以下27个代码示例,用于说明如何使用tensorflow.local_variables()

项目:benchmarks    作者:tensorflow    | 项目源码 | 文件源码
def get_post_init_ops(self):
    # Copy initialized variables for variables on the parameter server
    # to the local copy of the variable.

    local_vars = tf.local_variables()
    local_var_by_name = dict(
        [(self._strip_port(v.name), v) for v in local_vars])
    post_init_ops = []
    for v in tf.global_variables():
      if v.name.startswith(variable_mgr_util.PS_SHADOW_VAR_PREFIX + '/v0/'):
        prefix = self._strip_port(
            v.name[len(variable_mgr_util.PS_SHADOW_VAR_PREFIX + '/v0'):])
        for i in range(self.benchmark_cnn.num_gpus):
          name = 'v%s%s' % (i, prefix)
          if name in local_var_by_name:
            copy_to = local_var_by_name[name]
            post_init_ops.append(copy_to.assign(v.read_value()))
    return post_init_ops
项目:benchmarks    作者:tensorflow    | 项目源码 | 文件源码
def savable_variables(self):
    """Returns a list/dict of savable variables to pass to tf.train.Saver."""
    params = {}
    for v in tf.global_variables():
      assert (v.name.startswith(variable_mgr_util.PS_SHADOW_VAR_PREFIX + '/v0/')
              or v.name in ('global_step:0', 'loss_scale:0',
                            'loss_scale_normal_steps:0')), (
                                'Invalid global variable: %s' % v)
      # We store variables in the checkpoint with the shadow variable prefix
      # removed so we can evaluate checkpoints in non-distributed replicated
      # mode. The checkpoints can also be loaded for training in
      # distributed_replicated mode.
      name = self._strip_port(self._remove_shadow_var_prefix_if_present(v.name))
      params[name] = v
    for v in tf.local_variables():
      # Non-trainable variables, such as batch norm moving averages, do not have
      # corresponding global shadow variables, so we add them here. Trainable
      # local variables have corresponding global shadow variables, which were
      # added in the global variable loop above.
      if v.name.startswith('v0/') and v not in tf.trainable_variables():
        params[self._strip_port(v.name)] = v
    return params
项目:LTR-DNN    作者:kn45    | 项目源码 | 文件源码
def accumulate_accuracy(self, sess, l, p):
        """update accuracy by inputs and staged value.
        @return: newly-updated accuracy.
        """
        input_dict = {
            self.labels: l,
            self.preds: p}
        ct = sess.run(self.update_acc, feed_dict=input_dict)
        ac = sess.run(self.acc)
        print [(i.name, i.eval(sess)) for i in tf.local_variables()]
        return ac, ct

    def pairwise_accuracy(self, sess, fiter, inp_fn):
        """evaluate the correct pairwise order ratio.
        @return: correct_pair/ total_pair
        @fiter:  an iterable to fetch instance (qry&pos&neg of each query)
        @inp_fn: a func extracting ([qry], [pos], [neg]) from instance
        """
        accuracy = None
        for inst in fiter:
            qrys, poss, negs = inp_fn(inst)
            for qry, pos, neg in itertools.product(qrys, poss, negs):
                accuracy, ct = self.accumulate_accuracy(sess, qry, pos, neg)
        return accuracy, ct
项目:tefla    作者:litan    | 项目源码 | 文件源码
def show_vars(logger=None, trainable_scopes=None):
    printer = logger.info if logger is not None else print
    all_vars = set(tf.global_variables())
    trainable_vars = set(trainable_variables(trainable_scopes))
    non_trainable_vars = all_vars.difference(trainable_vars)
    local_vars = set(tf.local_variables())

    class nonlocal: pass

    nonlocal.num_params = {}

    def show_var_info(vars, var_type):
        printer('\n---%s vars in model:' % var_type)
        name_shapes = map(lambda v: (v.name, v.get_shape()), vars)
        total_params = 0
        for n, s in sorted(name_shapes, key=lambda ns: ns[0]):
            printer('%s %s' % (n, s))
            total_params += np.prod(s.as_list())
        nonlocal.num_params[var_type] = total_params

    show_var_info(trainable_vars, 'Trainable')
    show_var_info(non_trainable_vars, 'Non Trainable')
    show_var_info(local_vars, 'Local')
    printer('Total number of params:')
    printer(pprint.pformat(nonlocal.num_params))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _test_streaming_sparse_average_precision_at_k(
      self, predictions, labels, k, expected, weights=None):
    with tf.Graph().as_default() as g, self.test_session(g):
      if weights is not None:
        weights = tf.constant(weights, tf.float32)
      predictions = tf.constant(predictions, tf.float32)
      metric, update = metrics.streaming_sparse_average_precision_at_k(
          predictions=predictions, labels=labels, k=k, weights=weights)

      # Fails without initialized vars.
      self.assertRaises(tf.OpError, metric.eval)
      self.assertRaises(tf.OpError, update.eval)
      local_variables = tf.local_variables()
      tf.initialize_variables(local_variables).run()

      # Run per-step op and assert expected values.
      if math.isnan(expected):
        self.assertTrue(math.isnan(update.eval()))
        self.assertTrue(math.isnan(metric.eval()))
      else:
        self.assertAlmostEqual(expected, update.eval())
        self.assertAlmostEqual(expected, metric.eval())
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _test_streaming_sparse_precision_at_top_k(self,
                                                top_k_predictions,
                                                labels,
                                                expected,
                                                class_id=None,
                                                weights=None):
    with tf.Graph().as_default() as g, self.test_session(g):
      if weights is not None:
        weights = tf.constant(weights, tf.float32)
      metric, update = metrics.streaming_sparse_precision_at_top_k(
          top_k_predictions=tf.constant(top_k_predictions, tf.int32),
          labels=labels, class_id=class_id, weights=weights)

      # Fails without initialized vars.
      self.assertRaises(tf.OpError, metric.eval)
      self.assertRaises(tf.OpError, update.eval)
      tf.initialize_variables(tf.local_variables()).run()

      # Run per-step op and assert expected values.
      if math.isnan(expected):
        self.assertTrue(math.isnan(update.eval()))
        self.assertTrue(math.isnan(metric.eval()))
      else:
        self.assertEqual(expected, update.eval())
        self.assertEqual(expected, metric.eval())
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _test_streaming_sparse_average_precision_at_k(
      self, predictions, labels, k, expected, weights=None):
    with tf.Graph().as_default() as g, self.test_session(g):
      if weights is not None:
        weights = tf.constant(weights, tf.float32)
      predictions = tf.constant(predictions, tf.float32)
      metric, update = metrics.streaming_sparse_average_precision_at_k(
          predictions, labels, k, weights=weights)

      # Fails without initialized vars.
      self.assertRaises(tf.OpError, metric.eval)
      self.assertRaises(tf.OpError, update.eval)
      local_variables = tf.local_variables()
      tf.initialize_variables(local_variables).run()

      # Run per-step op and assert expected values.
      if math.isnan(expected):
        _assert_nan(self, update.eval())
        _assert_nan(self, metric.eval())
      else:
        self.assertAlmostEqual(expected, update.eval())
        self.assertAlmostEqual(expected, metric.eval())
项目:lang2program    作者:kelvinguu    | 项目源码 | 文件源码
def guarantee_initialized_variables(session, variables=None):
    """Guarantee that all the specified variables are initialized.

    If a variable is already initialized, leave it alone. Otherwise, initialize it.

    If no variables are specified, checks all variables in the default graph.

    Args:
        variables (list[tf.Variable])
    """
    name_to_var = {v.op.name: v for v in tf.global_variables() + tf.local_variables()}
    uninitialized_variables = list(name_to_var[name] for name in
                                   session.run(tf.report_uninitialized_variables(variables)))
    init_op = tf.variables_initializer(uninitialized_variables)
    session.run(init_op)
    return uninitialized_variables
项目:lang2program    作者:kelvinguu    | 项目源码 | 文件源码
def guarantee_initialized_variables(session, variables=None):
    """Guarantee that all the specified variables are initialized.

    If a variable is already initialized, leave it alone. Otherwise, initialize it.

    If no variables are specified, checks all variables in the default graph.

    Args:
        variables (list[tf.Variable])
    """
    name_to_var = {v.op.name: v for v in tf.global_variables() + tf.local_variables()}
    uninitialized_variables = list(name_to_var[name] for name in
                                   session.run(tf.report_uninitialized_variables(variables)))
    init_op = tf.variables_initializer(uninitialized_variables)
    session.run(init_op)
    return uninitialized_variables
项目:stuff    作者:yaroslavvb    | 项目源码 | 文件源码
def get_post_init_ops(self):
    # Copy initialized variables for variables on the parameter server
    # to the local copy of the variable.

    local_vars = tf.local_variables()
    local_var_by_name = dict(
        [(self._strip_port(v.name), v) for v in local_vars])
    post_init_ops = []
    for v in tf.global_variables():
      if v.name.startswith(PS_SHADOW_VAR_PREFIX + '/v0/'):
        prefix = self._strip_port(
            v.name[len(PS_SHADOW_VAR_PREFIX + '/v0'):])
        for i in range(self.benchmark_cnn.num_gpus):
          name = 'v%s%s' % (i, prefix)
          if name in local_var_by_name:
            copy_to = local_var_by_name[name]
            post_init_ops.append(copy_to.assign(v.read_value()))
    return post_init_ops
项目:stuff    作者:yaroslavvb    | 项目源码 | 文件源码
def savable_variables(self):
    """Returns a list/dict of savable variables to pass to tf.train.Saver."""
    params = {}
    for v in tf.global_variables():
      assert (v.name.startswith(PS_SHADOW_VAR_PREFIX + '/v0/') or
              v.name == 'global_step:0')
      # We store variables in the checkpoint with the shadow variable prefix
      # removed so we can evaluate checkpoints in non-distributed replicated
      # mode. The checkpoints can also be loaded for training in
      # distributed_replicated mode.
      name = self._strip_port(self._remove_shadow_var_prefix_if_present(v.name))
      params[name] = v
    for v in tf.local_variables():
      # Non-trainable variables, such as batch norm moving averages, do not have
      # corresponding global shadow variables, so we add them here. Trainable
      # local variables have corresponding global shadow variables, which were
      # added in the global variable loop above.
      if v.name.startswith('v0/') and v not in tf.trainable_variables():
        params[self._strip_port(v.name)] = v
    return params
项目:stuff    作者:yaroslavvb    | 项目源码 | 文件源码
def get_post_init_ops(self):
    # Copy initialized variables for variables on the parameter server
    # to the local copy of the variable.

    local_vars = tf.local_variables()
    local_var_by_name = dict(
        [(self._strip_port(v.name), v) for v in local_vars])
    post_init_ops = []
    for v in tf.global_variables():
      if v.name.startswith(PS_SHADOW_VAR_PREFIX + '/v0/'):
        prefix = self._strip_port(
            v.name[len(PS_SHADOW_VAR_PREFIX + '/v0'):])
        for i in range(self.benchmark_cnn.num_gpus):
          name = 'v%s%s' % (i, prefix)
          if name in local_var_by_name:
            copy_to = local_var_by_name[name]
            post_init_ops.append(copy_to.assign(v.read_value()))
    return post_init_ops
项目:stuff    作者:yaroslavvb    | 项目源码 | 文件源码
def savable_variables(self):
    """Returns a list/dict of savable variables to pass to tf.train.Saver."""
    params = {}
    for v in tf.global_variables():
      assert (v.name.startswith(PS_SHADOW_VAR_PREFIX + '/v0/') or
              v.name == 'global_step:0')
      # We store variables in the checkpoint with the shadow variable prefix
      # removed so we can evaluate checkpoints in non-distributed replicated
      # mode. The checkpoints can also be loaded for training in
      # distributed_replicated mode.
      name = self._strip_port(self._remove_shadow_var_prefix_if_present(v.name))
      params[name] = v
    for v in tf.local_variables():
      # Non-trainable variables, such as batch norm moving averages, do not have
      # corresponding global shadow variables, so we add them here. Trainable
      # local variables have corresponding global shadow variables, which were
      # added in the global variable loop above.
      if v.name.startswith('v0/') and v not in tf.trainable_variables():
        params[self._strip_port(v.name)] = v
    return params
项目:sonnet    作者:deepmind    | 项目源码 | 文件源码
def log_variables(variables=None):
  """Logs variable information.

  This function logs the name, shape, type, collections, and device for either
  all variables or a given iterable of variables.

  Args:
    variables: iterable of variables; if not provided, then all variables
        (in the default graph) are logged.
  """
  if variables is None:
    variables = tf.global_variables() + tf.local_variables()
  for row in format_variables(variables, join_lines=False):
    tf.logging.info(row)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _test_streaming_sparse_precision_at_k(self,
                                            predictions,
                                            labels,
                                            k,
                                            expected,
                                            class_id=None,
                                            ignore_mask=None,
                                            weights=None):
    with tf.Graph().as_default() as g, self.test_session(g):
      if ignore_mask is not None:
        ignore_mask = tf.constant(ignore_mask, tf.bool)
      if weights is not None:
        weights = tf.constant(weights, tf.float32)
      metric, update = metrics.streaming_sparse_precision_at_k(
          predictions=tf.constant(predictions, tf.float32), labels=labels,
          k=k, class_id=class_id, ignore_mask=ignore_mask, weights=weights)

      # Fails without initialized vars.
      self.assertRaises(tf.OpError, metric.eval)
      self.assertRaises(tf.OpError, update.eval)
      tf.initialize_variables(tf.local_variables()).run()

      # Run per-step op and assert expected values.
      if math.isnan(expected):
        self.assertTrue(math.isnan(update.eval()))
        self.assertTrue(math.isnan(metric.eval()))
      else:
        self.assertEqual(expected, update.eval())
        self.assertEqual(expected, metric.eval())
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_sparse_tensor_value(self):
    predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
    labels = [[0, 0, 0, 1], [0, 0, 1, 0]]
    expected_precision = 0.5
    with self.test_session():
      _, precision = metrics.streaming_sparse_precision_at_k(
          predictions=tf.constant(predictions, tf.float32),
          labels=_binary_2d_label_to_sparse_value(labels), k=1)

      tf.initialize_variables(tf.local_variables()).run()

      self.assertEqual(expected_precision, precision.eval())
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _test_streaming_sparse_recall_at_k(self,
                                         predictions,
                                         labels,
                                         k,
                                         expected,
                                         class_id=None,
                                         ignore_mask=None,
                                         weights=None):
    with tf.Graph().as_default() as g, self.test_session(g):
      if ignore_mask is not None:
        ignore_mask = tf.constant(ignore_mask, tf.bool)
      if weights is not None:
        weights = tf.constant(weights, tf.float32)
      metric, update = metrics.streaming_sparse_recall_at_k(
          predictions=tf.constant(predictions, tf.float32),
          labels=labels, k=k, class_id=class_id, ignore_mask=ignore_mask,
          weights=weights)

      # Fails without initialized vars.
      self.assertRaises(tf.OpError, metric.eval)
      self.assertRaises(tf.OpError, update.eval)
      tf.initialize_variables(tf.local_variables()).run()

      # Run per-step op and assert expected values.
      if math.isnan(expected):
        self.assertTrue(math.isnan(update.eval()))
        self.assertTrue(math.isnan(metric.eval()))
      else:
        self.assertEqual(expected, update.eval())
        self.assertEqual(expected, metric.eval())
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_sparse_tensor_value(self):
    predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
    labels = [[0, 0, 1, 0], [0, 0, 0, 1]]
    expected_recall = 0.5
    with self.test_session():
      _, recall = metrics.streaming_sparse_recall_at_k(
          predictions=tf.constant(predictions, tf.float32),
          labels=_binary_2d_label_to_sparse_value(labels), k=1)

      tf.initialize_variables(tf.local_variables()).run()

      self.assertEqual(expected_recall, recall.eval())
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_top_k_rank_invalid(self):
    with self.test_session():
      # top_k_predictions has rank < 2.
      top_k_predictions = [9, 4, 6, 2, 0]
      sp_labels = tf.SparseTensorValue(
          indices=np.array([[0,], [1,], [2,]], np.int64),
          values=np.array([2, 7, 8], np.int64),
          shape=np.array([10,], np.int64))

      with self.assertRaises(ValueError):
        precision, _ = metrics.streaming_sparse_precision_at_top_k(
            top_k_predictions=tf.constant(top_k_predictions, tf.int64),
            labels=sp_labels)
        tf.initialize_variables(tf.local_variables()).run()
        precision.eval()
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_sparse_tensor_value(self):
    predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
    labels = [[0, 0, 0, 1], [0, 0, 1, 0]]
    expected_precision = 0.5
    with self.test_session():
      _, precision = metrics.streaming_sparse_precision_at_k(
          predictions=tf.constant(predictions, tf.float32),
          labels=_binary_2d_label_to_sparse_value(labels), k=1)

      tf.initialize_variables(tf.local_variables()).run()

      self.assertEqual(expected_precision, precision.eval())
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _test_streaming_sparse_recall_at_k(self,
                                         predictions,
                                         labels,
                                         k,
                                         expected,
                                         class_id=None,
                                         weights=None):
    with tf.Graph().as_default() as g, self.test_session(g):
      if weights is not None:
        weights = tf.constant(weights, tf.float32)
      metric, update = metrics.streaming_sparse_recall_at_k(
          predictions=tf.constant(predictions, tf.float32),
          labels=labels, k=k, class_id=class_id, weights=weights)

      # Fails without initialized vars.
      self.assertRaises(tf.OpError, metric.eval)
      self.assertRaises(tf.OpError, update.eval)
      tf.initialize_variables(tf.local_variables()).run()

      # Run per-step op and assert expected values.
      if math.isnan(expected):
        _assert_nan(self, update.eval())
        _assert_nan(self, metric.eval())
      else:
        self.assertEqual(expected, update.eval())
        self.assertEqual(expected, metric.eval())
项目:nengo_dl    作者:nengo    | 项目源码 | 文件源码
def save_params(self, path, include_global=True, include_local=False):
        """Save network parameters to the given ``path``.

        Parameters
        ----------
        path : str
            Filepath of parameter output file
        include_global : bool, optional
            If True (default True), save global (trainable) network variables
        include_local : bool, optional
            If True (default False), save local (non-trainable) network
            variables
        """
        if self.closed:
            raise SimulationError("Simulation has been closed, cannot save "
                                  "parameters")

        with self.tensor_graph.graph.as_default():
            vars = []
            if include_global:
                vars.extend(tf.global_variables())
            if include_local:
                vars.extend(tf.local_variables())

            path = tf.train.Saver(vars).save(self.sess, path)

        logger.info("Model parameters saved to %s", path)
项目:nengo_dl    作者:nengo    | 项目源码 | 文件源码
def load_params(self, path, include_global=True, include_local=False):
        """Load network parameters from the given ``path``.

        Parameters
        ----------
        path : str
            Filepath of parameter input file
        include_global : bool, optional
            If True (default True), load global (trainable) network variables
        include_local : bool, optional
            If True (default False), load local (non-trainable) network
            variables
        """
        if self.closed:
            raise SimulationError("Simulation has been closed, cannot load "
                                  "parameters")

        with self.tensor_graph.graph.as_default():
            vars = []
            if include_global:
                vars.extend(tf.global_variables())
            if include_local:
                vars.extend(tf.local_variables())

            tf.train.Saver(vars).restore(self.sess, path)

        logger.info("Model parameters loaded from %s", path)
项目:keras-fcn    作者:JihongJu    | 项目源码 | 文件源码
def _initialize_variables():
    """Utility to initialize uninitialized variables on the fly.
    """
    variables = tf.local_variables()
    uninitialized_variables = []
    for v in variables:
        if not hasattr(v, '_keras_initialized') or not v._keras_initialized:
            uninitialized_variables.append(v)
            v._keras_initialized = True
    if uninitialized_variables:
        sess = K.get_session()
        sess.run(tf.variables_initializer(uninitialized_variables))
项目:tfutils    作者:neuroailab    | 项目源码 | 文件源码
def initialize(self, no_scratch=False):
        """Fetch record then uses tf's saver.restore."""
        if self.do_restore:

            # First, determine which checkpoint to use.
            if self.from_ckpt is not None:
                # Use a cached checkpoint file.
                ckpt_filename = self.from_ckpt
                log.info('Restoring variables from checkpoint %s ...' % ckpt_filename)
            else:
                # Otherwise, use a database checkpoint.
                self.load_rec() if self.load_data is None else None
                if self.load_data is not None:
                    rec, ckpt_filename = self.load_data
                    log.info('Restoring variables from record %s (step %d)...' %
                             (str(rec['_id']), rec['step']))
                else:
                    # No db checkpoint to load.
                    ckpt_filename = None

            if ckpt_filename is not None:

                all_vars = tf.global_variables() + tf.local_variables()  # get list of all variables
                self.all_vars = strip_prefix(self.params['model_params']['prefix'], all_vars)

                # Next, determine which vars should be restored from the specified checkpoint.
                restore_vars = self.get_restore_vars(ckpt_filename, self.all_vars)
                restore_stripped = strip_prefix(self.params['model_params']['prefix'], list(restore_vars.values()))
                restore_names =  [name for name, var in restore_stripped.items()]
                # Actually load the vars.
                log.info('Restored Vars:\n' + str(restore_names))
                tf_saver_restore = tf.train.Saver(restore_vars)
                tf_saver_restore.restore(self.sess, ckpt_filename)
                log.info('... done restoring.')

                # Reinitialize all other, unrestored vars.
                unrestored_vars = [var for name, var in self.all_vars.items() if name not in restore_names]
                unrestored_var_names = [name for name, var in self.all_vars.items() if name not in restore_names]
                log.info('Unrestored Vars:\n' + str(unrestored_var_names))
                self.sess.run(tf.variables_initializer(unrestored_vars))  # initialize variables not restored
                assert len(self.sess.run(tf.report_uninitialized_variables())) == 0, (
                    self.sess.run(tf.report_uninitialized_variables()))

        if not self.do_restore or (self.load_data is None and self.from_ckpt is None):
            init_op_global = tf.global_variables_initializer()
            self.sess.run(init_op_global)
            init_op_local = tf.local_variables_initializer()
            self.sess.run(init_op_local)
项目:tfutils    作者:neuroailab    | 项目源码 | 文件源码
def get_restore_vars(self, save_file, all_vars=None):
        """Create the `var_list` init argument to tf.Saver from save_file.

        Extracts the subset of variables from tf.global_variables that match the
        name and shape of variables saved in the checkpoint file, and returns these
        as a list of variables to restore.

        To support multi-model training, a model prefix is prepended to all
        tf global_variable names, although this prefix is stripped from
        all variables before they are saved to a checkpoint. Thus,


        Args:
            save_file: path of tf.train.Saver checkpoint.

        Returns:
            dict: checkpoint variables.

        """
        reader = tf.train.NewCheckpointReader(save_file)
        var_shapes = reader.get_variable_to_shape_map()
        log.info('Saved Vars:\n' + str(var_shapes.keys()))

        var_shapes = {  # Strip the prefix off saved var names.
            strip_prefix_from_name(self.params['model_params']['prefix'], name): shape
            for name, shape in var_shapes.items()}

        # Map old vars from checkpoint to new vars via load_param_dict.
        mapped_var_shapes = self.remap_var_list(var_shapes)
        log.info('Saved shapes:\n' + str(mapped_var_shapes))

        if all_vars is None:
            all_vars = tf.global_variables() + tf.local_variables()  # get list of all variables
            all_vars = strip_prefix(self.params['model_params']['prefix'], all_vars)

        # Specify which vars are to be restored vs. reinitialized.
        if self.load_param_dict is None:
            restore_vars = {name: var for name, var in all_vars.items() if name in mapped_var_shapes}
        else:
            # associate checkpoint names with actual variables
            load_var_dict = {}
            for ckpt_var_name, curr_var_name in self.load_param_dict.items():
                for curr_name, curr_var in all_vars.items():
                    if curr_name == curr_var_name:
                        load_var_dict[ckpt_var_name] = curr_var
                        break

            restore_vars = load_var_dict

        restore_vars = self.filter_var_list(restore_vars)

        # Ensure the vars to restored have the correct shape.
        var_list = {}
        for name, var in restore_vars.items():
            var_shape = var.get_shape().as_list()
            if var_shape == mapped_var_shapes[name]:
                var_list[name] = var
        return var_list
项目:dynamic-training-bench    作者:galeone    | 项目源码 | 文件源码
def extract_features(self,
                         checkpoint_path,
                         inputs,
                         layer_name,
                         num_classes=0):
        """Restore model parameters from checkpoint_path. Search in the model
        the layer with name `layer_name`. If found places `inputs` as input to the model
        and returns the values extracted by the layer.
        Args:
            checkpoint_path: path of the trained model checkpoint directory
            inputs: a Tensor with a shape compatible with the model's input
            layer_name: a string, the name of the layer to extract from model
            num_classes: number of classes to classify, this number must be equal to the number
            of classes the classifier was trained on, if the model is a classifier or however is
            a model class aware, otherwise let the number = 0
        Returns:
            features: a numpy ndarray that contains the extracted features
        """

        # Evaluate the inputs in the current default graph
        # then user a placeholder to inject the computed values into the new graph
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            evaluated_inputs = sess.run(inputs)

        # Create a new graph to not making dirty the default graph after subsequent
        # calls
        with tf.Graph().as_default() as graph:
            inputs_ = tf.placeholder(inputs.dtype, shape=inputs.shape)

            # Build a Graph that computes the predictions from the inference model.
            _ = self._model.get(
                inputs_, num_classes, train_phase=False, l2_penalty=0.0)

            # This will raise an exception if layer_name is not found
            layer = graph.get_tensor_by_name(layer_name)

            saver = tf.train.Saver(variables_to_restore())
            init = [
                tf.variables_initializer(
                    tf.global_variables() + tf.local_variables()),
                tf.tables_initializer()
            ]
            features = np.zeros(layer.shape)
            with tf.Session(config=tf.ConfigProto(
                    allow_soft_placement=True)) as sess:
                ckpt = tf.train.get_checkpoint_state(checkpoint_path)
                if ckpt and ckpt.model_checkpoint_path:
                    # Restores from checkpoint
                    saver.restore(sess, ckpt.model_checkpoint_path)
                else:
                    print('[!] No checkpoint file found')
                    return features
                sess.run(init)
                features = sess.run(
                    layer, feed_dict={
                        inputs_: evaluated_inputs
                    })

            return features