Python model 模块,DCGAN 实例源码

我们从Python开源项目中,提取了以下9个代码示例,用于说明如何使用model.DCGAN

项目:WassersteinGAN-TensorFlow    作者:MustafaMustafa    | 项目源码 | 文件源码
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        dcgan = DCGAN(sess, 
                      dataset=FLAGS.dataset,
                      batch_size=FLAGS.batch_size,
                      output_size=FLAGS.output_size,
                      c_dim=FLAGS.c_dim,
                      z_dim=FLAGS.z_dim)

        if FLAGS.is_train:
            if FLAGS.preload_data == True:
                data = get_data_arr(FLAGS)
            else:
                data = glob(os.path.join('./data', FLAGS.dataset, '*.jpg'))
            train.train_wasserstein(sess, dcgan, data, FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)
项目:Magic-Pixel    作者:zhwhong    | 项目源码 | 文件源码
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10,
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)
        else:
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)

        if FLAGS.is_train:
            dcgan.train(FLAGS)
        else:
            if FLAGS.is_single:
                dcgan.single_test(FLAGS.checkpoint_dir, FLAGS.file_name)
            elif FLAGS.is_small:
                dcgan.batch_test2(FLAGS.checkpoint_dir)
            else:
                dcgan.batch_test(FLAGS.checkpoint_dir, FLAGS.file_name)
            # dcgan.load(FLAGS.checkpoint_dir)
            # dcgan.single_test(FLAGS.checkpoint_dir)
            # dcgan.batch_test(FLAGS.checkpoint_dir)

        """
        if FLAGS.visualize:
            to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
                                          [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
                                          [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
                                          [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
                                          [dcgan.h4_w, dcgan.h4_b, None])

            # Below is codes for visualization
            OPTION = 2
            visualize(sess, dcgan, FLAGS, OPTION)
        """
项目:DeepLearning    作者:Wanwannodao    | 项目源码 | 文件源码
def main(_):
    with tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_dev_placement)) as sess:
        dcgan = DCGAN(sess, batch_size=FLAGS.batch_size,
                      #in_dim=[28,28,1], z_dim=100)
                      in_dim=[112,112,3], z_dim=100)
        dcgan.train(FLAGS)
项目:easygen    作者:markriedl    | 项目源码 | 文件源码
def train(epoch = 25, learning_rate = 0.0002, beta1 = 0.5, train_size = np.inf, batch_size = 64, input_height = 108, input_width = None, output_height = 64, output_width = None, dataset = 'celebA', input_fname_pattern = '*.jpg', checkpoint_dir = 'checkpoints', sample_dir = 'samples', output_dir = 'output', crop = True, model_dir = 'temp', model_filename = 'dcgan'):
  #pp.pprint(flags.FLAGS.__flags)

  if input_width is None:
    input_width = input_height
  if output_width is None:
    output_width = output_height

  #if not os.path.exists(checkpoint_dir):
  #  os.makedirs(checkpoint_dir)
  if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
  if not os.path.exists(output_dir):
    os.makedirs(output_dir)

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    dcgan = DCGAN(
        sess,
        input_width=input_width,
        input_height=input_height,
        output_width=output_width,
        output_height=output_height,
        batch_size=batch_size,
        sample_num=batch_size,
        dataset_name=dataset,
        input_fname_pattern=input_fname_pattern,
        crop=crop,
        checkpoint_dir=checkpoint_dir,
        sample_dir=sample_dir,
        output_dir=output_dir)

    show_all_variables()

    dcgan.train(epoch = epoch, learning_rate = learning_rate, beta1 = beta1, train_size = train_size, batch_size = batch_size, input_height = input_height, input_width = input_width, output_height = output_height, output_width = output_width, dataset = dataset, input_fname_pattern = input_fname_pattern, checkpoint_dir = checkpoint_dir, sample_dir = sample_dir, output_dir = output_dir, train = train, crop = crop)

    dcgan.save(model_dir, dcgan.global_training_steps, model_filename)
项目:Mendelssohn    作者:diggerdu    | 项目源码 | 文件源码
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        dcgan = DCGAN(sess, image_size = FLAGS.image_size, output_size = FLAGS.output_size, batch_size=FLAGS.batch_size, sample_size = FLAGS.sample_size)

        if FLAGS.is_train:
            dcgan.train(FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        if FLAGS.visualize:
            # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
            #                               [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
            #                               [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
            #                               [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
            #                               [dcgan.h4_w, dcgan.h4_b, None])

            # Below is codes for visualization
            OPTION = 2
            visualize(sess, dcgan, FLAGS, OPTION)
项目:streetview    作者:ydnaandy123    | 项目源码 | 文件源码
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session(config=tf.ConfigProto(
              allow_soft_placement=True, log_device_placement=False)) as sess:
        if FLAGS.dataset == 'mnist':
            assert False
        dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
                    sample_size = 64,
                    z_dim = 8192,
                    d_label_smooth = .25,
                    generator_target_prob = .75 / 2.,
                    out_stddev = .075,
                    out_init_b = - .45,
                    image_shape=[FLAGS.image_width, FLAGS.image_width, 3],
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir,
                    sample_dir=FLAGS.sample_dir,
                    generator=Generator(),
                    train_func=train, discriminator_func=discriminator,
                    build_model_func=build_model, config=FLAGS,
                    devices=["gpu:0", "gpu:1", "gpu:2", "gpu:3"] #, "gpu:4"]
                    )

        if FLAGS.is_train:
            print "TRAINING"
            dcgan.train(FLAGS)
            print "DONE TRAINING"
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        OPTION = 2
        visualize(sess, dcgan, FLAGS, OPTION)
项目:easygen    作者:markriedl    | 项目源码 | 文件源码
def run(checkpoint_dir = 'checkpoints', batch_size = 64, input_height = 108, input_width = None, output_height = 64, output_width = None, dataset = 'celebA', input_fname_pattern = '*.jpg', output_dir = 'output', sample_dir = 'samples', crop=True):
  #pp.pprint(flags.FLAGS.__flags)

  if input_width is None:
    input_width = input_height
  if output_width is None:
    output_width = output_height

  #if not os.path.exists(checkpoint_dir):
  #  os.makedirs(checkpoint_dir)
  #if not os.path.exists(output_dir):
  #  os.makedirs(output_dir)

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    dcgan = DCGAN(
        sess,
        input_width=input_width,
        input_height=input_height,
        output_width=output_width,
        output_height=output_height,
        batch_size=batch_size,
        sample_num=batch_size,
        dataset_name=dataset,
        input_fname_pattern=input_fname_pattern,
        crop=crop,
        checkpoint_dir=checkpoint_dir,
        sample_dir=sample_dir,
        output_dir=output_dir)

    show_all_variables()

    try:
      tf.global_variables_initializer().run()
    except:
      tf.initialize_all_variables().run()

    # Below is code for visualization
    visualize(sess, dcgan, batch_size = batch_size, input_height = input_height, input_width = input_width, output_dir = output_dir)
项目:streetview    作者:ydnaandy123    | 项目源码 | 文件源码
def main(_):

    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    # Do not take all memory
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.30)
    # sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        # w/ y label
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10, output_size=28,
                          c_dim=1, dataset_name=FLAGS.dataset,
                          checkpoint_dir=FLAGS.checkpoint_dir)
        # w/o y label
        else:
            if FLAGS.dataset == 'cityscapes':
                print 'Select CITYSCAPES'
                mask_dir = CITYSCAPES_mask_dir
                syn_dir = CITYSCAPES_syn_dir_2
                FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 192, 512, False
                FLAGS.dataset_dir = CITYSCAPES_dir
            elif FLAGS.dataset == 'inria':
                print 'Select INRIAPerson'
                FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 160, 96, False
                FLAGS.dataset_dir = INRIA_dir

            discriminator = Discriminator(sess, batch_size=FLAGS.batch_size, output_size_h=FLAGS.output_size_h, output_size_w=FLAGS.output_size_w, c_dim=FLAGS.c_dim,
                          dataset_name=FLAGS.dataset,
                          checkpoint_dir=FLAGS.checkpoint_dir, dataset_dir=FLAGS.dataset_dir)

        if FLAGS.mode == 'test':
            print('Testing!')
            discriminator.test(FLAGS, syn_dir)
        elif FLAGS.mode == 'train':
            print('Train!')
            discriminator.train(FLAGS, syn_dir)
        elif FLAGS.mode == 'complete':
            print('Complete!')
项目:streetview    作者:ydnaandy123    | 项目源码 | 文件源码
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    # Do not take all memory
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.80)
    # sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        # w/ y label
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10, output_size=28,
                          c_dim=1, dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop,
                          checkpoint_dir=FLAGS.checkpoint_dir)
        # w/o y label
        else:
            if FLAGS.dataset == 'cityscapes':
                print 'Select CITYSCAPES'
                mask_dir = CITYSCAPES_mask_dir
                FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 192, 512, False
                FLAGS.dataset_dir = CITYSCAPES_dir
            elif FLAGS.dataset == 'inria':
                print 'Select INRIAPerson'
                FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 160, 96, False
                FLAGS.dataset_dir = INRIA_dir
            elif FLAGS.dataset == 'indoor':
                print 'Select indoor'
                FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 256, 256, False
                FLAGS.dataset_dir = indoor_dir
            elif FLAGS.dataset == 'indoor_bedroom':
                print 'Select indoor bedroom'
                FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 256, 256, False
                FLAGS.dataset_dir = indoor_bedroom_dir
            elif FLAGS.dataset == 'indoor_dining':
                print 'Select indoor dining'
                FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 256, 256, False
                FLAGS.dataset_dir = indoor_bedroom_dir

            dcgan = DCGAN(sess, batch_size=FLAGS.batch_size, output_size_h=FLAGS.output_size_h, output_size_w=FLAGS.output_size_w, c_dim=FLAGS.c_dim,
                          dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop,
                          checkpoint_dir=FLAGS.checkpoint_dir, dataset_dir=FLAGS.dataset_dir)

        if FLAGS.mode == 'test':
            print('Testing!')
            dcgan.test(FLAGS)
        elif FLAGS.mode == 'train':
            print('Train!')
            dcgan.train(FLAGS)
        elif FLAGS.mode == 'complete':
            print('Complete!')
            dcgan.complete(FLAGS, mask_dir)