Python tensorflow 模块,Saver() 实例源码


项目:muffnn    作者:civisanalytics    | 项目源码 | 文件源码
def _build_tf_graph(self):
        """Build the TF graph, setup model saving and setup a TF session.

        This method initializes a TF Saver and a TF Session via

            self._saver = tf.train.Saver()
            self._session = tf.Session()
These calls are made after `self._set_up_graph()`` is called.

    See the main class docs for how to properly call this method from a
    child class.
    self._saver = tf.train.Saver()
    self._session = tf.Session()


项目:tf_rnnlm    作者:Ubiqus    | 项目源码 | 文件源码
def saver(self):
    if self._saver is None:
      self._saver = tf.train.Saver()
    return self._saver
项目:tfutils    作者:neuroailab    | 项目源码 | 文件源码
def tf_saver(self):
        if not hasattr(self, '_tf_saver'):
            self._tf_saver = tf.train.Saver(
                *self.tfsaver_args, **self.tfsaver_kwargs)
        return self._tf_saver
项目:mnist_LeNet    作者:LuxxxLucy    | 项目源码 | 文件源码
def __init__(self, session, saver, args):
        Create a model session.

        Do not call this constructor directly. To instantiate a ModelSession object, use the create and restore class

        :param session: the session in which this model is running
        :type session: tf.Session
        :param saver: object used to serialize this session
        :type saver: tf.Saver
        self.session, self.saver, self.args = session, saver, args
项目:mnist_LeNet    作者:LuxxxLucy    | 项目源码 | 文件源码
def create(cls, **kwargs):
        Create a new model session.

        :param kwargs: optional graph parameters
        :type kwargs: dict
        :return: new model session
        :rtype: ModelSession
        session = tf.Session()
        with session.graph.as_default():
        return cls(session, tf.train.Saver())
项目:hart    作者:akosiorek    | 项目源码 | 文件源码
def saver(self, **kwargs):
        """Returns a Saver for all (trainable and model) variables used by the model.
        Model variables include e.g. moving mean and average in BatchNorm.

        :return: tf.Saver

        return tf.train.Saver(self.vars, **kwargs)
项目:tfutils    作者:neuroailab    | 项目源码 | 文件源码
def initialize(self, no_scratch=False):
        """Fetch record then uses tf's saver.restore."""
        if self.do_restore:

            # First, determine which checkpoint to use.
            if self.from_ckpt is not None:
                # Use a cached checkpoint file.
                ckpt_filename = self.from_ckpt
      'Restoring variables from checkpoint %s ...' % ckpt_filename)
                # Otherwise, use a database checkpoint.
                self.load_rec() if self.load_data is None else None
                if self.load_data is not None:
                    rec, ckpt_filename = self.load_data
          'Restoring variables from record %s (step %d)...' %
                             (str(rec['_id']), rec['step']))
                    # No db checkpoint to load.
                    ckpt_filename = None

            if ckpt_filename is not None:

                all_vars = tf.global_variables() + tf.local_variables()  # get list of all variables
                self.all_vars = strip_prefix(self.params['model_params']['prefix'], all_vars)

                # Next, determine which vars should be restored from the specified checkpoint.
                restore_vars = self.get_restore_vars(ckpt_filename, self.all_vars)
                restore_stripped = strip_prefix(self.params['model_params']['prefix'], list(restore_vars.values()))
                restore_names =  [name for name, var in restore_stripped.items()]
                # Actually load the vars.
      'Restored Vars:\n' + str(restore_names))
                tf_saver_restore = tf.train.Saver(restore_vars)
                tf_saver_restore.restore(self.sess, ckpt_filename)
      '... done restoring.')

                # Reinitialize all other, unrestored vars.
                unrestored_vars = [var for name, var in self.all_vars.items() if name not in restore_names]
                unrestored_var_names = [name for name, var in self.all_vars.items() if name not in restore_names]
      'Unrestored Vars:\n' + str(unrestored_var_names))
        # initialize variables not restored
                assert len( == 0, (

        if not self.do_restore or (self.load_data is None and self.from_ckpt is None):
            init_op_global = tf.global_variables_initializer()
            init_op_local = tf.local_variables_initializer()
项目:tfutils    作者:neuroailab    | 项目源码 | 文件源码
def get_restore_vars(self, save_file, all_vars=None):
        """Create the `var_list` init argument to tf.Saver from save_file.

        Extracts the subset of variables from tf.global_variables that match the
        name and shape of variables saved in the checkpoint file, and returns these
        as a list of variables to restore.

        To support multi-model training, a model prefix is prepended to all
        tf global_variable names, although this prefix is stripped from
        all variables before they are saved to a checkpoint. Thus,

            save_file: path of tf.train.Saver checkpoint.

            dict: checkpoint variables.

        reader = tf.train.NewCheckpointReader(save_file)
        var_shapes = reader.get_variable_to_shape_map()'Saved Vars:\n' + str(var_shapes.keys()))

        var_shapes = {  # Strip the prefix off saved var names.
            strip_prefix_from_name(self.params['model_params']['prefix'], name): shape
            for name, shape in var_shapes.items()}

        # Map old vars from checkpoint to new vars via load_param_dict.
        mapped_var_shapes = self.remap_var_list(var_shapes)'Saved shapes:\n' + str(mapped_var_shapes))

        if all_vars is None:
            all_vars = tf.global_variables() + tf.local_variables()  # get list of all variables
            all_vars = strip_prefix(self.params['model_params']['prefix'], all_vars)

        # Specify which vars are to be restored vs. reinitialized.
        if self.load_param_dict is None:
            restore_vars = {name: var for name, var in all_vars.items() if name in mapped_var_shapes}
            # associate checkpoint names with actual variables
            load_var_dict = {}
            for ckpt_var_name, curr_var_name in self.load_param_dict.items():
                for curr_name, curr_var in all_vars.items():
                    if curr_name == curr_var_name:
                        load_var_dict[ckpt_var_name] = curr_var

            restore_vars = load_var_dict

        restore_vars = self.filter_var_list(restore_vars)

        # Ensure the vars to restored have the correct shape.
        var_list = {}
        for name, var in restore_vars.items():
            var_shape = var.get_shape().as_list()
            if var_shape == mapped_var_shapes[name]:
                var_list[name] = var
        return var_list
项目:tfutils    作者:neuroailab    | 项目源码 | 文件源码
def test(sess,
    Actually runs the testing evaluation loop.

        sess (tensorflow.Session): Object in which to run calculations
        queues (list of CustomQueue): Objects containing asynchronously queued data iterators
        dbinterface (DBInterface object): Saver through which to save results
        validation_targets (dict of tensorflow objects): Objects on which validation will be computed.
        save_intermediate_freq (None or int): How frequently to save intermediate results captured during test
            None means no intermediate saving will be saved

        dict: Validation summary.
        dict: Results.

    # Collect args in a dict of lists
    test_args = {
        'queues': queues,
        'dbinterface': dbinterface,
        'validation_targets': validation_targets,
        'save_intermediate_freq': save_intermediate_freq}

    _ttargs = [{key: value[i] for (key, value) in test_args.items()}
               for i in range(len(queues))]

    for ttarg in _ttargs:

        ttarg['coord'], ttarg['threads'] = start_queues(sess)
        ttarg['dbinterface'].start_time_step = time.time()
        validation_summary = run_targets_dict(sess,

    res = []
    for ttarg in _ttargs:
        stop_queues(sess, ttarg['queues'], ttarg['coord'], ttarg['threads'])

    return validation_summary, res