Python dataset 模块,Dataset() 实例源码

我们从Python开源项目中,提取了以下28个代码示例,用于说明如何使用dataset.Dataset()

项目:GeneGAN    作者:Prinsphield    | 项目源码 | 文件源码
def main():
    parser = argparse.ArgumentParser(description='test', formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument(
        '-a', '--attribute', 
        default='Smiling',
        type=str,
        help='Specify attribute name for training. \ndefault: %(default)s. \nAll attributes can be found in list_attr_celeba.txt'
    )
    parser.add_argument(
        '-g', '--gpu', 
        default='0',
        type=str,
        help='Specify GPU id. \ndefault: %(default)s. \nUse comma to seperate several ids, for example: 0,1'
    )
    args = parser.parse_args()

    celebA = Dataset(args.attribute)
    GeneGAN = Model(is_train=True)
    run(config, celebA, GeneGAN, gpu=args.gpu)
项目:speed    作者:keon    | 项目源码 | 文件源码
def init_datasets(arg, resize, n):
    """ Initialize N number of datasets for ensemble training """
    datasets = []
    for i in range(n):
        dset = Dataset(arg.train_folder,
                       resize=resize,
                       batch_size=arg.batch_size,
                       timesteps=arg.timesteps,
                       windowsteps=arg.timesteps // 2, shift=i*2, train=True)
        print('[!] train dataset len: %d - shift: %d' % (len(dset.data), i*2))
        datasets.append(dset)
    # Validation Dataset
    v_dataset = Dataset(arg.valid_folder,
                        resize=resize,
                        batch_size=arg.batch_size//2,
                        timesteps=arg.timesteps,
                        windowsteps=arg.timesteps //2, shift=0, train=True)
    print('[!] validation dataset samples: %d' % len(v_dataset.data))
    return datasets, v_dataset
项目:instacart-basket-prediction    作者:colinmorris    | 项目源码 | 文件源码
def main():
  logging.basicConfig(level=logging.INFO)
  parser = argparse.ArgumentParser()
  parser.add_argument('tags', metavar='tag', nargs='+')
  parser.add_argument('--fold', default='test', 
      help='identifier for file with the users to test on (default: test)')
  args = parser.parse_args()


  for model_tag in args.tags:
    hps = hypers.hps_for_tag(model_tag)
    dataset = Dataset(args.fold, hps, mode=Mode.inference)
    path = common.resolve_xgboostmodel_path(model_tag)
    logging.info('Loading model with tag {}'.format(model_tag))
    model = xgb.Booster(model_file=path)
    logging.info('Computing probs for tag {}'.format(model_tag))
    with time_me('Computed probs for {}'.format(model_tag), mode='stderr'):
      pdict = get_pdict(model, dataset)
      logging.info('Got probs for {} users'.format(len(pdict)))
      # TODO: might want to enforce some namespace separation between 
      # rnn-generated pdicts and ones coming from xgboost models?
      common.save_pdict_for_tag(model_tag, pdict, args.fold)
项目:instacart-basket-prediction    作者:colinmorris    | 项目源码 | 文件源码
def main():
  logging.basicConfig(level=logging.INFO)
  parser = argparse.ArgumentParser()
  parser.add_argument('tag')
  parser.add_argument('--train-recordfile', default='train', 
      help='identifier for file with the users to train on (default: train). deprecated: specify in hps...')
  parser.add_argument('-n', '--n-rounds', type=int, default=50,
      help='Number of rounds of boosting. Deprecated: specify this in hp config file')
  parser.add_argument('--weight', action='store_true',
      help='Whether to do per-instance weighting. Deprecated: specify in hps')
  args = parser.parse_args()

  try:
    hps = hypers.hps_for_tag(args.tag)
  except hypers.NoHpsDefinedException:
    logging.warn('No hps found for tag {}. Creating and saving some.'.format(args.tag))
    hps = hypers.get_default_hparams()
    hps.train_file = args.train_recordfile
    hps.rounds = args.n_rounds
    hps.weight = args.weight
    hypers.save_hps(args.tag, hps)
  validate_hps(hps)
  dataset = Dataset(hps.train_file, hps)
  with time_me(mode='stderr'):
    train(dataset, args.tag, hps)
项目:tf_serving_example    作者:Vetal1977    | 项目源码 | 文件源码
def main():
    # preparations
    create_checkpoints_dir()
    utils.download_train_and_test_data()
    trainset, testset = utils.load_data_sets()

    # create real input for the GAN model (its dicriminator) and
    # GAN model itself
    real_size = (32, 32, 3)
    z_size = 100
    learning_rate = 0.0003

    tf.reset_default_graph()
    input_real = tf.placeholder(tf.float32, (None, *real_size), name='input_real')
    net = GAN(input_real, z_size, learning_rate)

    # craete dataset
    dataset = Dataset(trainset, testset)

    # train the model
    batch_size = 128
    epochs = 25
    _, _, _ = train(net, dataset, epochs, batch_size, z_size)
项目:speed    作者:keon    | 项目源码 | 文件源码
def main(arg):
    resize = (200, 66)

    # initialize dataset
    dataset = Dataset(arg.test_folder,
                      resize=resize,
                      batch_size=1,
                      timesteps=arg.timesteps,
                      windowsteps=1,
                      shift=0,
                      train=False)
    print('[!] testing dataset samples: %d' % len(dataset.data))

    # initialize model
    cuda = th.cuda.is_available()
    models = init_models(arg.model, n=3, lr=0, restore=True, cuda=cuda)

    # Initiate Prediction
    t0 = datetime.datetime.now()
    try:
        predict(models, dataset, arg, cuda=cuda)
    except KeyboardInterrupt:
        print('[!] KeyboardInterrupt: Stopped Training...')
    t1 = datetime.datetime.now()

    print('[!] Finished Training, Time Taken4 %s' % (t1-t0))
项目:tensorflow-action-conditional-video-prediction    作者:williamd4112    | 项目源码 | 文件源码
def main(args):
    with tf.Graph().as_default() as graph:
        # Create dataset
        logging.info('Create data flow from %s' % args.train)
        train_data = Dataset(directory=args.train, mean_path=args.mean, batch_size=args.batch_size, num_threads=2, capacity=10000)

        # Create initializer
        init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

        # Config session
        config = get_config(args)

        # Setup summary
        check_summary_writer = tf.summary.FileWriter(os.path.join(args.log, 'check'), graph)

        check_op = tf.cast(train_data()['x_t_1'] * 255.0 + train_data()['mean'], tf.uint8)

        tf.summary.image('x_t_1_batch_restore', check_op, collections=['check'])
        check_summary_op = tf.summary.merge_all('check')

        # Start session
        with tf.Session(config=config) as sess:
            coord = tf.train.Coordinator()
            sess.run(init)
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            for i in range(10):
                x_t_1_batch, summary = sess.run([check_op, check_summary_op])
                check_summary_writer.add_summary(summary, i)
            coord.request_stop()
            coord.join(threads)
项目:rl-attack-detection    作者:yenchenlin    | 项目源码 | 文件源码
def main(args):
    with tf.Graph().as_default() as graph:
        # Create dataset
        logging.info('Create data flow from %s' % args.train)
        train_data = Dataset(directory=args.train, mean_path=args.mean, batch_size=args.batch_size, num_threads=2, capacity=10000)

        # Create initializer
        init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

        # Config session
        config = get_config(args)

        # Setup summary
        check_summary_writer = tf.summary.FileWriter(os.path.join(args.log, 'check'), graph)

        check_op = tf.cast(train_data()['x_t_1'] * 255.0 + train_data()['mean'], tf.uint8)

        tf.summary.image('x_t_1_batch_restore', check_op, collections=['check'])
        check_summary_op = tf.summary.merge_all('check')

        # Start session
        with tf.Session(config=config) as sess:
            coord = tf.train.Coordinator()
            sess.run(init)
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            for i in range(10):
                x_t_1_batch, summary = sess.run([check_op, check_summary_op])
                check_summary_writer.add_summary(summary, i)
            coord.request_stop()
            coord.join(threads)
项目:instacart-basket-prediction    作者:colinmorris    | 项目源码 | 文件源码
def dataset():
  return Dataset('testuser', hypers.get_default_hparams())
项目:dataset    作者:analysiscenter    | 项目源码 | 文件源码
def gen_data(num_items):
    features_array = np.arange(num_items * 3).reshape(num_items, -1)
    labels_array = np.random.choice(10, size=num_items)
    data = features_array, labels_array

    index = np.arange(num_items)
    # when your data fits into memory, just preload it
    dataset = Dataset(index=index, batch_class=MyBatch, preloaded=data)
    return dataset
项目:dataset    作者:analysiscenter    | 项目源码 | 文件源码
def gen_data(num_items):
    index = np.arange(num_items)
    data = np.arange(num_items * 3).reshape(num_items, -1)
    # when your data fits into memory, just preload it
    dataset = Dataset(index=index, batch_class=ArrayBatch, preloaded=data)
    return dataset
项目:dataset    作者:analysiscenter    | 项目源码 | 文件源码
def gen_data(num_items):
    index = np.arange(num_items)
    dataset = Dataset(index=index, batch_class=ArrayBatch)
    return dataset
项目:dataset    作者:analysiscenter    | 项目源码 | 文件源码
def gen_data(num_items):
    ix = np.arange(num_items).astype('str')
    data = np.arange(num_items * 3).reshape(num_items, -1)
    ds = Dataset(index=ix, batch_class=MyBatch, preloaded=data)
    return ds, data


# Create datasets
项目:dataset    作者:analysiscenter    | 项目源码 | 文件源码
def gen_data():
        ix = np.arange(K)
        images = np.random.randint(0, 255, size=K*S*S).reshape(-1, S, S).astype('uint8')
        labels = np.random.randint(0, 3, size=K).astype('uint8')
        masks = np.random.randint(0, 10, size=K).astype('uint8') + 100
        targets = np.random.randint(0, 10, size=K).astype('uint8') + 1000
        data = images, labels, masks, targets

        ds = Dataset(index=ix, batch_class=MyBatch)
        return ds, data


    # Create datasets
项目:dataset    作者:analysiscenter    | 项目源码 | 文件源码
def gen_data():
        ix = np.arange(K)
        images = np.random.randint(0, 255, size=K//2*S*S).reshape(-1, S, S).astype('uint8')
        top = np.random.randint(0, 3, size=K*2).reshape(-1, 2).astype('uint8')
        size = np.random.randint(3, 7, size=K*2).reshape(-1, 2).astype('uint8')
        pos = np.random.choice(range(len(images)), replace=True, size=K)
        data = images, top, size, pos

        #dsindex = DatasetIndex(ix)
        ds = Dataset(index=ix, batch_class=MyBatch)
        return ds, data


    # Create datasets
项目:dataset    作者:analysiscenter    | 项目源码 | 文件源码
def gen_data(num_items, shape):
        index = np.arange(num_items)
        data = np.random.randint(0, 255, size=num_items * shape[0] * shape[1])
        data = data.reshape(num_items, shape[0], shape[1]).astype('uint8')
        ds = Dataset(index=index, batch_class=ImagesBatch)
        return ds, data


    # Create a dataset
项目:dataset    作者:analysiscenter    | 项目源码 | 文件源码
def gen_data():
        ix = np.arange(K)
        images = np.random.randint(0, 255, size=K*S*S).reshape(-1, S, S).astype('uint8')
        labels = np.random.randint(0, 3, size=K).astype('uint8')
        masks = np.random.randint(0, 10, size=K).astype('uint8') + 100
        targets = np.random.randint(0, 10, size=K).astype('uint8') + 1000
        data = images, labels, masks, targets

        ds = Dataset(index=ix, batch_class=MyBatch)
        return ds, data


    # Create datasets
项目:object-classification    作者:HenrYxZ    | 项目源码 | 文件源码
def test_dataset():
    dataset = Dataset(constants.DATASET_PATH)
    pickle.dump(dataset, open(constants.DATASET_OBJ_FILENAME, "wb"), protocol=constants.PICKLE_PROTOCOL)
    classes = dataset.get_classes()
    print("Dataset generated with {0} classes.".format(len(classes)))
    print(classes)
    train = dataset.get_train_set()
    test = dataset.get_test_set()
    for i in range(len(classes)):
        print(
            "There are {0} training files and {1} testing files for class number {2} ({3})".format(
                len(train[i]), len(test[i]), i, classes[i]
            )
        )
项目:FunnyPyML    作者:MrPig    | 项目源码 | 文件源码
def load(self, target_col_name=None):
        with open(self.__path, 'r') as fp:
            data, categorical = self.__loader(fp)
            target_col_name = data.columns[
                -1] if target_col_name is None or target_col_name not in data.columns else target_col_name
        return Dataset(data=data, target_col_name=target_col_name, categorical=categorical)
项目:LinguisticAnalysis    作者:DucAnhPhi    | 项目源码 | 文件源码
def test_training_set(self):
        dataset = ds.Dataset("realDonaldTrump", "HillaryClinton", 0.8)
        # should have 50% positive and 50% negative examples
        trainSet = dataset.trainSet
        validationSet = dataset.validationSet
        tCount = ds.get_positive_negative_amount(trainSet)
        vCount = ds.get_positive_negative_amount(validationSet)
        self.assertEqual(tCount[0], tCount[1])
        self.assertEqual(vCount[0], vCount[1])
项目:EKLAVYA    作者:shensq04    | 项目源码 | 文件源码
def training(config_info):
    data_folder = config_info['data_folder']
    func_path = config_info['func_path']
    embed_path = config_info['embed_path']
    tag = config_info['tag']
    data_tag = config_info['data_tag']
    process_num = int(config_info['process_num'])
    embed_dim = int(config_info['embed_dim'])
    max_length = int(config_info['max_length'])
    num_classes = int(config_info['num_classes'])
    epoch_num = int(config_info['epoch_num'])
    save_batch_num = int(config_info['save_batchs'])
    output_dir = config_info['output_dir']

    '''create model & log folder'''
    if os.path.exists(output_dir):
        pass
    else:
        os.mkdir(output_dir)
    model_basedir = os.path.join(output_dir, 'model')
    if os.path.exists(model_basedir):
        pass
    else:
        os.mkdir(model_basedir)
    log_basedir = os.path.join(output_dir, 'log')
    if tf.gfile.Exists(log_basedir):
        tf.gfile.DeleteRecursively(log_basedir)
    tf.gfile.MakeDirs(log_basedir)
    config_info['log_path'] = log_basedir
    print('Created all folders!')

    '''load dataset'''
    if data_tag == 'callee':
        my_data = dataset.Dataset(data_folder, func_path, embed_path, process_num, embed_dim, max_length, num_classes, tag)
    else: #caller
        my_data = dataset_caller.Dataset(data_folder, func_path, embed_path, process_num, embed_dim, max_length, num_classes, tag)

    print('Created the dataset!')

    with tf.Graph().as_default(), tf.Session() as session:
        # generate placeholder
        data_pl, label_pl, length_pl, keep_prob_pl = placeholder_inputs(num_classes, max_length, embed_dim)

        # generate model
        model = Model(session, my_data, config_info, data_pl, label_pl, length_pl, keep_prob_pl)
        print('Created the model!')

        while my_data._complete_epochs < epoch_num:
            model.train()
            if model.run_count % save_batch_num == 0:
                model.saver.save(session, os.path.join(model_basedir, 'model'), global_step = model.run_count)
                print('Saved the model ... %d' % model.run_count)
            else:
                pass
        model.train_writer.close()
        model.test_writer.close()
项目:EKLAVYA    作者:shensq04    | 项目源码 | 文件源码
def testing(config_info):
    data_folder = config_info['data_folder']
    func_path = config_info['func_path']
    embed_path = config_info['embed_path']
    tag = config_info['tag']
    data_tag = config_info['data_tag']
    process_num = int(config_info['process_num'])
    embed_dim = int(config_info['embed_dim'])
    max_length = int(config_info['max_length'])
    num_classes = int(config_info['num_classes'])
    model_dir = config_info['model_dir']
    output_dir = config_info['output_dir']

    '''create model & log folder'''
    if os.path.exists(output_dir):
        pass
    else:
        os.mkdir(output_dir)
    print('Created all folders!')

    '''load dataset'''
    if data_tag == 'callee':
        my_data = dataset.Dataset(data_folder, func_path, embed_path, process_num, embed_dim, max_length, num_classes, tag)
    else: # caller
        my_data = dataset_caller.Dataset(data_folder, func_path, embed_path, process_num, embed_dim, max_length, num_classes, tag)
    print('Created the dataset!')

    '''get model id list'''
    # model_id_list = sorted(get_model_id_list(model_dir), reverse=True)
    model_id_list = sorted(get_model_id_list(model_dir))

    with tf.Graph().as_default(), tf.Session() as session:
        # generate placeholder
        data_pl, label_pl, length_pl, keep_prob_pl = placeholder_inputs(num_classes, max_length, embed_dim)
        # generate model
        model = Model(session, my_data, config_info, data_pl, label_pl, length_pl, keep_prob_pl)
        print('Created the model!')

        for model_id in model_id_list:
            result_path = os.path.join(output_dir, 'test_result_%d.pkl' % model_id)
            if os.path.exists(result_path):
                continue
            else:
                pass
            model_path = os.path.join(model_dir, 'model-%d' % model_id)
            model.saver.restore(session, model_path)

            total_result = model.test()
            my_data._index_in_test = 0
            my_data.test_tag = True
            with open(result_path, 'w') as f:
                pickle.dump(total_result, f)
            print('Save the test result !!! ... %s' % result_path)
项目:instacart-basket-prediction    作者:colinmorris    | 项目源码 | 文件源码
def train(traindat, tag, hps):
  valdat = Dataset('validation', hps, mode=Mode.eval)
  # TODO: try set_base_margin (https://github.com/dmlc/xgboost/blob/master/demo/guide-python/boost_from_prediction.py)
  with time_me('Made training dmatrix', mode='stderr'):
    dtrain = traindat.as_dmatrix()
  def quick_fscore(preds, _notused_dtrain):
    global counter
    counter += 1
    if 0 and counter % 5 != 0:
      return 'fscore', 0.0
    with time_me('calculated validation fscore', mode='print'):
      user_counts = defaultdict(lambda : dict(tpos=0, fpos=0, fneg=0))
      uids = valdat.uids
      labels = dval.get_label()
      for i, prob in enumerate(preds):
        uid = uids[i]
        pred = prob >= THRESH
        label = labels[i]
        if pred and label:
          user_counts[uid]['tpos'] += 1
        elif pred and not label:
          user_counts[uid]['fpos'] += 1
        elif label and not pred:
          user_counts[uid]['fneg'] += 1
      fscore_sum = 0
      for uid, res in user_counts.iteritems():
        numerator = 2 * res['tpos']
        denom = numerator + res['fpos'] + res['fneg']
        if denom == 0:
          fscore = 1
        else:
          fscore = numerator / denom
        fscore_sum += fscore
      return 'fscore', fscore_sum / len(user_counts)

  dval = valdat.as_dmatrix()
  # If you pass in more than one value to evals, early stopping uses the
  # last one. Because why not.
  watchlist = [(dtrain, 'train'), (dval, 'validation'),]
  #watchlist = [(dval, 'validation'),]

  xgb_params = hypers.xgb_params_from_hps(hps)
  evals_result = {}
  t0 = time.time()
  model = xgb.train(xgb_params, dtrain, hps.rounds, evals=watchlist, 
      early_stopping_rounds=hps.early_stopping_rounds, evals_result=evals_result) #, feval=quick_fscore, maximize=True)

  t1 = time.time()
  model_path = common.resolve_xgboostmodel_path(tag)
  model.save_model(model_path)
  preds = model.predict(dval)
  _, fscore = quick_fscore(preds, None)
  logging.info('Final validation (quick) fscore = {}'.format(fscore))
  resultsdict = dict(fscore=fscore, evals=evals_result, duration=t1-t0)
  res_path = os.path.join(common.XGBOOST_DIR, 'results', tag+'.pickle')
  with open(res_path, 'w') as f:
    pickle.dump(resultsdict, f)
项目:pytorch-nips2017-attack-example    作者:rwightman    | 项目源码 | 文件源码
def run_attack(args, attack):
    assert args.input_dir

    if args.targeted:
        dataset = Dataset(
            args.input_dir,
            transform=default_inception_transform(args.img_size))
    else:
        dataset = Dataset(
            args.input_dir,
            target_file='',
            transform=default_inception_transform(args.img_size))

    loader = data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False)

    model = torchvision.models.inception_v3(pretrained=False, transform_input=False)
    if not args.no_gpu:
        model = model.cuda()

    if args.checkpoint_path is not None and os.path.isfile(args.checkpoint_path):
        checkpoint = torch.load(args.checkpoint_path)
        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model.load_state_dict(checkpoint)
    else:
        print("Error: No checkpoint found at %s." % args.checkpoint_path)

    model.eval()

    for batch_idx, (input, target) in enumerate(loader):
        if not args.no_gpu:
            input = input.cuda()
            target = target.cuda()

        input_adv = attack.run(model, input, target, batch_idx)

        start_index = args.batch_size * batch_idx
        indices = list(range(start_index, start_index + input.size(0)))
        for filename, o in zip(dataset.filenames(indices, basename=True), input_adv):
            output_file = os.path.join(args.output_dir, filename)
            imsave(output_file, (o + 1.0) * 0.5, format='png')
项目:universal_postagger    作者:hpzhao    | 项目源码 | 文件源码
def build_dataset(data_file):
    global word2id,char2id,word2cluster,upos2id,xpos2id

    sent_words_list = []
    sent_chars_list = []
    sent_clusters_list = []
    sent_upos_list = []
    sent_xpos_list = []

    words_list = []
    chars_list = []
    clusters_list = []
    upos_list = []
    xpos_list = []

    for line in open(data_file):
        line = line.strip().decode('utf8')
        if line and line[0] != u'#':
            tokens = line.split('\t')
            if u'-' not in tokens[0] and u'.' not in tokens[0]:
                word = tokens[1].lower()
                words_list.append(word2id[word] if word in word2id else 1)
                chars_list.append([char2id[char] if char in char2id else 1 for char in word])
                clusters_list.append(word2cluster[word] if word in word2cluster else 0)
                upos,xpos = tokens[3:5]
                upos_list.append(upos2id[upos] if upos in upos2id else 0)
                xpos_list.append(xpos2id[xpos] if xpos in xpos2id else 0)
        if line == '':
            sent_words_list.append(words_list)
            sent_chars_list.append(chars_list)
            sent_clusters_list.append(clusters_list)
            sent_xpos_list.append(xpos_list)
            sent_upos_list.append(upos_list)

            words_list = []
            chars_list = []
            clusters_list = []
            upos_list = []
            xpos_list = []
    upos_word_dataset = Dataset(sent_words_list,sent_upos_list)
    xpos_word_dataset = Dataset(sent_words_list,sent_xpos_list)
    char_dataset = Dataset(sent_chars_list,sent_upos_list)
    cluster_dataset = Dataset(sent_clusters_list,sent_upos_list)
    return upos_word_dataset,xpos_word_dataset,char_dataset,cluster_dataset
项目:Sing_Par    作者:wanghm92    | 项目源码 | 文件源码
def __init__(self, model, *args, **kwargs):
    """"""

    if args:
      if len(args) > 1:
        raise TypeError('Parser takes at most one argument')

    kwargs['name'] = kwargs.pop('name', model.__name__)
    super(Network, self).__init__(*args, **kwargs)
    if not os.path.isdir(self.save_dir):
      os.mkdir(self.save_dir)
    with open(os.path.join(self.save_dir, 'config.cfg'), 'w') as f:
      self._config.write(f)

    self._global_step = tf.Variable(0., trainable=False)
    self._global_epoch = tf.Variable(0., trainable=False)
    self._model = model(self._config, global_step=self.global_step)

    self._vocabs = []
    vocab_files = [(self.word_file, 1, 'Words'),
                   (self.tag_file, [3, 4], 'Tags'),
                   (self.rel_file, 7, 'Rels')]
    for i, (vocab_file, index, name) in enumerate(vocab_files):
      vocab = Vocab(vocab_file, index, self._config,
                    name=name,
                    cased=self.cased if not i else True,
                    load_embed_file=(not i),
                    global_step=self.global_step)
      self._vocabs.append(vocab)

    self._trainset = Dataset(self.train_file, self._vocabs, model, self._config, name='Trainset')
    self._validset = Dataset(self.valid_file, self._vocabs, model, self._config, name='Validset')
    self._testset = Dataset(self.test_file, self._vocabs, model, self._config, name='Testset')

    self._ops = self._gen_ops()
    self.history = {
      'train_loss': [],
      'train_accuracy': [],
      'valid_loss': [],
      'valid_accuracy': [],
      'test_acuracy': 0
    }
    return

  #=============================================================
项目:Parser-v1    作者:tdozat    | 项目源码 | 文件源码
def __init__(self, model, *args, **kwargs):
    """"""

    if args:
      if len(args) > 1:
        raise TypeError('Parser takes at most one argument')

    kwargs['name'] = kwargs.pop('name', model.__name__)
    super(Network, self).__init__(*args, **kwargs)
    if not os.path.isdir(self.save_dir):
      os.mkdir(self.save_dir)
    with open(os.path.join(self.save_dir, 'config.cfg'), 'w') as f:
      self._config.write(f)

    self._global_step = tf.Variable(0., trainable=False)
    self._global_epoch = tf.Variable(0., trainable=False)
    self._model = model(self._config, global_step=self.global_step)

    self._vocabs = []
    vocab_files = [(self.word_file, 1, 'Words'),
                   (self.tag_file, [3, 4], 'Tags'),
                   (self.rel_file, 7, 'Rels')]
    for i, (vocab_file, index, name) in enumerate(vocab_files):
      vocab = Vocab(vocab_file, index, self._config,
                    name=name,
                    cased=self.cased if not i else True,
                    use_pretrained=(not i),
                    global_step=self.global_step)
      self._vocabs.append(vocab)

    self._trainset = Dataset(self.train_file, self._vocabs, model, self._config, name='Trainset')
    self._validset = Dataset(self.valid_file, self._vocabs, model, self._config, name='Validset')
    self._testset = Dataset(self.test_file, self._vocabs, model, self._config, name='Testset')

    self._ops = self._gen_ops()
    self._save_vars = filter(lambda x: u'Pretrained' not in x.name, tf.all_variables())
    self.history = {
      'train_loss': [],
      'train_accuracy': [],
      'valid_loss': [],
      'valid_accuracy': [],
      'test_acuracy': 0
    }
    return

  #=============================================================
项目:ActiveBoundary    作者:MiriamHu    | 项目源码 | 文件源码
def init(self, opt, X, y, y_groundtruth, X_val, y_val, unlabeled_class=-5, ali_model=None):
        dataset = Dataset(X, y, y_groundtruth, X_val, y_val, unlabeled_class=unlabeled_class,
                          al_batch_size=opt.al_batch_size, save_path_db_points=opt.save_path,
                          dataset=opt.hdf5_dataset_encoded)
        print "All samples: ", len(dataset)
        print "Labeled samples: ", dataset.len_labeled()
        print "Unlabeled samples: ", dataset.len_unlabeled()

        print "Initializing model"
        model = JointOptimisationSVM(initial_model=self.initial_model,
                                     hyperparameters=opt.hyperparameters,
                                     save_path_boundaries=opt.save_path)  # declare model instance
        print "Done declaring model"
        print "Initializing query strategy", opt.query_strategy
        if opt.query_strategy == "uncertainty":
            query_strategy = UncertaintySamplingLine(dataset, model=model, generative_model=ali_model,
                                                     save_path_queries=opt.save_path,
                                                     human_experiment=self.human_experiment,
                                                     base_precision=opt.base_precision)  # declare a QueryStrategy instance
        elif opt.query_strategy == "uncertainty-dense":
            query_strategy = UncertaintyDenseSamplingLine(dataset, model=model, generative_model=ali_model,
                                                          save_path_queries=opt.save_path,
                                                          human_experiment=self.human_experiment,
                                                          base_precision=opt.base_precision)
        elif opt.query_strategy == "clustercentroids":
            query_strategy = ClusterCentroidsLine(dataset, model=model, generative_model=ali_model,
                                                  batch_size=opt.al_batch_size, save_path_queries=opt.save_path,
                                                  human_experiment=self.human_experiment,
                                                  base_precision=opt.base_precision)
        elif opt.query_strategy == "random":
            query_strategy = RandomSamplingLine(dataset, model=model, generative_model=ali_model,
                                                save_path_queries=opt.save_path,
                                                human_experiment=self.human_experiment,
                                                base_precision=opt.base_precision)
        else:
            raise Exception("Please specify a query strategy")
        print "Done declaring query strategy", opt.query_strategy

        if opt.oracle_type == "noisy_line_labeler":
            labeler = NoisyLineLabeler(dataset, opt.std_noise, pretrained_groundtruth=self.groundtruth_model,
                                       hyperparameters=opt.hyperparameters)
            print "Done declaring NoisyLineLabeler"
        elif opt.oracle_type == "human_line_labeler":
            labeler = HumanLineLabeler(dataset, ali_model, hyperparameters=opt.hyperparameters)
        else:
            labeler = LineLabeler(dataset, pretrained_groundtruth=self.groundtruth_model,
                                  hyperparameters=opt.hyperparameters)  # declare Labeler instance
            print "Done declaring LineLabeler"
        print "Done initializing"
        return dataset, model, query_strategy, labeler