Python chainer 模块,serializers() 实例源码

我们从Python开源项目中,提取了以下4个代码示例,用于说明如何使用chainer.serializers()

项目:lencon    作者:kiyukuta    | 项目源码 | 文件源码
def _build_model(self, config, src_vocab, trg_vocab):

        def convert(val):
            if val.isdigit():
                return int(val)
            try:
                return float(val)
            except:
                return val
        model_config = config['Model']

        kwargs = {k: convert(v) for k, v in model_config.items() if k != 'name'}
        m = getattr(models, model_config['name'])(**kwargs)

        model_path = os.path.join(self.save_dir, 'model.hdf')
        # load
        if os.path.exists(model_path):
            chainer.serializers.load_hdf5(model_path, m)

        xstoi = src_vocab.stoi
        ystoi = trg_vocab.stoi
        xbos = xstoi('<s>')
        xeos = xstoi('</s>')
        ybos = ystoi('<s>')
        yeos = ystoi('</s>')
        m.set_symbols(xbos, xeos, ybos, yeos)

        m.name = model_config['name']
        m.byte = self._load_binary_config(config['Training'], 'byte')
        m.reverse_output = self._load_binary_config(
            config['Training'], 'reverse_output')
        if m.byte:
            m.vocab = trg_vocab
        return m
项目:lencon    作者:kiyukuta    | 项目源码 | 文件源码
def save(self):
        save_dir = self.save_dir
        m = self.model.copy()
        m.name = self.model.name
        m.to_cpu()

        model_path = os.path.join(save_dir, 'model.hdf')
        chainer.serializers.save_hdf5(model_path, m)
        with open(os.path.join(save_dir, "vocab.pkl"), "wb") as f:
            pickle.dump((self.src_vcb, self.trg_vcb), f)
项目:chainer_nmt    作者:odashi    | 项目源码 | 文件源码
def load_params(prefix, mdl, opt):
  logger = logging.getLogger(__name__)

  logger.info('Loading model/optimizer parameters')
  chainer.serializers.load_npz(prefix + '.mdl', mdl)
  chainer.serializers.load_npz(prefix + '.opt', opt)
项目:chainer_nmt    作者:odashi    | 项目源码 | 文件源码
def save_params(prefix, mdl, opt):
  logger = logging.getLogger(__name__)

  logger.info('Saving model/optimizer parameters')
  chainer.serializers.save_npz(prefix + '.mdl', mdl)
  chainer.serializers.save_npz(prefix + '.opt', opt)