Python models 模块,Generator() 实例源码

我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用models.Generator()

项目:tf-gogh    作者:n-kats    | 项目源码 | 文件源码
def main():
  args = parse_args()
  config = Config(args)

  # ??????
  os.makedirs(config.output_dir, exist_ok=True)

  # ??????
  model = models.generate_model(config.model)

  # ????????
  img_orig = load_image(config.original_image, [config.width, config.height])
  img_style = load_image(config.style_image, [config.width, config.height] if not config.no_resize_style else None)

  # ?????
  generator = models.Generator(model, img_orig, img_style, config)
  generator.generate(config)
项目:tf-gogh    作者:n-kats    | 项目源码 | 文件源码
def main():
  args = parse_args()
  config = Config(args)

  # ??????
  os.makedirs(config.output_dir, exist_ok=True)

  # ??????
  model = models.generate_model(config.model)

  # ????????
  img_orig = load_image(config.original_image, [config.width, config.height])
  img_style = load_image(config.style_image, [config.width, config.height] if not config.no_resize_style else None)

  # ?????
  generator = models.Generator(model, img_orig, img_style, config)
  generator.generate(config)
项目:chainer-wasserstein-gan    作者:hvy    | 项目源码 | 文件源码
def train(args):
    nz = args.nz
    batch_size = args.batch_size
    epochs = args.epochs
    gpu = args.gpu

    # CIFAR-10 images in range [-1, 1] (tanh generator outputs)
    train, _ = datasets.get_cifar10(withlabel=False, ndim=3, scale=2)
    train -= 1.0
    train_iter = iterators.SerialIterator(train, batch_size)

    z_iter = RandomNoiseIterator(GaussianNoiseGenerator(0, 1, args.nz),
                                 batch_size)

    optimizer_generator = optimizers.RMSprop(lr=0.00005)
    optimizer_critic = optimizers.RMSprop(lr=0.00005)
    optimizer_generator.setup(Generator())
    optimizer_critic.setup(Critic())

    updater = WassersteinGANUpdater(
        iterator=train_iter,
        noise_iterator=z_iter,
        optimizer_generator=optimizer_generator,
        optimizer_critic=optimizer_critic,
        device=gpu)

    trainer = training.Trainer(updater, stop_trigger=(epochs, 'epoch'))
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.LogReport(trigger=(1, 'iteration')))
    trainer.extend(GeneratorSample(), trigger=(1, 'epoch'))
    trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'critic/loss',
            'critic/loss/real', 'critic/loss/fake', 'generator/loss']))
    trainer.run()
项目:rarepepes    作者:kendricktan    | 项目源码 | 文件源码
def __init__(self, nic, noc, ngf, ndf, beta=0.5, lamb=100, lr=1e-3, cuda=True, crayon=False):
        """
        Args:
            nic: Number of input channel
            noc: Number of output channels
            ngf: Number of generator filters
            ndf: Number of discriminator filters
            lamb: Weight on L1 term in objective
        """
        self.cuda = cuda
        self.start_epoch = 0

        self.crayon = crayon
        if crayon:
            self.cc = CrayonClient(hostname="localhost", port=8889)

            try:
                self.logger = self.cc.create_experiment('pix2pix')
            except:
                self.cc.remove_experiment('pix2pix')
                self.logger = self.cc.create_experiment('pix2pix')

        self.gen = self.cudafy(Generator(nic, noc, ngf))
        self.dis = self.cudafy(Discriminator(nic, noc, ndf))

        # Optimizers for generators
        self.gen_optim = self.cudafy(optim.Adam(
            self.gen.parameters(), lr=lr, betas=(beta, 0.999)))

        # Optimizers for discriminators
        self.dis_optim = self.cudafy(optim.Adam(
            self.dis.parameters(), lr=lr, betas=(beta, 0.999)))

        # Loss functions
        self.criterion_bce = nn.BCELoss()
        self.criterion_mse = nn.MSELoss()
        self.criterion_l1 = nn.L1Loss()

        self.lamb = lamb
项目:rarepepes    作者:kendricktan    | 项目源码 | 文件源码
def train(self, loader, c_epoch):
        self.dis.train()
        self.gen.train()
        self.reset_gradients()

        max_idx = len(loader)
        for idx, features in enumerate(tqdm(loader)):
            orig_x = Variable(self.cudafy(features[0]))
            orig_y = Variable(self.cudafy(features[1]))

            """ Discriminator """
            # Train with real
            self.dis.volatile = False
            dis_real = self.dis(torch.cat((orig_x, orig_y), 1))
            real_labels = Variable(self.cudafy(
                torch.ones(dis_real.size())
            ))
            dis_real_loss = self.criterion_bce(
                dis_real, real_labels)

            # Train with fake
            gen_y = self.gen(orig_x)
            dis_fake = self.dis(torch.cat((orig_x, gen_y.detach()), 1))
            fake_labels = Variable(self.cudafy(
                torch.zeros(dis_fake.size())
            ))
            dis_fake_loss = self.criterion_bce(
                dis_fake, fake_labels)

            # Update weights
            dis_loss = dis_real_loss + dis_fake_loss
            dis_loss.backward()

            self.dis_optim.step()
            self.reset_gradients()

            """ Generator """
            self.dis.volatile = True
            dis_real = self.dis(torch.cat((orig_x, gen_y), 1))
            real_labels = Variable(self.cudafy(
                torch.ones(dis_real.size())
            ))
            gen_loss = self.criterion_bce(dis_real, real_labels) + \
                self.lamb * self.criterion_l1(gen_y, orig_y)
            gen_loss.backward()
            self.gen_optim.step()

            # Pycrayon or nah
            if self.crayon:
                self.logger.add_scalar_value('pix2pix_gen_loss', gen_loss.data[0])
                self.logger.add_scalar_value('pix2pix_dis_loss', dis_loss.data[0])

            if idx % 50 == 0:
                tqdm.write('Epoch: {} [{}/{}]\t'
                           'D Loss: {:.4f}\t'
                           'G Loss: {:.4f}'.format(
                               c_epoch, idx, max_idx, dis_loss.data[0], gen_loss.data[0]
                           ))