Python chainer.training.extensions 模块,ExponentialShift() 实例源码

我们从Python开源项目中,提取了以下3个代码示例,用于说明如何使用chainer.training.extensions.ExponentialShift()

项目:signature-embedding    作者:hrantzsch    | 项目源码 | 文件源码
def get_trainer(updater, evaluator, epochs):
    trainer = training.Trainer(updater, (epochs, 'epoch'), out='result')
    trainer.extend(evaluator)
    # TODO: reduce LR -- how to update every X epochs?
    # trainer.extend(extensions.ExponentialShift('lr', 0.1, target=lr*0.0001))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.ProgressBar(
        (epochs, 'epoch'), update_interval=10))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss']))
    return trainer
项目:chainer-ADDA    作者:pfnet-research    | 项目源码 | 文件源码
def pretrain_source_cnn(data, args, epochs=1000):
    print(":: pretraining source encoder")
    source_cnn = Loss(num_classes=10)
    if args.device >= 0:
        source_cnn.to_gpu()

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(source_cnn)

    train_iterator, test_iterator = data2iterator(data, args.batchsize, multiprocess=False)

    # train_iterator = chainer.iterators.MultiprocessIterator(data, args.batchsize, n_processes=4)

    updater = chainer.training.StandardUpdater(iterator=train_iterator, optimizer=optimizer, device=args.device)
    trainer = chainer.training.Trainer(updater, (epochs, 'epoch') ,out=args.output)

    # learning rate decay
    # trainer.extend(extensions.ExponentialShift("alpha", rate=0.9, init=args.learning_rate, target=args.learning_rate*10E-5))

    trainer.extend(extensions.Evaluator(test_iterator, source_cnn, device=args.device))
    # trainer.extend(extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}'), trigger=(10, "epoch"))
    trainer.extend(extensions.snapshot_object(optimizer.target, "source_model_epoch_{.updater.epoch}"), trigger=(epochs, "epoch"))

    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(extensions.LogReport(trigger=(1, "epoch")))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    trainer.run()

    return source_cnn
项目:depccg    作者:masashi-y    | 项目源码 | 文件源码
def train(args):
    model = BiaffineJaLSTMParser(args.model, args.word_emb_size, args.char_emb_size,
            args.nlayers, args.hidden_dim, args.dep_dim, args.dropout_ratio)
    with open(args.model + "/params", "w") as f: log(args, f)

    if args.initmodel:
        print('Load model from', args.initmodel)
        chainer.serializers.load_npz(args.initmodel, model)

    if args.pretrained:
        print('Load pretrained word embeddings from', args.pretrained)
        model.load_pretrained_embeddings(args.pretrained)

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()
        model.to_gpu()

    train = LSTMParserDataset(args.model, args.train)
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    val = LSTMParserDataset(args.model, args.val)
    val_iter = chainer.iterators.SerialIterator(
            val, args.batchsize, repeat=False, shuffle=False)
    optimizer = chainer.optimizers.Adam(beta2=0.9)
    # optimizer = chainer.optimizers.MomentumSGD(momentum=0.7)
    optimizer.setup(model)
    optimizer.add_hook(WeightDecay(2e-6))
    # optimizer.add_hook(GradientClipping(5.))
    updater = training.StandardUpdater(train_iter, optimizer,
            device=args.gpu, converter=converter)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.model)

    val_interval = 1000, 'iteration'
    log_interval = 200, 'iteration'

    eval_model = model.copy()
    eval_model.train = False

    trainer.extend(extensions.ExponentialShift(
                    "eps", .75, 2e-3), trigger=(2500, 'iteration'))
    trainer.extend(extensions.Evaluator(
        val_iter, eval_model, converter, device=args.gpu), trigger=val_interval)
    trainer.extend(extensions.snapshot_object(
        model, 'model_iter_{.updater.iteration}'), trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration',
        'main/tagging_accuracy', 'main/tagging_loss',
        'main/parsing_accuracy', 'main/parsing_loss',
        'validation/main/tagging_accuracy',
        'validation/main/parsing_accuracy'
    ]), trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.run()