Python utils 模块,save_model() 实例源码

我们从Python开源项目中,提取了以下2个代码示例,用于说明如何使用utils.save_model()

项目:kaggle-quora-solution-8th    作者:qqgeogor    | 项目源码 | 文件源码
def train():
    log.info('loading dataset...')
    train_data=TextIterator(train_file,n_batch=batch_size,maxlen=maxlen)
    valid_data = TextIterator(valid_file, n_batch=batch_size, maxlen=maxlen)
    test_data = TextIterator(test_file, n_batch=batch_size, maxlen=maxlen,mode=2)
    log.info('building models....')
    model=RCNNModel(n_input=n_input,n_vocab=VOCABULARY_SIZE,n_hidden=n_hidden,cell='gru',  optimizer=optimizer,dropout=dropout,sim=sim,maxlen=maxlen,batch_size=batch_size)
    start=time.time()

    if os.path.isfile(model_dir):
        print 'loading checkpoint parameters....',model_dir
        model=load_model(model_dir,model)
    if goto_line!=0:
        train_data.goto_line(goto_line)
        print 'goto line:',goto_line

    log.info('training start...')
    for epoch in xrange(NEPOCH):
        costs=0
        idx=0
        error_rate_list=[]
        try:
            for (x,xmask),(y,ymask),label in train_data:
                idx+=1
                if x.shape[-1]!=batch_size:
                    continue
                cost,error_rate=model.train(x,xmask,y,ymask,label,lr)
                #print cost,error_rate
                #projected_output,cost= model.test(x, xmask, y, ymask,label)
                #print "projected_output shape:", projected_output.shape
                ##print "cnn_output shape:",cnn_output.shape
                #print "cost:",cost
                costs+=cost
                error_rate_list.append(error_rate)
                if np.isnan(cost) or np.isinf(cost):
                    print 'Nan Or Inf detected!'
                    print x.shape,y.shape
                    print cost,error_rate
                    return  -1
                if idx % disp_freq==0:
                    log.info('epoch: %d, idx: %d cost: %.3f, Accuracy: %.3f '%(epoch,idx,costs/idx,  np.mean(list(itertools.chain.from_iterable(error_rate_list)))))

                if idx%dump_freq==0:
                    save_model('./model/parameters_%.2f.pkl'%(time.time()-start),model)
        except Exception:
            print np.max(x),np.max(y)
            print x.shape,y.shape



        evaluate(train_data,valid_data, test_data,model)

    log.info("Finished. Time = " +str(time.time()-start))
项目:kaggle-quora-solution-8th    作者:qqgeogor    | 项目源码 | 文件源码
def test():
    log.info('loading dataset...')

    log.info('building models....')
    model=RCNNModel(n_input=n_input,n_vocab=VOCABULARY_SIZE,n_hidden=n_hidden,cell='gru',optimizer=optimizer,dropout=dropout,sim=sim,maxlen=maxlen,batch_size=batch_size)
    log.info('training start....')
    start=time.time()

    if os.path.isfile(model_dir):
        print 'loading checkpoint parameters....',model_dir
        model=load_model(model_dir,model)


    for epoch in xrange(NEPOCH):
        costs=[]
        idx=0
        acc_list=[]
        train_data = TextIterator(train_file+".train."+str(epoch), n_batch=batch_size, maxlen=maxlen)
        valid_data = TextIterator(train_file+".valid."+str(epoch), n_batch=batch_size, maxlen=maxlen)
        for (x,xmask),(y,ymask),label in train_data:
            idx+=1
            if x.shape[-1]!=batch_size:
                continue
            #print x.shape
            cost,acc=model.predict(x,xmask,y,ymask,label)
            #print cost
            #projected_output,cost= model.test(x, xmask, y, ymask,label)
            #print "projected_output shape:", projected_output.shape
            ##print "cnn_output shape:",cnn_output.shape
            #print "cost:",cost
            costs.append(cost)
            acc_list.append(acc)
            if np.isnan(np.mean(cost)) or np.isinf(np.mean(cost)):
                print 'Nan Or Inf detected!'
                print "x:",x
                print x.shape
                print 'y:',y
                print y.shape
                return  -1
        #log.info('dumping parameters....')    
        #save_model('./model/parameters_%.2f.pkl'%(time.time()-start),model)

        log.info('epoch: %d, cost: %.3f, Accuracy: %.3f ' % (
        epoch,np.mean(list(itertools.chain.from_iterable(costs))), np.mean(list(itertools.chain.from_iterable(acc_list)))))
        loss, acc = evaluate(valid_data, model)
        log.info('validation cost: %.3f, Accuracy: %.3f' % (loss,acc))

    log.info("Finished. Time = " +str(time.time()-start))