Python tensorflow 模块,report_uninitialized_variables() 实例源码

我们从Python开源项目中,提取了以下18个代码示例,用于说明如何使用tensorflow.report_uninitialized_variables()

项目:tensorlight    作者:bsautermeister    | 项目源码 | 文件源码
def uninitialized_variables(session, var_list=None):
    """Gets the list of uninitialized variables.
       Note: this has to be evaluated on a session.
    Parameters
    ----------
    session: tf.Session
        The TensorFlow session to scan for uninitialized variables
    var_list: list(tf.Varaible) or None
        The list of variables to filter for uninitialized ones.
        Defaults to tf.all_variables() is used.
    """
    if var_list is None:
        var_list = tf.all_variables()

    reported_var_names = session.run(tf.report_uninitialized_variables(var_list))
    uninit_vars = []
    for name in reported_var_names:
        try:
            uninit_vars.append(tf.get_variable(name))
        except ValueError:
            print("Failed to collect variable {}. Skipping.", name)

    return uninit_vars
项目:lang2program    作者:kelvinguu    | 项目源码 | 文件源码
def guarantee_initialized_variables(session, variables=None):
    """Guarantee that all the specified variables are initialized.

    If a variable is already initialized, leave it alone. Otherwise, initialize it.

    If no variables are specified, checks all variables in the default graph.

    Args:
        variables (list[tf.Variable])
    """
    name_to_var = {v.op.name: v for v in tf.global_variables() + tf.local_variables()}
    uninitialized_variables = list(name_to_var[name] for name in
                                   session.run(tf.report_uninitialized_variables(variables)))
    init_op = tf.variables_initializer(uninitialized_variables)
    session.run(init_op)
    return uninitialized_variables
项目:lang2program    作者:kelvinguu    | 项目源码 | 文件源码
def guarantee_initialized_variables(session, variables=None):
    """Guarantee that all the specified variables are initialized.

    If a variable is already initialized, leave it alone. Otherwise, initialize it.

    If no variables are specified, checks all variables in the default graph.

    Args:
        variables (list[tf.Variable])
    """
    name_to_var = {v.op.name: v for v in tf.global_variables() + tf.local_variables()}
    uninitialized_variables = list(name_to_var[name] for name in
                                   session.run(tf.report_uninitialized_variables(variables)))
    init_op = tf.variables_initializer(uninitialized_variables)
    session.run(init_op)
    return uninitialized_variables
项目:tflearn    作者:tflearn    | 项目源码 | 文件源码
def resetGlobal(self):
        self.global_acc = 0.0
        self.global_loss = 0.0


# def initialize_uninit_variables(session, list_of_variables=None):
#     if list_of_variables is None:
#         list_of_variables = tf.global_variables()
#     uninitialized_variables = list(tf.get_variable(name) for name in
#                                    session.run(tf.report_uninitialized_variables(list_of_variables)))
#     session.run(tf.variables_initializer(uninitialized_variables))
#     return uninitialized_variables
项目:RFHO    作者:lucfra    | 项目源码 | 文件源码
def __init__(self, optimizer, hyper_dict, method, hyper_grad_kwargs=None,
                 hyper_optimizer_class=AdamOptimizer, **optimizers_kwargs):
        """
        Interface instance of gradient-based hyperparameter optimization methods.

        :param optimizer: parameter optimization dynamics (obtained from `Optimizer.create` methods)
        :param hyper_dict: dictionary of validation errors and list of hyperparameters to be optimized
        :param method:  method with which to compute hyper-gradients: Forward
                        or Reverse-Ho
        :param hyper_grad_kwargs: dictionary of keyword arguments for `HyperGradient` classes (usually None)
        :param hyper_optimizer_class: (default Adam) Optimizer class for optimization of the hyperparameters
        :param optimizers_kwargs: keyword arguments for hyperparameter optimizers (like hyper-learning rate)
        """
        assert method in [ReverseHG, ForwardHG]
        assert hyper_optimizer_class is None or issubclass(hyper_optimizer_class, Optimizer)
        assert isinstance(hyper_dict, dict)
        assert isinstance(optimizer, Optimizer)

        if not hyper_grad_kwargs: hyper_grad_kwargs = {}
        self.hyper_iteration_step = GlobalStep(name='hyper_iteration_step')
        self._report_hyper_it_init = tf.report_uninitialized_variables([self.hyper_iteration_step.var])
        # self.hyper_batch_step = GlobalStep(name='hyper_batch_step')
        self.hyper_batch_step = GlobalStep(name='batch_step')

        # automatically links eventual optimizer global step (like in Adam) to HyperGradient global step
        hyper_grad_kwargs['global_step'] = hyper_grad_kwargs.get(
            'global_step', optimizer.global_step if hasattr(optimizer, 'global_step') else GlobalStep())

        # automatically links eventual hyper-optimizer global step (like in Adam) to batch_step
        if hyper_optimizer_class == AdamOptimizer:
            optimizers_kwargs['global_step'] = self.hyper_batch_step
            optimizers_kwargs.setdefault('eps', 1.e-14)

        self.hyper_gradients = method(optimizer, hyper_dict, **hyper_grad_kwargs)

        if hyper_optimizer_class:
            # noinspection PyTypeChecker
            self.hyper_optimizers = create_hyperparameter_optimizers(
                self.hyper_gradients, optimizer_class=hyper_optimizer_class, **optimizers_kwargs)
        else:
            self.hyper_optimizers = None
项目:GPflow    作者:GPflow    | 项目源码 | 文件源码
def _find_initializable_tensors(intializables, session):
    for_reports = []
    status_tensors = []
    boolean_tensors = []

    for v in intializables:
        if isinstance(v, (tuple, list)):
            status_tensors.append(v[0])
            boolean_tensors.append(v[1])
        # TODO(@awav): Tensorflow Iterator must have to be skipped at
        # auto-intialization unless TensorFlow issue #14633 is resolved.
        elif isinstance(v, tf.data.Iterator):
            continue
        else:
            for_reports.append(v)

    if for_reports:
        uninitialized = tf.report_uninitialized_variables(var_list=for_reports)
        def uninitialized_names():
            for uv in session.run(uninitialized):
                yield uv.decode('utf-8')

        names = set(uninitialized_names())
        for v in for_reports:
            if v.name.split(':')[0] in names:
                yield v

    if boolean_tensors:
        stats = session.run(boolean_tensors)
        length = len(stats)
        for i in range(length):
            if not stats[i]:
                yield status_tensors[i]
项目:TensorBase    作者:dancsalo    | 项目源码 | 文件源码
def _init_uninit_vars(self):
        """ Initialize all other trainable variables, i.e. those which are uninitialized """
        uninit_vars = self.sess.run(tf.report_uninitialized_variables())
        vars_list = list()
        for v in uninit_vars:
            var = v.decode("utf-8")
            vars_list.append(var)
        uninit_vars_tf = [v for v in tf.global_variables() if v.name.split(':')[0] in vars_list]
        self.sess.run(tf.variables_initializer(var_list=uninit_vars_tf))
项目:tensorflow    作者:luyishisi    | 项目源码 | 文件源码
def test_restore_map_for_classification_ckpt(self):
    # Define mock tensorflow classification graph and save variables.
    test_graph_classification = tf.Graph()
    with test_graph_classification.as_default():
      image = tf.placeholder(dtype=tf.float32, shape=[1, 20, 20, 3])
      with tf.variable_scope('mock_model'):
        net = slim.conv2d(image, num_outputs=3, kernel_size=1, scope='layer1')
        slim.conv2d(net, num_outputs=3, kernel_size=1, scope='layer2')

      init_op = tf.global_variables_initializer()
      saver = tf.train.Saver()
      save_path = self.get_temp_dir()
      with self.test_session() as sess:
        sess.run(init_op)
        saved_model_path = saver.save(sess, save_path)

    # Create tensorflow detection graph and load variables from
    # classification checkpoint.
    test_graph_detection = tf.Graph()
    with test_graph_detection.as_default():
      model = self._build_model(
          is_training=False, first_stage_only=False, second_stage_batch_size=6)

      inputs_shape = (2, 20, 20, 3)
      inputs = tf.to_float(tf.random_uniform(
          inputs_shape, minval=0, maxval=255, dtype=tf.int32))
      preprocessed_inputs = model.preprocess(inputs)
      prediction_dict = model.predict(preprocessed_inputs)
      model.postprocess(prediction_dict)
      var_map = model.restore_map(from_detection_checkpoint=False)
      self.assertIsInstance(var_map, dict)
      saver = tf.train.Saver(var_map)
      with self.test_session() as sess:
        saver.restore(sess, saved_model_path)
        for var in sess.run(tf.report_uninitialized_variables()):
          self.assertNotIn(model.first_stage_feature_extractor_scope, var.name)
          self.assertNotIn(model.second_stage_feature_extractor_scope,
                           var.name)
项目:tensorflow    作者:luyishisi    | 项目源码 | 文件源码
def test_restore_map_for_detection_ckpt(self):
    # Define first detection graph and save variables.
    test_graph_detection1 = tf.Graph()
    with test_graph_detection1.as_default():
      model = self._build_model(
          is_training=False, first_stage_only=False, second_stage_batch_size=6)
      inputs_shape = (2, 20, 20, 3)
      inputs = tf.to_float(tf.random_uniform(
          inputs_shape, minval=0, maxval=255, dtype=tf.int32))
      preprocessed_inputs = model.preprocess(inputs)
      prediction_dict = model.predict(preprocessed_inputs)
      model.postprocess(prediction_dict)
      init_op = tf.global_variables_initializer()
      saver = tf.train.Saver()
      save_path = self.get_temp_dir()
      with self.test_session() as sess:
        sess.run(init_op)
        saved_model_path = saver.save(sess, save_path)

    # Define second detection graph and restore variables.
    test_graph_detection2 = tf.Graph()
    with test_graph_detection2.as_default():
      model2 = self._build_model(is_training=False, first_stage_only=False,
                                 second_stage_batch_size=6, num_classes=42)

      inputs_shape2 = (2, 20, 20, 3)
      inputs2 = tf.to_float(tf.random_uniform(
          inputs_shape2, minval=0, maxval=255, dtype=tf.int32))
      preprocessed_inputs2 = model2.preprocess(inputs2)
      prediction_dict2 = model2.predict(preprocessed_inputs2)
      model2.postprocess(prediction_dict2)
      var_map = model2.restore_map(from_detection_checkpoint=True)
      self.assertIsInstance(var_map, dict)
      saver = tf.train.Saver(var_map)
      with self.test_session() as sess:
        saver.restore(sess, saved_model_path)
        for var in sess.run(tf.report_uninitialized_variables()):
          self.assertNotIn(model2.first_stage_feature_extractor_scope, var.name)
          self.assertNotIn(model2.second_stage_feature_extractor_scope,
                           var.name)
项目:tensorflow    作者:luyishisi    | 项目源码 | 文件源码
def test_restore_map_for_detection_ckpt(self):
    init_op = tf.global_variables_initializer()
    saver = tf_saver.Saver()
    save_path = self.get_temp_dir()
    with self.test_session() as sess:
      sess.run(init_op)
      saved_model_path = saver.save(sess, save_path)
      var_map = self._model.restore_map(from_detection_checkpoint=True)
      self.assertIsInstance(var_map, dict)
      saver = tf.train.Saver(var_map)
      saver.restore(sess, saved_model_path)
      for var in sess.run(tf.report_uninitialized_variables()):
        self.assertNotIn('FeatureExtractor', var.name)
项目:tensorflow    作者:luyishisi    | 项目源码 | 文件源码
def test_restore_map_for_classification_ckpt(self):
    # Define mock tensorflow classification graph and save variables.
    test_graph_classification = tf.Graph()
    with test_graph_classification.as_default():
      image = tf.placeholder(dtype=tf.float32, shape=[1, 20, 20, 3])
      with tf.variable_scope('mock_model'):
        net = slim.conv2d(image, num_outputs=32, kernel_size=1, scope='layer1')
        slim.conv2d(net, num_outputs=3, kernel_size=1, scope='layer2')

      init_op = tf.global_variables_initializer()
      saver = tf.train.Saver()
      save_path = self.get_temp_dir()
      with self.test_session() as sess:
        sess.run(init_op)
        saved_model_path = saver.save(sess, save_path)

    # Create tensorflow detection graph and load variables from
    # classification checkpoint.
    test_graph_detection = tf.Graph()
    with test_graph_detection.as_default():
      inputs_shape = [2, 2, 2, 3]
      inputs = tf.to_float(tf.random_uniform(
          inputs_shape, minval=0, maxval=255, dtype=tf.int32))
      preprocessed_inputs = self._model.preprocess(inputs)
      prediction_dict = self._model.predict(preprocessed_inputs)
      self._model.postprocess(prediction_dict)
      var_map = self._model.restore_map(from_detection_checkpoint=False)
      self.assertIsInstance(var_map, dict)
      saver = tf.train.Saver(var_map)
      with self.test_session() as sess:
        saver.restore(sess, saved_model_path)
        for var in sess.run(tf.report_uninitialized_variables()):
          self.assertNotIn('FeatureExtractor', var.name)
项目:tfutils    作者:neuroailab    | 项目源码 | 文件源码
def initialize(self, no_scratch=False):
        """Fetch record then uses tf's saver.restore."""
        if self.do_restore:

            # First, determine which checkpoint to use.
            if self.from_ckpt is not None:
                # Use a cached checkpoint file.
                ckpt_filename = self.from_ckpt
                log.info('Restoring variables from checkpoint %s ...' % ckpt_filename)
            else:
                # Otherwise, use a database checkpoint.
                self.load_rec() if self.load_data is None else None
                if self.load_data is not None:
                    rec, ckpt_filename = self.load_data
                    log.info('Restoring variables from record %s (step %d)...' %
                             (str(rec['_id']), rec['step']))
                else:
                    # No db checkpoint to load.
                    ckpt_filename = None

            if ckpt_filename is not None:

                all_vars = tf.global_variables() + tf.local_variables()  # get list of all variables
                self.all_vars = strip_prefix(self.params['model_params']['prefix'], all_vars)

                # Next, determine which vars should be restored from the specified checkpoint.
                restore_vars = self.get_restore_vars(ckpt_filename, self.all_vars)
                restore_stripped = strip_prefix(self.params['model_params']['prefix'], list(restore_vars.values()))
                restore_names =  [name for name, var in restore_stripped.items()]
                # Actually load the vars.
                log.info('Restored Vars:\n' + str(restore_names))
                tf_saver_restore = tf.train.Saver(restore_vars)
                tf_saver_restore.restore(self.sess, ckpt_filename)
                log.info('... done restoring.')

                # Reinitialize all other, unrestored vars.
                unrestored_vars = [var for name, var in self.all_vars.items() if name not in restore_names]
                unrestored_var_names = [name for name, var in self.all_vars.items() if name not in restore_names]
                log.info('Unrestored Vars:\n' + str(unrestored_var_names))
                self.sess.run(tf.variables_initializer(unrestored_vars))  # initialize variables not restored
                assert len(self.sess.run(tf.report_uninitialized_variables())) == 0, (
                    self.sess.run(tf.report_uninitialized_variables()))

        if not self.do_restore or (self.load_data is None and self.from_ckpt is None):
            init_op_global = tf.global_variables_initializer()
            self.sess.run(init_op_global)
            init_op_local = tf.local_variables_initializer()
            self.sess.run(init_op_local)
项目:CycleGAN-Tensorflow    作者:gitlimlab    | 项目源码 | 文件源码
def run(args):
    logger.info('Read data:')
    train_A, train_B, test_A, test_B = get_data(args.task, args.image_size)

    logger.info('Build graph:')
    model = CycleGAN(args)

    variables_to_save = tf.global_variables()
    init_op = tf.variables_initializer(variables_to_save)
    init_all_op = tf.global_variables_initializer()
    saver = FastSaver(variables_to_save)

    logger.info('Trainable vars:')
    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 tf.get_variable_scope().name)
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    if args.load_model != '':
        model_name = args.load_model
    else:
        model_name = '{}_{}'.format(args.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    logdir = './logs'
    makedirs(logdir)
    logdir = os.path.join(logdir, model_name)
    logger.info('Events directory: %s', logdir)
    summary_writer = tf.summary.FileWriter(logdir)

    def init_fn(sess):
        logger.info('Initializing all parameters.')
        sess.run(init_all_op)

    sv = tf.train.Supervisor(is_chief=True,
                             logdir=logdir,
                             saver=saver,
                             summary_op=None,
                             init_op=init_op,
                             init_fn=init_fn,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(variables_to_save),
                             global_step=model.global_step,
                             save_model_secs=300,
                             save_summaries_secs=30)

    if args.train:
        logger.info("Starting training session.")
        with sv.managed_session() as sess:
            model.train(sess, summary_writer, train_A, train_B)

    logger.info("Starting testing session.")
    with sv.managed_session() as sess:
        base_dir = os.path.join('results', model_name)
        makedirs(base_dir)
        model.test(sess, test_A, test_B, base_dir)
项目:tensorfx    作者:TensorLab    | 项目源码 | 文件源码
def build_training_graph(self, dataset):
    """Builds the graph to use for training a model.

    This operates on the current default graph.

    Args:
      dataset: The dataset to use during training.
    Returns:
      The set of tensors and ops references required for training.
    """
    with tf.name_scope('input'):
      # For training, ensure the data is shuffled, and don't limit to any fixed number of epochs.
      # The datasource to use is the one named as 'train' within the dataset.
      inputs = self.build_input(dataset, 'train',
                                batch=self.args.batch_size,
                                epochs=self.args.epochs,
                                shuffle=True)

    with tf.name_scope('inference'):
      inferences = self.build_inference(inputs, training=True)

    with tf.name_scope('train'):
      # Global steps is marked as trainable (explicitly), so as to have it be saved into checkpoints
      # for the purposes of resumed training.
      global_steps = tf.Variable(0, name='global_steps', dtype=tf.int64, trainable=True,
                                 collections=[tf.GraphKeys.GLOBAL_VARIABLES,
                                              tf.GraphKeys.GLOBAL_STEP,
                                              tf.GraphKeys.TRAINABLE_VARIABLES])
      loss, train_op = self.build_training(global_steps, inputs, inferences)

    with tf.name_scope('initialization'):
      # Create the saver that will be used to save and restore (in cases of resumed training)
      # trained variables.
      saver = tf.train.Saver(tf.trainable_variables(), sharded=True)

      init_op, local_init_op = self.build_init()
      ready_op = tf.report_uninitialized_variables(tf.trainable_variables())

    # Create the summary op that will merge all summaries across all sub-graphs
    summary_op = tf.summary.merge_all()

    scaffold = tf.train.Scaffold(init_op=init_op,
                                 local_init_op=local_init_op,
                                 ready_op=ready_op,
                                 ready_for_local_init_op=ready_op,
                                 summary_op=summary_op,
                                 saver=saver)
    scaffold.finalize()

    return {
      'global_steps': global_steps,
      'loss': loss,
      'init_op': init_op,
      'local_init_op': local_init_op,
      'ready_op': ready_op,
      'train_op': train_op,
      'summary_op': summary_op,
      'saver': saver,
      'scaffold': scaffold
    }
项目:a3c-tensorflow    作者:carpedm20    | 项目源码 | 文件源码
def train(self):
    variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
    init_op = tf.variables_initializer(variables_to_save)
    init_all_op = tf.global_variables_initializer()

    saver = FastSaver(variables_to_save)

    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
    tf.logging.info('Trainable vars:')
    slim.model_analyzer.analyze_vars(var_list, print_info=True)

    def init_fn(ses):
      tf.logging.info("="*30)
      tf.logging.info("Initializing all parameters.")
      tf.logging.info("="*30)
      ses.run(init_all_op)

    sess_config = tf.ConfigProto(
      device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(self.task)])

    summary_writer = tf.summary.FileWriter("{}_{}".format(self.log_dir, self.task))
    tf.logging.info("Events directory: %s_%s", self.log_dir, self.task)
    sv = tf.train.Supervisor(is_chief=(self.task == 0),
                             logdir=self.log_dir,
                             saver=saver,
                             summary_op=None,
                             init_op=init_op,
                             init_fn=init_fn,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(variables_to_save),
                             save_model_secs=600,
                             save_summaries_secs=30)

    num_global_steps = 100000000

    with sv.managed_session(self.server.target, config=sess_config) as sess, sess.as_default():
      sess.run(self.agent.sync)
      self.agent.start(sess, summary_writer)

      global_step = sess.run(self.agent.global_step)
      tf.logging.info("Starting training at step=%d", global_step)

      while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps):
        self.agent.process(sess)
        global_step = sess.run(self.agent.global_step)

    # Ask for all the services to stop.
    sv.stop()
    tf.logging.info('reached %s steps. worker stopped.', global_step)
项目:BicycleGAN-Tensorflow    作者:gitlimlab    | 项目源码 | 文件源码
def run(args):
    # setting the GPU #
    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    logger.info('Read data:')
    train_A, train_B, test_A, test_B = get_data(args.task, args.image_size)

    logger.info('Build graph:')
    model = BicycleGAN(args)

    variables_to_save = tf.global_variables()
    init_op = tf.variables_initializer(variables_to_save)
    init_all_op = tf.global_variables_initializer()
    saver = FastSaver(variables_to_save)

    logger.info('Trainable vars:')
    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 tf.get_variable_scope().name)
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    if args.load_model != '':
        model_name = args.load_model
    else:
        model_name = '{}_{}'.format(args.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    logdir = './logs'
    makedirs(logdir)
    logdir = os.path.join(logdir, model_name)
    logger.info('Events directory: %s', logdir)
    summary_writer = tf.summary.FileWriter(logdir)

    makedirs('./results')

    def init_fn(sess):
        logger.info('Initializing all parameters.')
        sess.run(init_all_op)

    sv = tf.train.Supervisor(is_chief=True,
                             logdir=logdir,
                             saver=saver,
                             summary_op=None,
                             init_op=init_op,
                             init_fn=init_fn,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(variables_to_save),
                             global_step=model.global_step,
                             save_model_secs=300,
                             save_summaries_secs=30)

    if args.train:
        logger.info("Starting training session.")
        with sv.managed_session() as sess:
            model.train(sess, summary_writer, train_A, train_B)

    logger.info("Starting testing session.")
    with sv.managed_session() as sess:
        base_dir = os.path.join('results', model_name)
        makedirs(base_dir)
        model.test(sess, test_A, test_B, base_dir)
项目:RL-Universe    作者:Bifrost-Research    | 项目源码 | 文件源码
def run(args, server):
    env = create_env(args.env_id, client_id=str(args.task), remotes=args.remotes)
    trainer = A3C(env, args.task)

    # Variable names that start with "local" are not saved in checkpoints.
    variables_to_save = [v for v in tf.all_variables() if not v.name.startswith("local")]
    init_op = tf.initialize_variables(variables_to_save)
    init_all_op = tf.initialize_all_variables()
    saver = FastSaver(variables_to_save)

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)])
    logdir = os.path.join(args.log_dir, 'train')
    summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task)
    logger.info("Events directory: %s_%s", logdir, args.task)
    sv = tf.train.Supervisor(is_chief=(args.task == 0),
                             logdir=logdir,
                             saver=saver,
                             summary_op=None,
                             init_op=init_op,
                             init_fn=init_fn,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(variables_to_save),
                             global_step=trainer.global_step,
                             save_model_secs=30,
                             save_summaries_secs=30)

    num_global_steps = 100000000

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified.")
    with sv.managed_session(server.target, config=config) as sess, sess.as_default():
        trainer.start(sess, summary_writer)
        global_step = sess.run(trainer.global_step)
        logger.info("Starting training at step=%d", global_step)
        while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps):
            trainer.process(sess)
            global_step = sess.run(trainer.global_step)

    # Ask for all the services to stop.
    sv.stop()
    logger.info('reached %s steps. worker stopped.', global_step)
项目:RL-Universe    作者:Bifrost-Research    | 项目源码 | 文件源码
def run(args, server):
    env = atari_environment.AtariEnvironment(args.game)
    trainer = A3C(env, args.task)

    # Variable names that start with "local" are not saved in checkpoints.
    variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
    init_op = tf.variables_initializer(variables_to_save)
    init_all_op = tf.global_variables_initializer()
    saver = FastSaver(variables_to_save)

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)])
    logdir = os.path.join(args.log_dir, 'train')
    summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
    logger.info("Events directory: %s_%s", logdir, args.task)

    sv = tf.train.Supervisor(is_chief=(args.task == 0),
                             logdir=logdir,
                             saver=saver,
                             summary_op=None,
                             init_op=init_op,
                             init_fn=init_fn,
                             summary_writer=summary_writer,
                             ready_op=tf.report_uninitialized_variables(variables_to_save),
                             global_step=trainer.global_step,
                             save_model_secs=30,
                             save_summaries_secs=30)

    num_global_steps = 100000000

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified.")
    with sv.managed_session(server.target, config=config) as sess, sess.as_default():
        trainer.start(sess, summary_writer)
        global_step = sess.run(trainer.global_step)
        logger.info("Starting training at step=%d", global_step)
        while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps):
            trainer.process(sess)
            global_step = sess.run(trainer.global_step)

    # Ask for all the services to stop.
    sv.stop()
    logger.info('reached %s steps. worker stopped.', global_step)


##
## @brief      Genrates the host and port for each server
##
## @param      num_workers  The number of workers
## @param      num_ps       The number of ps
##
## @return     dict representing the specification of the cluster
##