Python tensorflow 模块,saturate_cast() 实例源码

我们从Python开源项目中,提取了以下15个代码示例,用于说明如何使用tensorflow.saturate_cast()

项目:ml_gans    作者:imironhead    | 项目源码 | 文件源码
def build_summaries(gan):
    """
    """
    generator_loss_summary = tf.summary.scalar(
        'generator loss', gan['generator_loss'])

    discriminator_loss_summary = tf.summary.scalar(
        'discriminator loss', gan['discriminator_loss'])

    fake_grid = tf.reshape(gan['generator_fake'], [1, 64 * 32, 32, 1])
    fake_grid = tf.split(fake_grid, 8, axis=1)
    fake_grid = tf.concat(fake_grid, axis=2)
    fake_grid = tf.saturate_cast(fake_grid * 127.5 + 127.5, tf.uint8)

    generator_fake_summary = tf.summary.image(
        'generated image', fake_grid, max_outputs=18)

    return {
        'generator_fake_summary': generator_fake_summary,
        'generator_loss_summary': generator_loss_summary,
        'discriminator_loss_summary': discriminator_loss_summary,
    }
项目:ml_gans    作者:imironhead    | 项目源码 | 文件源码
def build_summaries(model):
    """
    build image summary: [source batch, target batch, result batch]
    """
    keys = ['source_images', 'target_images', 'output_images']

    images = tf.concat([model[k] for k in keys], axis=2)

    images = tf.reshape(images, [1, FLAGS.batch_size * 256, 768, 3])

    images = tf.saturate_cast(images * 127.5 + 127.5, tf.uint8)

    summary = tf.summary.image('images', images, max_outputs=4)

    return {
        'summary': summary,
    }
项目:ml_gans    作者:imironhead    | 项目源码 | 文件源码
def build_summaries(gan_graph):
    """
    """
    generator_loss_summary = tf.summary.scalar(
        'generator loss', gan_graph['generator_loss'])

    discriminator_loss_summary = tf.summary.scalar(
        'discriminator loss', gan_graph['discriminator_loss'])

    fake_grid = tf.reshape(gan_graph['generator_fake'], [1, 64 * 32, 32, 1])
    fake_grid = tf.split(fake_grid, 8, axis=1)
    fake_grid = tf.concat(fake_grid, axis=2)
    fake_grid = tf.saturate_cast(fake_grid * 127.5 + 127.5, tf.uint8)

    generator_fake_summary = tf.summary.image(
        'generated image', fake_grid, max_outputs=18)

    return {
        'generator_fake_summary': generator_fake_summary,
        'generator_loss_summary': generator_loss_summary,
        'discriminator_loss_summary': discriminator_loss_summary,
    }
项目:ml_gans    作者:imironhead    | 项目源码 | 文件源码
def build_summaries(gan_graph):
    """
    """
    generator_loss_summary = tf.summary.scalar(
        'generator loss', gan_graph['generator_loss'])

    discriminator_loss_summary = tf.summary.scalar(
        'discriminator loss', gan_graph['discriminator_loss'])

    fake_grid = tf.reshape(gan_graph['generator_fake'], [1, 64 * 64, 64, 3])
    fake_grid = tf.split(fake_grid, 8, axis=1)
    fake_grid = tf.concat(fake_grid, axis=2)
    fake_grid = tf.saturate_cast(fake_grid * 127.5 + 127.5, tf.uint8)

    generator_fake_summary = tf.summary.image(
        'generated image', fake_grid, max_outputs=1)

    return {
        'generated_png': tf.image.encode_png(fake_grid[0]),
        'generator_fake_summary': generator_fake_summary,
        'generator_loss_summary': generator_loss_summary,
        'discriminator_loss_summary': discriminator_loss_summary,
    }
项目:ml_gans    作者:imironhead    | 项目源码 | 文件源码
def build_summaries(network):
    """
    """
    summaries = {}

    real = network['real']
    fake = network['fake']
    cute = network['ae_output_fake']

    image = tf.concat([real, fake, cute], axis=0)

    grid = tf.reshape(image, [1, 3 * FLAGS.image_size, FLAGS.image_size, 3])
    grid = tf.split(grid, 3, axis=1)
    grid = tf.concat(grid, axis=2)
    grid = tf.saturate_cast(grid * 127.5 + 127.5, tf.uint8)

    summaries['comparison'] = tf.summary.image('comp', grid, max_outputs=4)

    return summaries
项目:ml_gans    作者:imironhead    | 项目源码 | 文件源码
def build_summaries(model):
    """
    """
    images_summary = []

    generations = [
        ('summary_x_gx', 'xx_real', 'gx_fake'),
        ('summary_y_fy', 'yy_real', 'fy_fake')]

    for g in generations:
        images = tf.concat([model[g[1]], model[g[2]]], axis=2)

        images = tf.reshape(images, [1, FLAGS.batch_size * 256, 512, 3])

        images = tf.saturate_cast(images * 127.5 + 127.5, tf.uint8)

        summary = tf.summary.image(g[0], images, max_outputs=4)

        images_summary.append(summary)

    #
    summary_loss_d = tf.summary.scalar('d', model['loss_d'])
    summary_loss_dx = tf.summary.scalar('dx', model['loss_dx'])
    summary_loss_dy = tf.summary.scalar('dy', model['loss_dy'])
    summary_d = \
        tf.summary.merge([summary_loss_d, summary_loss_dx, summary_loss_dy])

    summary_loss_g = tf.summary.scalar('g', model['loss_g'])
    summary_loss_gx = tf.summary.scalar('gx', model['loss_gx'])
    summary_loss_fy = tf.summary.scalar('fy', model['loss_fy'])
    summary_g = \
        tf.summary.merge([summary_loss_g, summary_loss_gx, summary_loss_fy])

    return {
        'images': tf.summary.merge(images_summary),
        'loss_d': summary_d,
        'loss_g': summary_g,
    }
项目:ml_gans    作者:imironhead    | 项目源码 | 文件源码
def translate():
    """
    """
    image_path_pairs = prepare_paths()

    reals = tf.placeholder(shape=[None, 256, 256, 3], dtype=tf.uint8)

    flow = tf.cast(reals, dtype=tf.float32) / 127.5 - 1.0

    model = build_cycle_gan(flow, flow, FLAGS.mode)

    fakes = tf.saturate_cast(model['fake'] * 127.5 + 127.5, tf.uint8)

    # path to checkpoint
    ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir_path)

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(tf.local_variables_initializer())

        tf.train.Saver().restore(session, ckpt_source_path)

        for i in range(0, len(image_path_pairs), FLAGS.batch_size):
            path_pairs = image_path_pairs[i:i+FLAGS.batch_size]

            real_images = [scipy.misc.imread(p[0]) for p in path_pairs]

            fake_images = session.run(fakes, feed_dict={reals: real_images})

            for idx, path in enumerate(path_pairs):
                image = np.concatenate(
                    [real_images[idx], fake_images[idx]], axis=1)

                scipy.misc.imsave(path[1], image)
项目:ml_gans    作者:imironhead    | 项目源码 | 文件源码
def reshape_batch_images(batch_images):
    """
    """
    batch_size = FLAGS.batch_size
    image_size = FLAGS.image_size

    # build summary for generated fake images.
    grid = \
        tf.reshape(batch_images, [1, batch_size * image_size, image_size, 3])
    grid = tf.split(grid, FLAGS.summary_row_size, axis=1)
    grid = tf.concat(grid, axis=2)
    grid = tf.saturate_cast(grid * 127.5 + 127.5, tf.uint8)

    return grid
项目:ml_gans    作者:imironhead    | 项目源码 | 文件源码
def build_image_grid(image_batch, row, col):
    """
    Build an image grid from an image batch.
    """
    image_size = FLAGS.image_size

    grid = tf.reshape(
        image_batch, [1, row * col * image_size, image_size, 3])
    grid = tf.split(grid, col, axis=1)
    grid = tf.concat(grid, axis=2)
    grid = tf.saturate_cast(grid * 127.5 + 127.5, tf.uint8)
    grid = tf.reshape(grid, [row * image_size, col * image_size, 3])

    return grid
项目:ml_styles    作者:imironhead    | 项目源码 | 文件源码
def build_summaries(network):
    """
    """
    # summary_loss = tf.summary.scalar('transfer loss', network['loss'])

    images_c = network['image_content']
    images_s = network['image_styled']

    images_c = tf.slice(
        images_c,
        [0, FLAGS.padding, FLAGS.padding, 0],
        [-1, 256, 256, -1])

    images_s = tf.slice(
        images_s,
        [0, FLAGS.padding, FLAGS.padding, 0],
        [-1, 256, 256, -1])

    images_c = tf.reshape(images_c, [1, FLAGS.batch_size * 256, 256, 3])
    images_s = tf.reshape(images_s, [1, FLAGS.batch_size * 256, 256, 3])

    images_a = tf.concat([images_c, images_s], axis=2)
    images_a = images_a * 127.5 + 127.5
    # images_a = tf.add(images_a, VggNet.mean_color_bgr())
    images_a = tf.reverse(images_a, [3])
    images_a = tf.saturate_cast(images_a, tf.uint8)

    summary_image = tf.summary.image('all', images_a, max_outputs=4)

    # summary_plus = tf.summary.merge([summary_image, summary_loss])

    return {
        # 'summary_part': summary_loss,
        'summary_plus': summary_image,
    }
项目:ml_styles    作者:imironhead    | 项目源码 | 文件源码
def transfer_summary(vgg, loss, content_shape):
    """
    summaries of loss and result image.
    """
    image = tf.add(vgg.upstream, VggNet.mean_color_bgr())
    image = tf.image.resize_images(image, content_shape)
    image = tf.saturate_cast(image, tf.uint8)
    image = tf.reverse(image, [3])

    summary_image = tf.summary.image('generated image', image, max_outputs=1)

    summary_loss = tf.summary.scalar('transfer loss', loss)

    return tf.summary.merge([summary_image, summary_loss])
项目:neural_style_tensorflow    作者:burness    | 项目源码 | 文件源码
def main(argv=None):
    if not FLAGS.CONTENT_IMAGES_PATH:
        print "train a fast nerual style need to set the Content images path"
        return
    content_images = reader.image(
            FLAGS.BATCH_SIZE,
            FLAGS.IMAGE_SIZE,
            FLAGS.CONTENT_IMAGES_PATH,
            epochs=1,
            shuffle=False,
            crop=False)
    generated_images = model.net(content_images / 255.)

    output_format = tf.saturate_cast(generated_images + reader.mean_pixel, tf.uint8)
    with tf.Session() as sess:
        file = tf.train.latest_checkpoint(FLAGS.MODEL_PATH)
        if not file:
            print('Could not find trained model in {0}'.format(FLAGS.MODEL_PATH))
            return
        print('Using model from {}'.format(file))
        saver = tf.train.Saver()
        saver.restore(sess, file)
        sess.run(tf.initialize_local_variables())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        i = 0
        start_time = time.time()
        try:
            while not coord.should_stop():
                print(i)
                images_t = sess.run(output_format)
                elapsed = time.time() - start_time
                start_time = time.time()
                print('Time for one batch: {}'.format(elapsed))

                for raw_image in images_t:
                    i += 1
                    misc.imsave('out{0:04d}.png'.format(i), raw_image)
        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
        finally:
            coord.request_stop()

        coord.join(threads)
项目:ml_gans    作者:imironhead    | 项目源码 | 文件源码
def build_summaries(gan):
    """
    """
    g_summaries = []
    d_summaries = []

    g_summaries.append(
        tf.summary.scalar('generator loss', gan['generator_loss']))

    d_summaries.append(
        tf.summary.scalar('discriminator loss', gan['discriminator_loss']))

    for vg in gan['generator_variables_gradients']:
        variable_name = '{}/variable'.format(vg[0].name)
        gradient_name = '{}/gradient'.format(vg[0].name)

        g_summaries.append(tf.summary.histogram(variable_name, vg[0]))
        g_summaries.append(tf.summary.histogram(gradient_name, vg[1]))

    for vg in gan['discriminator_variables_gradients']:
        variable_name = '{}/variable'.format(vg[0].name)
        gradient_name = '{}/gradient'.format(vg[0].name)

        d_summaries.append(tf.summary.histogram(variable_name, vg[0]))
        d_summaries.append(tf.summary.histogram(gradient_name, vg[1]))

    # fake image
    image_width, image_depth = (64, 3) if FLAGS.use_lsun else (32, 1)

    fake_grid = tf.reshape(
        gan['generator_fake'],
        [1, FLAGS.batch_size * image_width, image_width, image_depth])
    fake_grid = tf.split(fake_grid, FLAGS.summary_col_size, axis=1)
    fake_grid = tf.concat(fake_grid, axis=2)
    fake_grid = tf.saturate_cast(fake_grid * 127.5 + 127.5, tf.uint8)

    summary_generator_fake = tf.summary.image(
        'generated image', fake_grid, max_outputs=1)

    g_summaries_plus = g_summaries + [summary_generator_fake]

    return {
        'summary_generator': tf.summary.merge(g_summaries),
        'summary_generator_plus': tf.summary.merge(g_summaries_plus),
        'summary_discriminator': tf.summary.merge(d_summaries),
    }
项目:prisma    作者:hijkzzz    | 项目源码 | 文件源码
def generate():
    if not FLAGS.CONTENT_IMAGE:
        tf.logging.info("train a fast nerual style need to set the Content images path")
        return

    if not os.path.exists(FLAGS.OUTPUT_FOLDER):
        os.mkdir(FLAGS.OUTPUT_FOLDER)

    # ??????
    height = 0
    width = 0
    with open(FLAGS.CONTENT_IMAGE, 'rb') as img:
        with tf.Session().as_default() as sess:
            if FLAGS.CONTENT_IMAGE.lower().endswith('png'):
                image = sess.run(tf.image.decode_png(img.read()))
            else:
                image = sess.run(tf.image.decode_jpeg(img.read()))
            height = image.shape[0]
            width = image.shape[1]
    tf.logging.info('Image size: %dx%d' % (width, height))

    with tf.Graph().as_default(), tf.Session() as sess:
        # ??????
        path = FLAGS.CONTENT_IMAGE
        png = path.lower().endswith('png')
        img_bytes = tf.read_file(path)

        # ????
        content_image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3)
        content_image = tf.image.convert_image_dtype(content_image, tf.float32) * 255.0
        content_image = tf.expand_dims(content_image, 0)

        generated_images = transform.net(content_image - vgg.MEAN_PIXEL, training=False)
        output_format = tf.saturate_cast(generated_images, tf.uint8)

        # ????
        saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
        model_path = os.path.abspath(FLAGS.MODEL_PATH)
        tf.logging.info('Usage model {}'.format(model_path))
        saver.restore(sess, model_path)

        filename = os.path.basename(FLAGS.CONTENT_IMAGE)
        (shotname, extension) = os.path.splitext(filename)
        filename = shotname + '-' + os.path.basename(FLAGS.MODEL_PATH) + extension

        tf.logging.info("image {}".format(filename))
        images_t = sess.run(output_format)

        assert len(images_t) == 1
        misc.imsave(os.path.join(FLAGS.OUTPUT_FOLDER, filename), images_t[0])
项目:ml_styles    作者:imironhead    | 项目源码 | 文件源码
def transfer():
    """
    """
    if tf.gfile.IsDirectory(FLAGS.ckpt_path):
        ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_path)
    elif tf.gfile.Exists(FLAGS.ckpt_path):
        ckpt_source_path = FLAGS.ckpt_path
    else:
        assert False, 'bad checkpoint'

    assert tf.gfile.Exists(FLAGS.content_path), 'bad content_path'
    assert not tf.gfile.IsDirectory(FLAGS.content_path), 'bad content_path'

    image_contents = build_contents_reader()

    network = build_style_transfer_network(image_contents, training=False)

    #
    shape = tf.shape(network['image_styled'])

    new_w = shape[1] - 2 * FLAGS.padding
    new_h = shape[2] - 2 * FLAGS.padding

    image_styled = tf.slice(
        network['image_styled'],
        [0, FLAGS.padding, FLAGS.padding, 0],
        [-1, new_w, new_h, -1])

    image_styled = tf.squeeze(image_styled, [0])
    image_styled = image_styled * 127.5 + 127.5
    image_styled = tf.reverse(image_styled, [2])
    image_styled = tf.saturate_cast(image_styled, tf.uint8)
    image_styled = tf.image.encode_jpeg(image_styled)

    image_styled_writer = tf.write_file(FLAGS.styled_path, image_styled)

    with tf.Session() as session:
        tf.train.Saver().restore(session, ckpt_source_path)

        # make dataset reader work
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        session.run(image_styled_writer)

        coord.request_stop()
        coord.join(threads)