Python tensorflow 模块,flags() 实例源码

我们从Python开源项目中,提取了以下12个代码示例,用于说明如何使用tensorflow.flags()

项目:text-gan-tensorflow    作者:tokestermw    | 项目源码 | 文件源码
def get_supervisor(model):
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.model_dir)

    supervisor = tf.train.Supervisor(
        logdir=FLAGS.model_dir,
        is_chief=True,
        saver=saver,
        init_op=set_initial_ops(),
        summary_op=tf.summary.merge_all(),
        summary_writer=summary_writer,
        save_summaries_secs=100,  # TODO: add as flags
        save_model_secs=1000,
        global_step=model.global_step,
    )

    return supervisor
项目:spotify-tensorflow    作者:spotify    | 项目源码 | 文件源码
def register_dataset_flags():
        logging.info("Registering Dataset flags")
        flags.DEFINE_integer("batch_size", 128,
                             "Size of the batch of the dataset iterator.")

        flags.DEFINE_integer("buffer_size", 512,
                             "Size of the buffer of the dataset iterator.")

        flags.DEFINE_integer("take_count", -1,
                             "Creates a `Dataset` with at most `count` batches from this dataset.")

        flags.DEFINE_string("train_subdir", "train",
                            "Location of training TFRecords, with the training set dir.")

        flags.DEFINE_string("eval_subdir", "eval",
                            "Location of eval TFRecords, with the training set dir.")
项目:text-gan-tensorflow    作者:tokestermw    | 项目源码 | 文件源码
def get_sess_config():
    # gpu_options = tf.GPUOptions(
    # per_process_gpu_memory_fraction=self.gpu_memory_fraction,
    # allow_growth=True) # seems to be not working

    sess_config = tf.ConfigProto(
        # log_device_placement=True,
        inter_op_parallelism_threads=8,  # TODO: add as flags
        # allow_soft_placement=True,
        # gpu_options=gpu_options)
    )

    return sess_config
项目:stuff    作者:yaroslavvb    | 项目源码 | 文件源码
def run_benchmark_distributed():
  ops = create_graph("/job:worker/task:0", "/job:worker/task:1")
  queues = [create_done_queue(0), create_done_queue(1)]

  # launch distributed service


  port0, port1 = [portpicker.pick_unused_port() for _ in range(2)]
  flags = " ".join(sys.argv)  # pass parent flags to children

  def run_worker(w):
    my_env = os.environ.copy()
    if not FLAGS.verbose:
      my_env["CUDA_VISIBLE_DEVICES"] = ""
      my_env["TF_CPP_MIN_LOG_LEVEL"] = "2"
    if FLAGS.profile:
      my_env["LD_PRELOAD"]="/usr/lib/libtcmalloc_and_profiler.so.4"
      my_env["CPUPROFILE"]="/tmp/profile.out.%s"%(w)
    cmd = "python %s --task=%d --port0=%s --port1=%s"%(flags, w, port0, port1)
    subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT,
                     env=my_env)

  run_worker(0)
  run_worker(1)

  sess = tf.Session("grpc://%s:%s"%(host, port0), config=session_config())
  rate = run_benchmark(sess, *ops)

  # bring down workers
  if FLAGS.verbose:
    print("Killing workers.")
  sess.run(queues[1].enqueue(1))
  sess.run(queues[0].enqueue(1))  # bring down master last

  return rate
项目:stuff    作者:yaroslavvb    | 项目源码 | 文件源码
def run_benchmark_distributed():
  ops = create_graph("/job:worker/task:0", "/job:worker/task:1")
  queues = [create_done_queue(0), create_done_queue(1)]

  # launch distributed service


  port0, port1 = [portpicker.pick_unused_port() for _ in range(2)]
  flags = " ".join(sys.argv)  # pass parent flags to children

  def run_worker(w):
    my_env = os.environ.copy()
    if not FLAGS.verbose:
      my_env["CUDA_VISIBLE_DEVICES"] = ""
      my_env["TF_CPP_MIN_LOG_LEVEL"] = "2"
    if FLAGS.profile:
      my_env["LD_PRELOAD"]="/usr/lib/libtcmalloc_and_profiler.so.4"
      my_env["CPUPROFILE"]="/tmp/profile.out.%s"%(w)
    cmd = "python %s --task=%d --port0=%s --port1=%s"%(flags, w, port0, port1)
    subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT,
                     env=my_env)

  run_worker(0)
  run_worker(1)

  sess = tf.Session("grpc://%s:%s"%(host, port0), config=session_config())
  rate = run_benchmark(sess, *ops)

  # bring down workers
  if FLAGS.verbose:
    print("Killing workers.")
  sess.run(queues[1].enqueue(1))
  # todo: sleep to avoid killing master too early?
  sess.run(queues[0].enqueue(1))  # bring down master last

  return rate
项目:spotify-tensorflow    作者:spotify    | 项目源码 | 文件源码
def register_core_flags():
        logging.info("Registering core spotify-tensorflow flags")
        flags.DEFINE_string("training_set", None,
                            "Location of the training set")

        flags.DEFINE_string("job-dir", None,
                            "Where to write data")
项目:vaelm    作者:kaiix    | 项目源码 | 文件源码
def restore_config(config_file):
    with open(config_file) as f:
        flags = pickle.load(f)
    for k, v in flags.iteritems():
        setattr(FLAGS, k, v)
项目:vaelm    作者:kaiix    | 项目源码 | 文件源码
def save_config(config_file):
    with open(config_file, 'w') as f:
        flags = get_flags()
        saved_flags = {}
        for k in _SAVE_FLAGS:
            saved_flags[k] = flags[k]
        pickle.dump(saved_flags, f)
项目:vaelm    作者:kaiix    | 项目源码 | 文件源码
def restore_config(config_file):
    with open(config_file) as f:
        flags = pickle.load(f)
    for k, v in flags.iteritems():
        setattr(FLAGS, k, v)
项目:vaelm    作者:kaiix    | 项目源码 | 文件源码
def save_config(config_file):
    with open(config_file, 'w') as f:
        flags = get_flags()
        saved_flags = {}
        for k in _SAVE_FLAGS:
            saved_flags[k] = flags[k]
        pickle.dump(saved_flags, f)
项目:sgnmt    作者:ucam-smt    | 项目源码 | 文件源码
def vocab_size(self):
            return self._vocab_size

    # Define flags from the t2t binaries
项目:tensor2tensor    作者:tensorflow    | 项目源码 | 文件源码
def __init__(self, data_dir, model_dir):
    """Creates the Transformer estimator.

    Args:
      data_dir: The training data directory.
      model_dir: The trained model directory.
    """
    # Do the pre-setup tensor2tensor requires for flags and configurations.
    FLAGS.output_dir = model_dir
    FLAGS.data_dir = data_dir
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    data_dir = os.path.expanduser(data_dir)

    # Create the basic hyper parameters.
    self.hparams = tpu_trainer_lib.create_hparams(
        FLAGS.hparams_set,
        FLAGS.hparams,
        data_dir=data_dir,
        problem_name=FLAGS.problems)

    decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
    decode_hp.add_hparam("shards", 1)
    decode_hp.add_hparam("shard_id", 0)

    # Create the estimator and final hyper parameters.
    self.estimator = tpu_trainer_lib.create_estimator(
        FLAGS.model,
        self.hparams,
        tpu_trainer.create_run_config(),
        decode_hp, use_tpu=False)

    # Fetch the vocabulary and other helpful variables for decoding.
    self.source_vocab = self.hparams.problems[0].vocabulary["inputs"]
    self.targets_vocab = self.hparams.problems[0].vocabulary["targets"]
    self.const_array_size = 10000

    # Prepare the Transformer's debug data directory.
    run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
    for run_dir in run_dirs:
      shutil.rmtree(run_dir)