Python tensorflow.contrib.slim 模块,get_variables() 实例源码

我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用tensorflow.contrib.slim.get_variables()

项目:deepmodels    作者:learningsociety    | 项目源码 | 文件源码
def save_model_for_prediction(self, save_ckpt_fn, vars_to_save=None):
    """Save model data only needed for prediction.

    Args:
      save_ckpt_fn: checkpoint file to save.
      vars_to_save: a list of variables to save.
    """
    if vars_to_save is None:
      vars_to_save = slim.get_model_variables()
      vars_restore_to_exclude = []
      for scope in self.dm_model.restore_scope_exclude:
        vars_restore_to_exclude.extend(slim.get_variables(scope))
      # remove not restored variables.
      vars_to_save = [
          v for v in vars_to_save if v not in vars_restore_to_exclude
      ]
    base_model.save_model(save_ckpt_fn, self.sess, vars_to_save)
项目:Deep_Learning_In_Action    作者:SunnyMarkLiu    | 项目源码 | 文件源码
def load_pretrained_model(self):
        """
        Load the pretrained weights into the non-trainable layer
        :return:
        """
        print('Load the pretrained weights into the non-trainable layer...')
        from tensorflow.python.framework import ops
        trainable_variables = slim.get_variables(None, None, ops.GraphKeys.TRAINABLE_VARIABLES)

        reader = pywrap_tensorflow.NewCheckpointReader(self.pre_trained_model_cpkt)
        pretrained_model_variables = reader.get_variable_to_shape_map()
        for variable in trainable_variables:
            variable_name = variable.name.split(':')[0]
            if variable_name in self.skip_layer:
                continue
            if variable_name not in pretrained_model_variables:
                continue
            print('load ' + variable_name)
            with tf.variable_scope('', reuse=True):
                var = tf.get_variable(variable_name, trainable=False)
                data = reader.get_tensor(variable_name)
                self.sess.run(var.assign(data))
项目:deepmodels    作者:learningsociety    | 项目源码 | 文件源码
def load_model_from_checkpoint_fn(self, model_fn):
    """Load weights from file and keep in memory.

    Args:
      model_fn: saved model file.
    """
    # self.dm_model.use_graph()
    print "start loading from checkpoint file..."
    if self.vars_to_restore is None:
      self.vars_to_restore = slim.get_variables()
    restore_fn = slim.assign_from_checkpoint_fn(model_fn, self.vars_to_restore)
    print "restoring model from {}".format(model_fn)
    restore_fn(self.sess)
    print "model restored."
项目:cnn-visualizer    作者:penny4860    | 项目源码 | 文件源码
def load_ckpt(self, sess, ckpt='ckpts/vgg_16.ckpt'):
        variables = slim.get_variables(scope='vgg_16')
        init_assign_op, init_feed_dict = slim.assign_from_checkpoint(ckpt, variables)
        sess.run(init_assign_op, init_feed_dict)
项目:RL-Universe    作者:Bifrost-Research    | 项目源码 | 文件源码
def build_network(self):

        state = tf.placeholder(tf.float32, [None, 84, 84, 4])

        cnn_1 = slim.conv2d(state, 16, [8,8], stride=4, scope=self.name + '/cnn_1', activation_fn=nn.relu)

        cnn_2 = slim.conv2d(cnn_1, 32, [4,4], stride=2, scope=self.name + '/cnn_2', activation_fn=nn.relu)

        flatten = slim.flatten(cnn_2)

        fcc_1 = slim.fully_connected(flatten, 256, scope=self.name + '/fcc_1', activation_fn=nn.relu)

        adv_probas = slim.fully_connected(fcc_1, self.nb_actions, scope=self.name + '/adv_probas', activation_fn=nn.softmax)

        value_state = slim.fully_connected(fcc_1, 1, scope=self.name + '/value_state', activation_fn=None)

        tf.summary.scalar("model/cnn1_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/cnn_1')))
        tf.summary.scalar("model/cnn2_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/cnn_2')))
        tf.summary.scalar("model/fcc1_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/fcc_1')))
        tf.summary.scalar("model/adv_probas_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/adv_probas')))
        tf.summary.scalar("model/value_state_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/value_state')))

        #Input
        self._tf_state = state

        #Output
        self._tf_adv_probas = adv_probas
        self._tf_value_state = value_state