Python model 模块,Model() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用model.Model()

项目:GeneGAN    作者:Prinsphield    | 项目源码 | 文件源码
def main():
    parser = argparse.ArgumentParser(description='test', formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument(
        '-a', '--attribute', 
        default='Smiling',
        type=str,
        help='Specify attribute name for training. \ndefault: %(default)s. \nAll attributes can be found in list_attr_celeba.txt'
    )
    parser.add_argument(
        '-g', '--gpu', 
        default='0',
        type=str,
        help='Specify GPU id. \ndefault: %(default)s. \nUse comma to seperate several ids, for example: 0,1'
    )
    args = parser.parse_args()

    celebA = Dataset(args.attribute)
    GeneGAN = Model(is_train=True)
    run(config, celebA, GeneGAN, gpu=args.gpu)
项目:OpenSAPM    作者:pathfinder14    | 项目源码 | 文件源码
def _assemble_model(self):
        """
        Create a model
        :return model
        """
        result_model = model.Model({
            "dimension" : self._dimension,
            "type" : self._type,
            "image_path" : self._image_path,
            "elasticity_quotient": self._elasticity_quotient,
            "mu_lame": self._mu_lame,
            "density": self._density,
            "v_p": self._v_p,
            "v_s": self._v_s
        }, self.GRID_SIZE)
        return result_model


    # TODO: produce different source types
项目:torch_light    作者:ne7ermore    | 项目源码 | 文件源码
def train():
    rnn.train()
    total_loss = 0
    hidden = rnn.init_hidden(args.batch_size)
    for data, label in tqdm(training_data, mininterval=1,
                desc='Train Processing', leave=False):
        optimizer.zero_grad()
        hidden = repackage_hidden(hidden)
        target, hidden = rnn(data, hidden)
        loss = criterion(target, label)

        loss.backward()
        torch.nn.utils.clip_grad_norm(rnn.parameters(), args.clip)
        optimizer.step()

        total_loss += loss.data
    return total_loss[0]/training_data.sents_size

# ##############################################################################
# Save Model
# ##############################################################################
项目:torch_light    作者:ne7ermore    | 项目源码 | 文件源码
def __init__(self, model=None, model_source=None, src_dict=None, args=None):
        assert model is not None or model_source is not None

        if model is None:
            model_source = torch.load(model_source, map_location=lambda storage, loc: storage)
            self.dict = model_source["src_dict"]
            self.args = model_source["settings"]
            model = Model(self.args)
            model.load_state_dict(model_source['model'])
        else:
            self.dict = src_dict
            self.args = args

        self.num_directions = 2 if self.args.bidirectional else 1
        self.idx2word = {v: k for k, v in self.dict.items()}
        self.model = model.eval()
项目:torch_light    作者:ne7ermore    | 项目源码 | 文件源码
def train():
    model.train()
    total_loss = 0
    for word, char, label in tqdm(training_data, mininterval=1,
                desc='Train Processing', leave=False):

        optimizer.zero_grad()
        loss, _ = model(word, char, label)
        loss.backward()

        optimizer.step()
        optimizer.update_learning_rate()
        total_loss += loss.data
    return total_loss[0]/training_data.sents_size/args.word_max_len

# ##############################################################################
# Save Model
# ##############################################################################
项目:Stereo-Pose-Machines    作者:ppwwyyxx    | 项目源码 | 文件源码
def get_runner(path):
    param_dict = np.load(path, encoding='latin1').item()
    predict_func = OfflinePredictor(PredictConfig(
        model=Model(),
        session_init=ParamRestore(param_dict),
        session_config=get_default_sess_config(0.99),
        input_names=['input'],
        #output_names=['Mconv7_stage6/output']
        output_names=['resized_map']
    ))
    def func_single(img):
        # img is bgr, [0,255]
        # return the output in WxHx15
        return predict_func([[img]])[0][0]
    def func_batch(imgs):
        # img is bgr, [0,255], nhwc
        # return the output in nhwc
        return predict_func([imgs])[0]
    return func_single, func_batch
项目:Stereo-Pose-Machines    作者:ppwwyyxx    | 项目源码 | 文件源码
def get_parallel_runner_1(path):
    param_dict = np.load(path, encoding='latin1').item()
    cfg = PredictConfig(
        model=Model(),
        session_init=ParamRestore(param_dict),
        session_config=get_default_sess_config(0.99),
        input_names=['input'],
        output_names=['resized_map']
    )
    inque = mp.Queue()
    outque = mp.Queue()
    with change_gpu(0):
        proc = MultiProcessQueuePredictWorker(1, inque, outque, cfg)
        proc.start()
    with change_gpu(1):
        pred1 = OfflinePredictor(cfg)
    def func1(img):
        inque.put((0,[[img]]))
    func1.outque = outque
    def func2(img):
        return pred1([[img]])[0][0]
    return func1, func2
项目:meinkurve    作者:michgur    | 项目源码 | 文件源码
def __init__(self, key_right='Right', key_left='Left', color='red', color2='pink',learn=False,iteration = 0,net=None):
        Player.__init__(self, key_right=key_right, key_left=key_left, color=color, color2=color2)

        self.learn = learn
        self.iteration = iteration
        self.file = None
        # if self.learn:
        # self.file = open('data.txt','w')
        # self.file.write('a,a,a,a,a,a,a,a,a,a,x,y,class\n')
        # else:

        self.net = net
        self.model = Model('data.txt')
        self.model_list = []
        self.model_list.append(self.model)
        self.radar = Radar(self, range=1000)
        if not os.path.exists('files'):
            os.makedirs('files')
项目:Dave-Godot    作者:finchMFG    | 项目源码 | 文件源码
def sample_main(args):
    model_path, config_path, vocab_path = get_paths(args.save_dir)
    # Arguments passed to sample.py direct us to a saved model.
    # Load the separate arguments by which that model was previously trained.
    # That's saved_args. Use those to load the model.
    with open(config_path, 'rb') as f:
        print(f)
        saved_args = pickle.load(f)
    # Separately load chars and vocab from the save directory.
    with open(vocab_path, 'rb') as f:
        chars, vocab = pickle.load(f)
    # Create the model from the saved arguments, in inference mode.
    print("Creating model...")
    net = Model(saved_args, True)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(net.save_variables_list())
        # Restore the saved variables, replacing the initialized values.
        print("Restoring weights...")
        saver.restore(sess, model_path)
        chatbot(net, sess, chars, vocab, args.n, args.beam_width, args.relevance, args.temperature)
        #beam_sample(net, sess, chars, vocab, args.n, args.prime,
            #args.beam_width, args.relevance, args.temperature)
项目:sequelspeare    作者:raidancampbell    | 项目源码 | 文件源码
def __init__(self, save_dir=SAVE_DIR, prime_text=PRIME_TEXT, num_sample_symbols=NUM_SAMPLE_SYMBOLS):
        self.save_dir = save_dir
        self.prime_text = prime_text
        self.num_sample_symbols = num_sample_symbols
        with open(os.path.join(Sampler.SAVE_DIR, 'chars_vocab.pkl'), 'rb') as file:
            self.chars, self.vocab = cPickle.load(file)
            self.model = Model(len(self.chars), is_sampled=True)

            # polite GPU memory allocation: don't grab everything you can.
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            config.gpu_options.allocator_type = 'BFC'
            self.sess = tf.Session(config=config)

            tf.initialize_all_variables().run(session=self.sess)
            self.checkpoint = tf.train.get_checkpoint_state(self.save_dir)
            if self.checkpoint and self.checkpoint.model_checkpoint_path:
                tf.train.Saver(tf.all_variables()).restore(self.sess, self.checkpoint.model_checkpoint_path)
项目:chatbot-rnn    作者:zenixls2    | 项目源码 | 文件源码
def sample_main(args):
    model_path, config_path, vocab_path = get_paths(args.save_dir)
    # Arguments passed to sample.py direct us to a saved model.
    # Load the separate arguments by which that model was previously trained.
    # That's saved_args. Use those to load the model.
    with open(config_path) as f:
        saved_args = cPickle.load(f)
    # Separately load chars and vocab from the save directory.
    with open(vocab_path) as f:
        chars, vocab = cPickle.load(f)
    # Create the model from the saved arguments, in inference mode.
    print("Creating model...")
    net = Model(saved_args, True)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(net.save_variables_list())
        # Restore the saved variables, replacing the initialized values.
        print("Restoring weights...")
        saver.restore(sess, model_path)
        chatbot(net, sess, chars, vocab, args.n, args.beam_width, args.relevance, args.temperature)
        #beam_sample(net, sess, chars, vocab, args.n, args.prime,
            #args.beam_width, args.relevance, args.temperature)
项目:neural-semantic-role-labeler    作者:hiroki13    | 项目源码 | 文件源码
def set_model(self):
        argv = self.argv

        #####################
        # Network variables #
        #####################
        x = T.ftensor3()
        d = T.imatrix()

        n_in = self.init_emb.shape[1]
        n_h = argv.hidden
        n_y = self.arg_dict.size()
        reg = argv.reg

        #################
        # Build a model #
        #################
        say('\n\nMODEL:  Unit: %s  Opt: %s' % (argv.unit, argv.opt))
        self.model = Model(argv=argv, x=x, y=d, n_in=n_in, n_h=n_h, n_y=n_y, reg=reg)
项目:TISP    作者:kaayy    | 项目源码 | 文件源码
def decode_sentence(kb, sentid, weightfile):
    indepkb = IndepKnowledgeBase()
    model = Model()

    parser = Parser(indepkb, kb, model, State)

    State.model = model
    State.model.weights = pickle.load(open(weightfile))
    State.ExtraInfoGen = ExprGenerator
    ExprGenerator.setup()

    ret = parser.parse(kb.questions[sentid])
    print >> LOGS, "============================="
    print >> LOGS, simplify_expr(ret.get_expr())
    print >> LOGS, "TRACING"
    for s in ret.trace_states():
        print >> LOGS, s, s.extrainfo
项目:CNN-LSTM-Caption-Generator    作者:mosessoh    | 项目源码 | 文件源码
def main(argv):
    opts, args = getopt.getopt(argv, 'i:')
    for opt, arg in opts:
        if opt == '-i':
            img_path = arg

    config = Config()
    with tf.variable_scope('CNNLSTM') as scope:
        print '-'*20
        print 'Model info'
        print '-'*20
        model = Model(config)
        print '-'*20
    saver = tf.train.Saver()

    img_vector = forward_cnn(img_path)

    with tf.Session() as session:
        save_path = best_model_dir + '/model-37'
        saver.restore(session, save_path)
        print '2 Layer LSTM loaded'
        print 'Generating caption...'
        caption = model.generate_caption(session, img_vector)
        print 'Output:', caption
项目:word-rnn-tf    作者:jtoy    | 项目源码 | 文件源码
def sample(args):
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
        chars, vocab = cPickle.load(f)
    model = Model(saved_args, True)
    val_loss_file = args.save_dir + '/val_loss.json'
    with tf.Session() as sess:
        saver = tf.train.Saver(tf.all_variables())
        if os.path.exists(val_loss_file):
            with open(val_loss_file, "r") as text_file:
                text = text_file.read()
                loss_json = json.loads(text)
                losses = loss_json.keys()
                losses.sort(key=lambda x: float(x))
                loss = losses[0]
                model_checkpoint_path =  loss_json[loss]['checkpoint_path']
                #print(model_checkpoint_path)
                saver.restore(sess, model_checkpoint_path)
                result = model.sample(sess, chars, vocab, args.n, args.prime, args.sample_rule, args.temperature)
                print(result) #add this back in later, not sure why its not working
                output = "/data/output/"+ str(int(time.time())) + ".txt"
                with open(output, "w") as text_file:
                    text_file.write(result)
                print(output)
项目:olive-gui    作者:dturevski    | 项目源码 | 文件源码
def openCollection(self, fileName):

        try:
            f = open(unicode(fileName), 'r')
            Mainframe.model = model.Model()
            Mainframe.model.delete(0)
            for data in yaml.load_all(f):
                Mainframe.model.add(model.makeSafe(data), False)
            f.close()
            Mainframe.model.is_dirty = False
        except IOError:
            msgBox(Lang.value('MSG_IO_failed'))
            Mainframe.model = model.Model()
        except yaml.YAMLError as e:
            msgBox(Lang.value('MSG_YAML_failed') % e)
            Mainframe.model = model.Model()
        else:
            if len(Mainframe.model.entries) == 0:
                Mainframe.model = model.Model()
            Mainframe.model.filename = unicode(fileName)
        finally:
            Mainframe.sigWrapper.sigModelChanged.emit()
项目:olive-gui    作者:dturevski    | 项目源码 | 文件源码
def __init__(self):
        super(Mainframe, self).__init__()

        Mainframe.model = model.Model()

        self.initLayout()
        self.initActions()
        self.initMenus()
        self.initToolbar()
        self.initSignals()
        self.initFrame()

        self.updateTitle()
        self.overview.rebuild()
        self.show()

        if Conf.value('check-for-latest-binary'):
            self.checkNewVersion = Mainframe.CheckNewVersion(self)
            self.checkNewVersion.start()
项目:olive-gui    作者:dturevski    | 项目源码 | 文件源码
def onImportCcv(self):
        if not self.doDirtyCheck():
            return
        default_dir = './collections/'
        if Mainframe.model.filename != '':
            default_dir, tail = os.path.split(Mainframe.model.filename)
        fileName, encoding = self.getOpenFileNameAndEncoding(
            Lang.value('MI_Import_CCV'), default_dir, "(*.ccv)")
        if not fileName:
            return
        try:
            Mainframe.model = model.Model()
            Mainframe.model.delete(0)
            for data in fancy.readCvv(fileName, encoding):
                Mainframe.model.add(model.makeSafe(data), False)
            Mainframe.model.is_dirty = False
        except IOError:
            msgBox(Lang.value('MSG_IO_failed'))
        except:
            msgBox(Lang.value('MSG_CCV_import_failed'))
        finally:
            if len(Mainframe.model.entries) == 0:
                Mainframe.model = model.Model()
            self.overview.rebuild()
            Mainframe.sigWrapper.sigModelChanged.emit()
项目:ChineseNER    作者:zjy-ucas    | 项目源码 | 文件源码
def evaluate_line():
    config = load_config(FLAGS.config_file)
    logger = get_logger(FLAGS.log_file)
    # limit GPU memory
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    with open(FLAGS.map_file, "rb") as f:
        char_to_id, id_to_char, tag_to_id, id_to_tag = pickle.load(f)
    with tf.Session(config=tf_config) as sess:
        model = create_model(sess, Model, FLAGS.ckpt_path, load_word2vec, config, id_to_char, logger)
        while True:
            # try:
            #     line = input("???????:")
            #     result = model.evaluate_line(sess, input_from_line(line, char_to_id), id_to_tag)
            #     print(result)
            # except Exception as e:
            #     logger.info(e)

                line = input("???????:")
                result = model.evaluate_line(sess, input_from_line(line, char_to_id), id_to_tag)
                print(result)
项目:LSTM-CRF-For-Named-Entity-Recognition    作者:zpppy    | 项目源码 | 文件源码
def evaluate_line():
    config = load_config(FLAGS.config_file)
    logger = get_logger(FLAGS.log_file)
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    with open(FLAGS.map_file, "rb") as f:
        char_to_id, id_to_char, tag_to_id, id_to_tag = pickle.load(f)
    with tf.Session(config=tf_config) as sess:
        model = create_model(sess, Model, FLAGS.ckpt_path, load_word2vec, config, id_to_char, logger)
        while True:
            # try:
            #     line = input("???????:")
            #     result = model.evaluate_line(sess, input_from_line(line, char_to_id), id_to_tag)
            #     print(result)
            # except Exception as e:
            #     logger.info(e)

                line = input("???????:")
                result = model.evaluate_line(sess, input_from_line(line, char_to_id), id_to_tag)
                print(result)
项目:adversarial-autoencoder    作者:musyoku    | 项目源码 | 文件源码
def plot_scatter():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", type=str, default="model.hdf5")
    args = parser.parse_args()

    dataset_train, dataset_test = chainer.datasets.get_mnist()
    images_train, labels_train = dataset_train._datasets
    images_test, labels_test = dataset_test._datasets

    model = Model()
    assert model.load(args.model)

    # normalize
    images_train = (images_train - 0.5) * 2
    images_test = (images_test - 0.5) * 2

    with chainer.no_backprop_mode() and chainer.using_config("train", False):
        z = model.encode_x_yz(images_test)[1].data
    plot.scatter_labeled_z(z, labels_test, "scatter_gen.png")
项目:adversarial-autoencoder    作者:musyoku    | 项目源码 | 文件源码
def plot_representation():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", type=str, default="model.hdf5")
    args = parser.parse_args()

    dataset_train, dataset_test = chainer.datasets.get_mnist()
    images_train, labels_train = dataset_train._datasets
    images_test, labels_test = dataset_test._datasets

    model = Model()
    assert model.load(args.model)

    # normalize
    images_train = (images_train - 0.5) * 2
    images_test = (images_test - 0.5) * 2

    with chainer.no_backprop_mode() and chainer.using_config("train", False):
        y_onehot, z = model.encode_x_yz(images_test, apply_softmax_y=True)
        representation = model.encode_yz_representation(y_onehot, z).data
    plot.scatter_labeled_z(representation, labels_test, "scatter_r.png")
项目:adversarial-autoencoder    作者:musyoku    | 项目源码 | 文件源码
def plot_z():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", type=str, default="model.hdf5")
    args = parser.parse_args()

    dataset_train, dataset_test = chainer.datasets.get_mnist()
    images_train, labels_train = dataset_train._datasets
    images_test, labels_test = dataset_test._datasets

    model = Model()
    assert model.load(args.model)

    # normalize
    images_train = (images_train - 0.5) * 2
    images_test = (images_test - 0.5) * 2

    with chainer.no_backprop_mode() and chainer.using_config("train", False):
        z = model.encode_x_yz(images_test)[1].data
    plot.scatter_labeled_z(z, labels_test, "scatter_z.png")
项目:adversarial-autoencoder    作者:musyoku    | 项目源码 | 文件源码
def plot_scatter():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", type=str, default="model.hdf5")
    args = parser.parse_args()

    dataset_train, dataset_test = chainer.datasets.get_mnist()
    images_train, labels_train = dataset_train._datasets
    images_test, labels_test = dataset_test._datasets

    model = Model()
    assert model.load(args.model)

    # normalize
    images_train = (images_train - 0.5) * 2
    images_test = (images_test - 0.5) * 2

    with chainer.no_backprop_mode() and chainer.using_config("train", False):
        z = model.encode_x_z(images_test).data
    plot.scatter_labeled_z(z, labels_test, "scatter_gen.png")
项目:adversarial-autoencoder    作者:musyoku    | 项目源码 | 文件源码
def plot_scatter():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", type=str, default="model.hdf5")
    args = parser.parse_args()

    dataset_train, dataset_test = chainer.datasets.get_mnist()
    images_train, labels_train = dataset_train._datasets
    images_test, labels_test = dataset_test._datasets

    model = Model()
    assert model.load(args.model)

    # normalize
    images_train = (images_train - 0.5) * 2
    images_test = (images_test - 0.5) * 2

    with chainer.no_backprop_mode() and chainer.using_config("train", False):
        z = model.encode_x_z(images_test).data
    plot.scatter_labeled_z(z, labels_test, "scatter_z.png")
项目:adversarial-autoencoder    作者:musyoku    | 项目源码 | 文件源码
def plot_scatter():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", type=str, default="model.hdf5")
    args = parser.parse_args()

    dataset_train, dataset_test = chainer.datasets.get_mnist()
    images_train, labels_train = dataset_train._datasets
    images_test, labels_test = dataset_test._datasets

    model = Model()
    assert model.load(args.model)

    # normalize
    images_train = (images_train - 0.5) * 2
    images_test = (images_test - 0.5) * 2

    with chainer.no_backprop_mode() and chainer.using_config("train", False):
        z = model.encode_x_yz(images_test)[1].data
    plot.scatter_labeled_z(z, labels_test, "scatter_gen.png")
项目:adversarial-autoencoder    作者:musyoku    | 项目源码 | 文件源码
def plot_z():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", type=str, default="model.hdf5")
    args = parser.parse_args()

    dataset_train, dataset_test = chainer.datasets.get_mnist()
    images_train, labels_train = dataset_train._datasets
    images_test, labels_test = dataset_test._datasets

    model = Model()
    assert model.load(args.model)

    # normalize
    images_train = (images_train - 0.5) * 2
    images_test = (images_test - 0.5) * 2

    with chainer.no_backprop_mode() and chainer.using_config("train", False):
        z = model.encode_x_yz(images_test)[1].data
    plot.scatter_labeled_z(z, labels_test, "scatter_z.png")
项目:char-rnn-tensorflow-master    作者:JDonnelly1    | 项目源码 | 文件源码
def sample(args):
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
        chars, vocab = cPickle.load(f)
    model = Model(saved_args, training=False)
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            ret, hidden = model.sample(sess, chars, vocab, args.n, args.prime,
                               args.sample)#.encode('utf-8'))
            print("Number of characters generated: ", len(ret))

            for i in range(len(ret)):
                print("Generated character: ", ret[i])
                print("Assosciated hidden state:" , hidden[i])
项目:keras_zoo    作者:david-vazquez    | 项目源码 | 文件源码
def make_discriminator(self):
        # TODO just to have something, 5 layers vgg-like
        inputs = Input(shape=self.img_shape)
        enc1 = self.downsampling_block_basic(inputs, 64, 7)
        enc2 = self.downsampling_block_basic(enc1,   64, 7)
        enc3 = self.downsampling_block_basic(enc2,   92, 7)
        enc4 = self.downsampling_block_basic(enc3,  128, 7)
        enc5 = self.downsampling_block_basic(enc4,  128, 7)
        flat = Flatten()(enc5)
        dense1 = Dense(512, activation='sigmoid')(flat)
        dense2 = Dense(512, activation='sigmoid')(dense1)
        fake = Dense(1, activation='sigmoid', name='generation')(dense2)
        # Dense(2,... two classes : real and fake
        # change last activation to softmax ?
        discriminator = kmodels.Model(input=inputs, output=fake)

        lr = 1e-04
        optimizer = RMSprop(lr=lr, rho=0.9, epsilon=1e-8, clipnorm=10)
        print ('   Optimizer discriminator: rmsprop. Lr: {}. Rho: 0.9, epsilon=1e-8, '
               'clipnorm=10'.format(lr))

        discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
        # TODO metrics=metrics,
        return discriminator
项目:keras_zoo    作者:david-vazquez    | 项目源码 | 文件源码
def make_gan(self, img_shape, optimizer,
                 the_loss='categorical_crossentropy', metrics=[]):
        # Build stacked GAN model
        gan_input = Input(shape=img_shape)
        H = self.generator(gan_input)
        gan_V = self.discriminator(H)
        GAN = kmodels.Model(gan_input, gan_V)

        # Compile model
        GAN.compile(loss=the_loss, metrics=metrics, optimizer=optimizer)

        # Show model
        if self.cf.show_model:
            print('GAN')
            GAN.summary()
            plot(GAN, to_file=os.path.join(self.cf.savepath, 'model_GAN.png'))

        return GAN

    # Make the network trainable or not
项目:jarvis    作者:whittlbc    | 项目源码 | 文件源码
def prep_for_app_use(self):
        self.args = Args(None).parse_args()
        self.args.test = TestMode.DAEMON

        self.load_model_params()
        self.text_data = TextData(self.args)

        with tf.device(self.get_device()):
            self.model = Model(self.args, self.text_data)

        self.writer = tf.train.SummaryWriter(self._get_summary_name())
        self.saver = tf.train.Saver(max_to_keep=200)

        self.sess = tf.Session()
        self.sess.run(tf.initialize_all_variables())
        self.manage_previous_model(self.sess)

    # Training initialization
项目:token-rnn-tensorflow    作者:aalmendoza    | 项目源码 | 文件源码
def evaluate(args):
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        (saved_args, reverse_input) = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'token_vocab.pkl'), 'rb') as f:
        tokens, vocab = cPickle.load(f)
    model = Model(saved_args, reverse_input, True)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            if args.pre_tokenized:
                with open(args.source, 'r') as f:
                    token_list = f.read().split()
            else:
                token_list = get_tokens(args.source, args.language)

            token_list = convert_to_vocab_tokens(vocab, token_list, model.start_token,
                model.end_token, model.unk_token)
            probs = model.evaluate(sess, tokens, vocab, token_list)
项目:token-rnn-tensorflow    作者:aalmendoza    | 项目源码 | 文件源码
def evaluate(args):
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        (saved_args, reverse_input) = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'token_vocab.pkl'), 'rb') as f:
        tokens, vocab = cPickle.load(f)
    model = Model(saved_args, reverse_input, True)

    if args.pre_tokenized:
        with open(args.source, 'r') as f:
            token_list = f.read().split()
    else:
        token_list = get_tokens(args.source, args.language)

    token_list = convert_to_vocab_tokens(vocab, token_list, model.start_token,
        model.end_token, model.unk_token)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            entropy_list = model.get_entropy_per_token(sess, vocab, token_list)
            display_results(token_list, entropy_list)
项目:tf_rnnlm    作者:Ubiqus    | 项目源码 | 文件源码
def Model(self, *args, **kwargs):
    model_class = RnnlmOp.MODELS[self.model]
    return model_class(*args, **kwargs)
项目:lstm-poetry    作者:dvictor    | 项目源码 | 文件源码
def train():
    cleanup.cleanup()
    c.save(c.work_dir)

    data_loader = TextLoader(c.work_dir, c.batch_size, c.seq_length)
    with open(os.path.join(c.work_dir, 'chars_vocab.pkl'), 'wb') as f:
        cPickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(c.rnn_size, c.num_layers, len(data_loader.chars), c.grad_clip, c.batch_size, c.seq_length)

    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        for e in range(c.num_epochs):
            sess.run(tf.assign(model.lr, c.learning_rate * (c.decay_rate ** e)))
            data_loader.reset_batch_pointer()
            state = model.initial_state.eval()
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                end = time.time()
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                    .format(e * data_loader.num_batches + b,
                            c.num_epochs * data_loader.num_batches,
                            e, train_loss, end - start))
                if (e * data_loader.num_batches + b) % c.save_every == 0:
                    checkpoint_path = os.path.join(c.work_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))
项目:miniluv    作者:fsantovito    | 项目源码 | 文件源码
def test_observed_should_not_share_observers(self):

        model2 = Model()
        self.assertFalse(model._Observable__observers is model2._Observable__observers)
项目:hostapd-mana    作者:adde88    | 项目源码 | 文件源码
def setUp(self, request, node, data):
        """
        Override this method to set up your Widget prior to generateDOM. This
        is a good place to call methods like L{add}, L{insert}, L{__setitem__}
        and L{__getitem__}.

        Overriding this method obsoletes overriding generateDOM directly, in
        most cases.

        @type request: L{twisted.web.server.Request}.
        @param node: The DOM node which this Widget is operating on.
        @param data: The Model data this Widget is meant to operate upon.
        """
        pass
项目:hostapd-mana    作者:adde88    | 项目源码 | 文件源码
def __init__(self, model, raw=0, clear=1, *args, **kwargs):
        """
        @param model: The text to render.
        @type model: A string or L{model.Model}.
        @param raw: A boolean that specifies whether to render the text as
              a L{domhelpers.RawText} or as a DOM TextNode.
        """
        self.raw = raw
        self.clearNode = clear
        Widget.__init__(self, model, *args, **kwargs)
项目:kor-char-rnn-tensorflow    作者:insikk    | 项目源码 | 文件源码
def sample(args):
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
        chars, vocab = cPickle.load(f)
    model = Model(saved_args, training=False)
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print(model.sample(sess, chars, vocab, args.n, args.prime,
                               args.sample))
项目:sketch_rnn_classification    作者:payalbajaj    | 项目源码 | 文件源码
def trainer(model_params):
  """Train a sketch-rnn model."""
  np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)

  tf.logging.info('sketch-rnn')
  tf.logging.info('Hyperparams:')
  for key, val in model_params.values().iteritems():
    tf.logging.info('%s = %s', key, str(val))
  tf.logging.info('Loading data files.')
  datasets = load_dataset(FLAGS.data_dir, model_params)

  train_set = datasets[0]
  valid_set = datasets[1]
  test_set = datasets[2]
  model_params = datasets[3]
  eval_model_params = datasets[4]

  reset_graph()
  model = sketch_rnn_model.Model(model_params)
  eval_model = sketch_rnn_model.Model(eval_model_params, reuse=True)

  sess = tf.InteractiveSession()
  sess.run(tf.global_variables_initializer())

  if FLAGS.resume_training:
    load_checkpoint(sess, FLAGS.log_root)

  # Write config file to json file.
  tf.gfile.MakeDirs(FLAGS.log_root)
  with tf.gfile.Open(
      os.path.join(FLAGS.log_root, 'model_config.json'), 'w') as f:
    json.dump(model_params.values(), f, indent=True)

  train(sess, model, eval_model, train_set, valid_set, test_set)
项目:ConditionalGAN    作者:seungjooli    | 项目源码 | 文件源码
def main(argv):
    m = model.Model(FLAGS.log_dir, FLAGS.ckpt_dir, FLAGS.load_ckpt, FLAGS.input_height, FLAGS.input_width)
    if FLAGS.mode == 'train':
        train(m)
    elif FLAGS.mode == 'test':
        test(m)
    else:
        print('Unexpected mode: {}  Choose \'train\' or \'test\''.format(FLAGS.mode))
    m.close()
项目:optimove    作者:nicolasramy    | 项目源码 | 文件源码
def __init__(self, username=None, password=None, timeout=30):
        self.general = General(self)
        self.model = Model(self)
        self.actions = Actions(self)
        self.groups = Groups(self)
        self.customers = Customers(self)
        self.segments = Segments(self)
        self.integrations = Integrations(self)
        self.timeout = timeout

        if username and password:
            self.general.login(username, password)
项目:Tree-LSTM-LM    作者:vgene    | 项目源码 | 文件源码
def sample(args):
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
        chars, vocab = cPickle.load(f)
    model = Model(saved_args, training=False)
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print(model.sample(sess, chars, vocab, args.n, args.prime.decode('utf-8'),
                               args.sample).encode('utf-8'))
项目:Stereo-Pose-Machines    作者:ppwwyyxx    | 项目源码 | 文件源码
def get_parallel_runner(path):
    param_dict = np.load(path, encoding='latin1').item()
    cfg = PredictConfig(
        model=Model(),
        session_init=ParamRestore(param_dict),
        session_config=get_default_sess_config(1.0),
        input_names=['input'],
        output_names=['resized_map']
    )
    predictor = DataParallelOfflinePredictor(cfg, [0,1])

    def func(im1, im2):
        o = predictor([[im1], [im2]])
        return o[0][0], o[1][0]
    return func
项目:meinkurve    作者:michgur    | 项目源码 | 文件源码
def place(self):
        print 'Game Index:',self.root.gameindex






        # --------DO NOT DELETE!!!-------

        if self.iteration < 10:
            a = np.random.rand(10, 10)
            b = np.random.rand(10)
            c = np.random.rand(10, 3)
            d = np.random.rand(3)
            self.net = neural.Net(a, b, c, d, 1)
        # else:
        #     if self.file:
        #         self.file.close()
        #         bisect.insort_left(self.model_list, Model('files/data%s.txt' % self.iteration))
        #         # self.model_list.append(Model('files/data%s.txt' % self.iteration))
        #     if self.iteration%20 == 0:
        #         for i in self.model_list:
        #             print i.size
        #         for i in range(10):
        #             os.remove(self.model_list[0].file)
        #             del self.model_list[0]

        self.iteration += 1


        Player.place(self)
项目:YellowFin    作者:JianGoForIt    | 项目源码 | 文件源码
def sample(args):
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
        chars, vocab = cPickle.load(f)
    model = Model(saved_args, training=False)
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print(model.sample(sess, chars, vocab, args.n, args.prime,
                               args.sample).encode('utf-8'))
项目:adaptivemd    作者:markovmodel    | 项目源码 | 文件源码
def initialize(self, resource):
        """
        Initialize a project with a specific resource.

        Notes
        -----
        This should only be called to setup the project and only the very
        first time.

        Parameters
        ----------
        resource : `Resource`
            the resource used in this project

        """
        self.storage.close()

        self.resource = resource

        st = MongoDBStorage(self.name, 'w')
        # st.create_store(ObjectStore('objs', None))
        st.create_store(ObjectStore('generators', TaskGenerator))
        st.create_store(ObjectStore('files', File))
        st.create_store(ObjectStore('resources', Resource))
        st.create_store(ObjectStore('models', Model))
        st.create_store(ObjectStore('tasks', Task))
        st.create_store(ObjectStore('workers', Worker))
        st.create_store(ObjectStore('logs', LogEntry))
        st.create_store(FileStore('data', DataDict))
        # st.create_store(ObjectStore('commands', Command))

        st.save(self.resource)

        st.close()

        self._open_db()
项目:KGP-ASR    作者:KGPML    | 项目源码 | 文件源码
def loadCLM(sess):
    with open('/users/TeamASR/char-rnn-tensorflow/save/config.pkl', 'rb') as f:
        saved_args = cPickle.load(f)
    with open('/users/TeamASR/char-rnn-tensorflow/save/chars_vocab.pkl', 'rb') as f:
        chars, vocab = cPickle.load(f)
    model = Model(saved_args, True)
    sess.run(tf.initialize_all_variables())
    saver = tf.train.Saver(tf.all_variables())
    ckpt = tf.train.get_checkpoint_state('/users/TeamASR/char-rnn-tensorflow/save')
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
    return model,chars,vocab
项目:Dave-Godot    作者:finchMFG    | 项目源码 | 文件源码
def process_user_command(user_input, states, relevance, temperature, beam_width):
    user_command_entered = False
    reset = False
    try:
        if user_input.startswith('--temperature '):
            user_command_entered = True
            temperature = max(0.001, float(user_input[len('--temperature '):]))
            print("[Temperature set to {}]".format(temperature))
        elif user_input.startswith('--relevance '):
            user_command_entered = True
            new_relevance = float(user_input[len('--relevance '):])
            if relevance <= 0. and new_relevance > 0.:
                states = [states, copy.deepcopy(states)]
            elif relevance > 0. and new_relevance <= 0.:
                states = states[0]
            relevance = new_relevance
            print("[Relevance disabled]" if relevance < 0. else "[Relevance set to {}]".format(relevance))
        elif user_input.startswith('--beam_width '):
            user_command_entered = True
            beam_width = max(1, int(user_input[len('--beam_width '):]))
            print("[Beam width set to {}]".format(beam_width))
        elif user_input.startswith('--reset'):
            user_command_entered = True
            reset = True
            print("[Model state reset]")
    except ValueError:
        print("[Value error with provided argument.]")
    return user_command_entered, reset, states, relevance, temperature, beam_width
项目:MENGEL    作者:CodeSpaceHQ    | 项目源码 | 文件源码
def save(self, filename='-1'):
        """ Saves the configuration to an XML file output"""
        if filename == '-1':
            filename = self.config_file_name

        #Update models
        for model_xml in self.root.iter('Model'):
            model = self.models[model_xml.get('name')]
            for param_xml in model_xml.iter('Param'):
                param = model.params[param_xml.get('name')]
                for detail, value in param.details.items():
                    param_xml.set(detail, value)
        self.tree.write(filename)