Python chainer.training.extensions 模块,PlotReport() 实例源码

我们从Python开源项目中,提取了以下6个代码示例,用于说明如何使用chainer.training.extensions.PlotReport()

项目:NlpUtil    作者:trtd56    | 项目源码 | 文件源码
def set_trainer(self, out_dir, gpu, n_epoch, g_clip, opt_name, lr=None):
        if opt_name == "Adam":
            opt = getattr(optimizers, opt_name)()
        else:
            opt = getattr(optimizers, opt_name)(lr)
        opt.setup(self.model)
        opt.add_hook(optimizer.GradientClipping(g_clip))

        updater = training.StandardUpdater(self.train_iter, opt, device=gpu)
        self.trainer = training.Trainer(updater, (n_epoch, 'epoch'), out=out_dir)
        self.trainer.extend(extensions.Evaluator(self.test_iter, self.model, device=gpu))
        self.trainer.extend(extensions.dump_graph('main/loss'))
        self.trainer.extend(extensions.snapshot(), trigger=(n_epoch, 'epoch'))
        self.trainer.extend(extensions.LogReport())
        self.trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                                   'epoch', file_name='loss.png'))
        self.trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'],
                                                   'epoch', file_name='accuracy.png'))
        self.trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss',
                                                    'main/accuracy', 'validation/main/accuracy',
                                                    'elapsed_time']))
        self.trainer.extend(extensions.ProgressBar())
项目:chainer-pspnet    作者:mitmul    | 项目源码 | 文件源码
def __init__(self, **kwargs):
        required_keys = []
        optional_keys = [
            'dump_graph',
            'Evaluator',
            'ExponentialShift',
            'LinearShift',
            'LogReport',
            'observe_lr',
            'observe_value',
            'snapshot',
            'PlotReport',
            'PrintReport',
        ]
        super().__init__(
            required_keys, optional_keys, kwargs, self.__class__.__name__)
项目:chainer-pspnet    作者:mitmul    | 项目源码 | 文件源码
def __init__(self, **kwargs):
        required_keys = []
        optional_keys = [
            'dump_graph',
            'Evaluator',
            'ExponentialShift',
            'LinearShift',
            'LogReport',
            'observe_lr',
            'observe_value',
            'snapshot',
            'PlotReport',
            'PrintReport',
        ]
        super().__init__(
            required_keys, optional_keys, kwargs, self.__class__.__name__)
项目:chainer_sklearn    作者:corochann    | 项目源码 | 文件源码
def fit(self, X, y=None, **kwargs):
        """If hyper parameters are set to None, then instance's variable is used,
        this functionality is used Grid search with `set_params` method.
        Also if instance's variable is not set, _default_hyperparam is used. 

        Usage: model.fit(train_dataset) or model.fit(X, y)

        Args:
            train: training dataset, assumes chainer's dataset class 
            test: test dataset for evaluation, assumes chainer's dataset class
            batchsize: batchsize for both training and evaluation
            iterator_class: iterator class used for this training, 
                            currently assumes SerialIterator or MultiProcessIterator
            optimizer: optimizer instance to update parameter
            epoch: training epoch
            out: directory path to save the result
            snapshot_frequency (int): snapshot frequency in epoch. 
                                Negative value indicates not to take snapshot.
            dump_graph: Save computational graph info or not, default is False.
            log_report: Enable LogReport or not
            plot_report: Enable PlotReport or not
            print_report: Enable PrintReport or not
            progress_report: Enable ProgressReport or not
            resume: specify trainer saved path to resume training.

        """
        kwargs = self.filter_sk_params(self.fit_core, kwargs)
        return self.fit_core(X, y, **kwargs)
项目:chainer-EWC    作者:okdshin    | 项目源码 | 文件源码
def train_task(args, train_name, model, epoch_num,
               train_dataset, test_dataset_dict, batch_size):
    optimizer = optimizers.SGD()
    optimizer.setup(model)

    train_iter = iterators.SerialIterator(train_dataset, batch_size)
    test_iter_dict = {name: iterators.SerialIterator(
            test_dataset, batch_size, repeat=False, shuffle=False)
            for name, test_dataset in test_dataset_dict.items()}

    updater = training.StandardUpdater(train_iter, optimizer)
    trainer = training.Trainer(updater, (epoch_num, 'epoch'), out=args.out)
    for name, test_iter in test_iter_dict.items():
        trainer.extend(extensions.Evaluator(test_iter, model), name)
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss'] +
        [test+'/main/loss' for test in test_dataset_dict.keys()] +
        ['main/accuracy'] +
        [test+'/main/accuracy' for test in test_dataset_dict.keys()]))
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.PlotReport(
        [test+"/main/accuracy" for test
         in test_dataset_dict.keys()],
        file_name=train_name+".png"))
    trainer.run()
项目:char-rnn-text-generation    作者:yxtay    | 项目源码 | 文件源码
def train_main(args):
    """
    trains model specfied in args.
    main method for train subcommand.
    """
    # load text
    with open(args.text_path) as f:
        text = f.read()
    logger.info("corpus length: %s.", len(text))

    # data iterator
    data_iter = DataIterator(text, args.batch_size, args.seq_len)

    # load or build model
    if args.restore:
        logger.info("restoring model.")
        load_path = args.checkpoint_path if args.restore is True else args.restore
        model = load_model(load_path)
    else:
        net = Network(vocab_size=VOCAB_SIZE,
                      embedding_size=args.embedding_size,
                      rnn_size=args.rnn_size,
                      num_layers=args.num_layers,
                      drop_rate=args.drop_rate)
        model = L.Classifier(net)

    # make checkpoint directory
    log_dir = make_dirs(args.checkpoint_path)
    with open("{}.json".format(args.checkpoint_path), "w") as f:
        json.dump(model.predictor.args, f, indent=2)
    chainer.serializers.save_npz(args.checkpoint_path, model)
    logger.info("model saved: %s.", args.checkpoint_path)

    # optimizer
    optimizer = chainer.optimizers.Adam(alpha=args.learning_rate)
    optimizer.setup(model)
    # clip gradient norm
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.clip_norm))

    # trainer
    updater = BpttUpdater(data_iter, optimizer)
    trainer = chainer.training.Trainer(updater, (args.num_epochs, 'epoch'), out=log_dir)
    trainer.extend(extensions.snapshot_object(model, filename=os.path.basename(args.checkpoint_path)))
    trainer.extend(extensions.ProgressBar(update_interval=1))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PlotReport(y_keys=["main/loss"]))
    trainer.extend(LoggerExtension(text))

    # training start
    model.predictor.reset_state()
    logger.info("start of training.")
    time_train = time.time()
    trainer.run()

    # training end
    duration_train = time.time() - time_train
    logger.info("end of training, duration: %ds.", duration_train)
    # generate text
    seed = generate_seed(text)
    generate_text(model, seed, 1024, 3)
    return model