Python _pickle 模块,dump() 实例源码

我们从Python开源项目中,提取了以下33个代码示例,用于说明如何使用_pickle.dump()

项目:GANGogh    作者:rkjones4    | 项目源码 | 文件源码
def flush():
    prints = []

    for name, vals in _since_last_flush.items():
        prints.append("{}\t{}".format(name, np.mean(list(vals.values()))))
        _since_beginning[name].update(vals)

        x_vals = np.sort(list(_since_beginning[name].keys()))
        y_vals = [_since_beginning[name][x] for x in x_vals]

        plt.clf()
        plt.plot(x_vals, y_vals)
        plt.xlabel('iteration')
        plt.ylabel(name)
        plt.savefig('generated/'+name.replace(' ', '_')+'.jpg')

    print("iter {}\t{}".format(_iter[0], "\t".join(prints)))
    _since_last_flush.clear()

    with open('log.pkl', 'wb') as f:
        pickle.dump(dict(_since_beginning), f, 4)
项目:Caption-Generation    作者:m516825    | 项目源码 | 文件源码
def build_w2v_matrix(vocab_processor, w2v_path, vector_path, dim_size):
    w2v_dict = {}
    f = open(vector_path, 'r')
    for line in f.readlines():
        word, vec = line.strip().split(' ', 1)
        w2v_dict[word] = np.loadtxt([vec], dtype='float32')

    vocab_list = vocab_processor._reverse_mapping
    w2v_W = np.zeros(shape=(len(vocab_list), dim_size), dtype='float32')

    for i, vocab in enumerate(vocab_list):
        # unknown vocab
        if i == 0:
            continue
        else:
            if vocab in w2v_dict:
                w2v_W[i] = w2v_dict[vocab]
            else:
                w2v_W[i] = get_unknown_word_vec(dim_size)

    cPickle.dump(w2v_W, open(w2v_path, 'wb'))

    return w2v_W
项目:text_imbalance_multiclass_classifier    作者:ManishKV    | 项目源码 | 文件源码
def create_classifier(doc,class_freq):        
    ## iterates through all classes in dataset
    for i in class_freq:         
        ## create a copy of dataframe
        temp_doc = doc.copy()
        ## assign class as '-1' to all other classes.
        temp_doc.category[temp_doc['category'] != class_freq[i]] = '-1'

        temp_doc = temp_doc.values.tolist()
        ## training classifier on the temp_doc
        classifier = NaiveBayesClassifier(temp_doc)

        # save the classifier on disk
        with open('classifier_'+i+'.pkl', 'wb') as fid:
            cPickle.dump(classifier, fid)
        ## reassign doc with the new reduced dataset
        doc = doc[doc['category'] != class_freq[i]].copy()



## to track time taken in creating classifiers
项目:personality    作者:meyersbs    | 项目源码 | 文件源码
def aggregate_data(data=constants.DATASET_CLEANED_PATH):
    with open(data, newline='') as csvfile:
        csv_reader = csv.reader(csvfile, delimiter=',', quotechar='"')
        next(csv_reader, None) # skip the header row

        words = dict()
        # for each line in the csv file
        for status in csv_reader:
            # for each word in the status
            for w in re.split(r'[\s_\-]+', _clean(status[0].lower())):
                # create a Word object if we don't have one
                if _clean(w) not in words.keys():
                    words[_clean(w)] = word.Word.init(_clean(w), status[6:])
                # update the Word object if it exists
                else:
                    words[_clean(w)].update_freqs(status[6:])

    with open(constants.AGGREGATE_INFO_FILE, 'wb') as out_file:
        _pickle.dump(words, out_file)
    return words
项目:chinese-char-rnn    作者:indiejoseph    | 项目源码 | 文件源码
def preprocess(self, input_file, vocab_file, tensor_file):
    with codecs.open(input_file, "r", encoding=self.encoding) as f:
      train_data = f.read()
      train_data = normalize_unicodes(train_data)

    counter = collections.Counter(train_data)
    count_pairs = sorted(counter.items(), key=lambda x: -x[1])
    threshold = 10
    self.chars, counts = zip(*count_pairs)
    self.chars = START_VOCAB + [c for i, c in enumerate(self.chars) if c not in START_VOCAB and counts[i] > threshold]
    self.vocab_size = len(self.chars)
    self.vocab = dict(zip(self.chars, range(len(self.chars))))
    with open(vocab_file, 'wb') as f:
      cPickle.dump(self.chars, f)
    unk_index = START_VOCAB.index(UNK)
    self.tensor = np.array([self.vocab.get(c, unk_index) for c in train_data], dtype=np.int64)
    train_size = int(self.tensor.shape[0] * 0.9)
    self.valid = self.tensor[train_size:]
    self.train = self.tensor[:train_size]
    np.save(tensor_file, self.tensor)
项目:autoxd    作者:nessessary    | 项目源码 | 文件源码
def get(fn, *args, **kwargs):
    """??redis?cache??
    fn: ??, ????????
    return: data fn????"""
    key = gen_keyname(fn)
    r = createRedis()
    #r.flushall()
    if key not in r.keys():
        o = fn(*args, **kwargs)
        #?????????
        f = cStringIO.StringIO()
        cPickle.dump(o, f)
        s = f.getvalue()
        f.close()        
        r.set(key, s)
    s = r.get(key)
    f = cStringIO.StringIO(s)
    o = cPickle.load(f)
    f.close()
    return o
项目:Caption-Generation    作者:m516825    | 项目源码 | 文件源码
def clean_str(string):
    string = re.sub(r"\.", r"", string)
    return string.strip().lower()

# load and dump the mapped train and valid's captions
项目:IDNNs    作者:ravidziv    | 项目源码 | 文件源码
def save_data(self, parent_dir='jobs/', file_to_save='data.pickle'):
        """Save the data to the file """
        directory = '{0}/{1}{2}/'.format(os.getcwd(), parent_dir, self.params['directory'])

        data = {'information': self.information,
                'test_error': self.test_error, 'train_error': self.train_error, 'var_grad_val': self.grads,
                'loss_test': self.loss_test, 'loss_train': self.loss_train, 'params': self.params
            , 'l1_norms': self.l1_norms, 'weights': self.weights, 'ws': self.ws}

        if not os.path.exists(directory):
            os.makedirs(directory)
        self.dir_saved = directory
        with open(self.dir_saved + file_to_save, 'wb') as f:
            cPickle.dump(data, f, protocol=2)
项目:bot2017Fin    作者:AllanYiin    | 项目源码 | 文件源码
def save(self, filename='word2vec.pklz'):
        """
        :param filename:????
        """
        fil = gzip.open(filename, 'wb')
        cPickle.dump(self, fil, protocol=pickle.HIGHEST_PROTOCOL)
        fil.close()
项目:epfl-semester-project-biaxialnn    作者:onanypoint    | 项目源码 | 文件源码
def generate_sample(model, pieces, directory, name):
    """Generate a sample and save it to disk

    Note
    ----
    Only use the first note just as in the original model

    Parameters
    ----------
    model : utils.model.Model
        The model to use
    pieces : dict
        Dictrionary containing statematrixes as values
    directory : str
        path to parent folder
    name : str
        specific name of the file, will be append after prefix
    """
    xIpt, xOpt = map(np.array, model.data_manager.get_piece_segment(pieces))
    seed_i, seed_o = (xIpt[0], xOpt[0])
    generated_sample = model.generate_fun(seq_len, 1, np.expand_dims(seed_i, axis=0))
    statematrix = np.concatenate((np.expand_dims(seed_o, 0), generated_sample), axis=0)
    s = model.data_manager.s.statematrix_to_stream(statematrix)
    np.save(directory + 'samples/sample_{}.npy'.format(name), statematrix)
    s.write('musicxml', directory + 'samples/sample_{}.xml'.format(name))

    pickle.dump(model.learned_config, open(directory + 'weights/params_{}.p'.format(name), 'wb'))
项目:RFHO    作者:lucfra    | 项目源码 | 文件源码
def save_obj(obj, name, root_dir=None, notebook_mode=True, default_overwrite=False):
    if root_dir is None: root_dir = os.getcwd()
    directory = check_or_create_dir(join_paths(root_dir, FOLDER_NAMINGS['OBJ_DIR']),
                                    notebook_mode=notebook_mode)

    filename = join_paths(directory, '%s.pkgz' % name)  # directory + '/%s.pkgz' % name
    if not default_overwrite and os.path.isfile(filename):
        overwrite = input('A file named %s already exists. Overwrite (Leave string empty for NO!)?' % filename)
        if not overwrite:
            print('No changes done.')
            return
        print('Overwriting...')
    with gzip.open(filename, 'wb') as f:
        pickle.dump(obj, f)
        # print('File saved!')
项目:ForumAnalysis    作者:consjuly542    | 项目源码 | 文件源码
def write_current_article_list(articles_list, cnt_visible_article=70):
    """
    Writing top = cnt_visible_article articles, 
    which user want to see, to a file.

    Parameters
    -------------
    *articles_list (list of instance ArticleStatistics): all articles
    *cnt_visible_article (int): count of visible articles 
    """
    with open("./../data/statistics/current_article_list", "wb") as f:
        article_dict = data2dict(copy(articles_list)[:cnt_visible_article])
        cPickle.dump(article_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
项目:ForumAnalysis    作者:consjuly542    | 项目源码 | 文件源码
def get_article_index(self):
        """
        Create dictionary {article:article_statistics} -
        Empty index for future statistics
        """
        articles = load_file_article.load_data()
        print ("Count official articles: %d" % len(articles))
        self.article_index = {a.article_ID: ArticleStatistics(a) for a in articles}

        with open("./../data/statistics/article_index", "wb") as f:
            cPickle.dump(self.article_index, f, protocol=pickle.HIGHEST_PROTOCOL)
项目:ForumAnalysis    作者:consjuly542    | 项目源码 | 文件源码
def get_article_statistics(self, recompute_statistics=True):
        """
        Agregate statistics from both forum.
        """
        if recompute_statistics:
            self.get_article_index()
            data_generator = loadDataGenerator()
            cnt_not_match_links = 0
            links_cnt = 0
            l2a = Link2Article()
            # log = open("./logs", "w")
            # error_link = []
            for question_batch in data_generator:
                for question in question_batch:

                    links = LinksSearcher(question.get_all_text()).get_simple_links()
                    for link in links:
                        # log.write(link.link_text + "\n")
                        # log.flush()
                        # function from Alexandrina
                        article = l2a.link2article(link)
                        # print (article)
                        if article:
                            # print (article.article_ID)
                            links_cnt += 1
                            self.article_index[article.article_ID].add_question(question, link)
                        else:
                            cnt_not_match_links += 1

                    sys.stderr.write("\r\t\t\t\t\tALL LINKS: %d; CAN't MATCH: %d" % (links_cnt, cnt_not_match_links))

            with open("./../data/statistics/article_statistics", "wb") as f:
                cPickle.dump(self.article_index, f, protocol=pickle.HIGHEST_PROTOCOL)
        else:
            with open("./../data/statistics/article_statistics", "rb") as f:
                self.article_index = cPickle.load(f)
项目:ForumAnalysis    作者:consjuly542    | 项目源码 | 文件源码
def add_question(self, question, link):
        if link.part_num:
            if link.part_num not in self.parts_statistics:
                self.parts_statistics[link.part_num] = 0
            self.parts_statistics[link.part_num] += 1

        self.questions_cnt += 1
        self.sum_answers_cnt += len(question.answers)

        self.cur_mean_answers = float(self.sum_answers_cnt) / self.questions_cnt

        with open(self.questions_filename, "ab") as f:
            tmp = copy(question)
            tmp.date = convert_date(question.date)
            cPickle.dump(tmp.to_dict(), f, protocol=pickle.HIGHEST_PROTOCOL)

        date_parts = question.date.strip().split("_")
        if len(date_parts) == 1:
            date_parts = question.date.strip().split(".")

        # print (date_parts)

        q_date = date(int(date_parts[0]), \
                     int(date_parts[1]), \
                     int(date_parts[2]))
        self.dates.append(q_date)

        if (self.first_date and self.first_date > q_date) or (self.first_date is None):
            self.first_date = q_date
        if (self.last_date and self.last_date < q_date) or (self.last_date is None):
            self.last_date = q_date
项目:personality    作者:meyersbs    | 项目源码 | 文件源码
def aggregate_train_test(train_size, data=constants.DATASET_CLEANED_PATH):
    with open(data, newline='') as csvfile:
        csv_reader = csv.reader(csvfile, delimiter=',', quotechar='"')
        statuses = list(csv_reader)
        statuses = statuses[1:] # skip the header row
        size = len(statuses)

    train_words, test_words = dict(), dict()
    train_size = int(size*train_size)

    random.shuffle(statuses)

    for i in range(0, train_size):
        status = statuses[i]
        for w in re.split(r'[\s_\-]+', _clean(status[0].lower())):
            if _clean(w) not in train_words.keys():
                train_words[_clean(w)] = word.Word.init(_clean(w), status[6:])
            else:
                train_words[_clean(w)].update_freqs(status[6:])

    for i in range(train_size, size):
        status = statuses[i]
        for w in re.split(r'[\s_\-]+', _clean(status[0].lower())):
            if _clean(w) not in test_words.keys():
                test_words[_clean(w)] = word.Word.init(_clean(w), status[6:])
            else:
                test_words[_clean(w)].update_freqs(status[6:])

    with open(constants.AGGREGATE_TRAINING, 'wb') as out_file:
        _pickle.dump(train_words, out_file)
    with open(constants.AGGREGATE_TESTING, 'wb') as out_file:
        _pickle.dump(test_words, out_file)
    stats = []
    with open(constants.AGGREGATE_TESTING_STATUSES, 'w') as out_file:
        csv_writer = csv.writer(out_file, delimiter=',', quotechar='"',
                                quoting=csv.QUOTE_MINIMAL)
        for status in statuses:
            csv_writer.writerow([status[0]])
            stats.append(status[0])
    return train_words, test_words, stats
项目:personality    作者:meyersbs    | 项目源码 | 文件源码
def predict_split(train_size):
    if os.path.getsize(TRAIN_PREDICTIONS) == 0:
        print("Splitting Training/Testing Data...")
        _, _, statuses = aggregate.aggregate_train_test(train_size)
        statuses = "\n".join(statuses)
        print("Gathering Features...")
        train_preds = predict(statuses, text_type='str', data=AGGREGATE_TRAINING)
        print("==== DUMPING TRAINING ====")
        with open(TRAIN_PREDICTIONS, 'wb') as f:
            _pickle.dump(train_preds, f)
        test_preds = predict(statuses, text_type='str', data=AGGREGATE_TESTING)
        print("==== DUMPING TESTING ====")
        with open(TEST_PREDICTIONS, 'wb') as f:
            _pickle.dump(test_preds, f)

    with open(TRAIN_PREDICTIONS, 'rb') as f:
        train_preds = _pickle.load(f)
        helpers.print_preds(train_preds, "TRAINING RESULTS")
    with open(TEST_PREDICTIONS, 'rb') as f:
        test_preds = _pickle.load(f)
        helpers.print_preds(test_preds, "TESTING RESULTS")
    print("Reading Testing Data...")
    with open(AGGREGATE_TESTING_STATUSES, 'r') as f:
        csv_reader = csv.reader(f, delimiter=',', quotechar='"')

        predictions = {}
        print("Collecting Predictions...\n")
        for i, status in enumerate(csv_reader):
            print("\r" + str(i), end="")
            train_pred = predict(status[0], text_type='str', data=AGGREGATE_TRAINING)
            test_pred = predict(status[0], text_type='str', data=AGGREGATE_TESTING)
            predictions[i] = [status[0], pred_to_labels(train_pred), pred_to_labels(test_pred)]

    sys.exit()
    return predictions
项目:personality    作者:meyersbs    | 项目源码 | 文件源码
def print_data(args):
    # TODO: Move this elsewhere.
    with open(constants.AGGREGATE_INFO_FILE, 'rb') as f:
        preds = _pickle.load(f)

    results = {'eRatio_y': 0, 'eRatio_n': 0, 'nRatio_y': 0, 'nRatio_n': 0,
               'aRatio_y': 0, 'aRatio_n': 0, 'cRatio_y': 0, 'cRatio_n': 0,
               'oRatio_y': 0, 'oRatio_n': 0}
    for key, pred in preds.items():
        if pred.eRatio_y > pred.eRatio_n:
            results['eRatio_y'] += 1
        else:
            results['eRatio_n'] += 1
        if pred.nRatio_y > pred.nRatio_n:
            results['nRatio_y'] += 1
        else:
            results['nRatio_n'] += 1
        if pred.aRatio_y > pred.aRatio_n:
            results['aRatio_y'] += 1
        else:
            results['aRatio_n'] += 1
        if pred.cRatio_y > pred.cRatio_n:
            results['cRatio_y'] += 1
        else:
            results['cRatio_n'] += 1
        if pred.oRatio_y > pred.oRatio_n:
            results['oRatio_y'] += 1
        else:
            results['oRatio_n'] += 1

    helpers.print_preds(results, '')
    with open('results.pkl', 'wb') as f:
        _pickle.dump(results, f)
项目:universal_tool_template.py    作者:shiningdesign    | 项目源码 | 文件源码
def writeDataFile(self,data,file,binary=0):
        with open(file, 'w') as f:
            if binary == 0:
                json.dump(data, f)
            else:
                cPickle.dump(data, f)
项目:universal_tool_template.py    作者:shiningdesign    | 项目源码 | 文件源码
def writeDataFile(self,data,file,binary=0):
        with open(file, 'w') as f:
            if binary == 0:
                json.dump(data, f)
            else:
                cPickle.dump(data, f)
项目:universal_tool_template.py    作者:shiningdesign    | 项目源码 | 文件源码
def writeFileData(self,data,file,binary=0):
        with open(file, 'w') as f:
            if binary == 0:
                json.dump(data, f)
            else:
                cPickle.dump(data, f)
项目:universal_tool_template.py    作者:shiningdesign    | 项目源码 | 文件源码
def writeDataFile(self,data,file,binary=0):
        with open(file, 'w') as f:
            if binary == 0:
                json.dump(data, f)
            else:
                cPickle.dump(data, f)
项目:universal_tool_template.py    作者:shiningdesign    | 项目源码 | 文件源码
def writeFileData(self,data,file,binary=0):
        with open(file, 'w') as f:
            if binary == 0:
                json.dump(data, f)
            else:
                cPickle.dump(data, f)
项目:universal_tool_template.py    作者:shiningdesign    | 项目源码 | 文件源码
def writeDataFile(self,data,file,binary=0):
        with open(file, 'w') as f:
            if binary == 0:
                json.dump(data, f)
            else:
                cPickle.dump(data, f)
项目:autoxd    作者:nessessary    | 项目源码 | 文件源码
def test_save():
    import stock
    r = redis.Redis(host='localhost', port=6379, db=0) 
    ths = stock.createThs()
    #df = stock.Guider(code).ToDataFrame()
    f = cStringIO.StringIO()
    cPickle.dump(ths, f)
    #df.to_csv(f)
    s = f.getvalue()
    f.close()
    print(len(s))
    r.set('ths', s)
项目:autoxd    作者:nessessary    | 项目源码 | 文件源码
def set(fn, *args, **kwargs):
    """????, ????????????"""
    key = gen_keyname(fn)
    r = createRedis()
    #r.flushall()
    o = fn(*args, **kwargs)
    #?????????
    f = cStringIO.StringIO()
    cPickle.dump(o, f)
    s = f.getvalue()
    f.close()        
    r.set(key, s)
项目:autoxd    作者:nessessary    | 项目源码 | 文件源码
def set_obj(key, o):
    """????, ????"""
    r = createRedis()
    f = cStringIO.StringIO()
    cPickle.dump(o, f)
    s = f.getvalue()
    f.close()        
    r.set(key, s)
项目:mflod    作者:arachnid42    | 项目源码 | 文件源码
def dump_key_pickle(self, key_ind, path):

        if not 0 < key_ind < len(self.keys):
            raise IndexError("Key index out of range")

        try:
            with open(path, 'wb') as f:
                pkl.dump(self.keys[key_ind], f)
        except Exception as e:
            print("failed to dump key: %s" % e)
项目:GeneralCrawlers    作者:shaolinjr    | 项目源码 | 文件源码
def getPhotos(urls, thumbs=False):
    puts("Locating Photos...")
    photos = {}
    typeErrorCount = 0
    keyErrorCount = 0
    urlErrorCount = 0
    for url in progress.bar(urls):
        try:
            data = urlopen(ROOT_URL + url).read()
            soup = BeautifulSoup(data, 'lxml')
            result = soup.find('img')
            if result is None:
                typeErrorCount += 1
                continue
            if thumbs:
                photos[url] = result['src']
            else:
                photos[url] = result.parent['href']
        except TypeError:
            typeErrorCount += 1
        except KeyError:
            keyErrorCount += 1
        except URLError:
            urlErrorCount += 1
    puts(colored.green("Found %d photos." % len(photos.values())))
    puts(colored.red("URL Error Count: %d" % urlErrorCount))
    puts(colored.red("Key Error Count: %d" % keyErrorCount))
    puts(colored.red("Type Error Count: %d" % typeErrorCount))
    with open('photos.pkl', 'wb') as output:
        pickle.dump(photos, output, pickle.HIGHEST_PROTOCOL)
    return photos
项目:Trading-Brain    作者:Prediction-Machines    | 项目源码 | 文件源码
def save_pkl(obj, path):
    with open(path, 'w') as f:
        cPickle.dump(obj, f)
        print("  [*] save %s" % path)
项目:Caption-Generation    作者:m516825    | 项目源码 | 文件源码
def load_text_data(train_lab, prepro_train_p, vocab_path):
    tlab = json.load(open(train_lab, 'r'))
    vocab_dict = collections.defaultdict(int)
    train_dict = {}

    for caps in tlab:
        train_dict[caps['id']] = ['<BOS> '+clean_str(cap)+' <EOS>' for cap in caps['caption']]

    # build vocabulary
    maxlen = 0
    avglen = 0
    total_seq = 0
    for cid, captions in train_dict.items():
        for caption in captions:
            s_caption = caption.split()
            avglen += len(s_caption)
            total_seq += 1
            if len(s_caption) >= maxlen:
                maxlen = len(s_caption)

            for word in s_caption:
                vocab_dict[word] += 1
    vocabulary = []
    for k, v in sorted(vocab_dict.items(), key=lambda x:x[1], reverse=True):
        if v >= min_count:
            vocabulary.append(k)

    # map sequence to its id
    vocab_processor = VocabularyProcessor(max_document_length=math.ceil(avglen/total_seq)+add_count, vocabulary=vocabulary, drop=True)

    t_number = 0
    min_t = float('Inf')
    avg_t = 0
    for cid, _ in train_dict.items():
        train_c_dat, lengths = vocab_processor.transform(train_dict[cid])
        train_dict[cid] = {'captions':train_c_dat, 'lengths':lengths}
        t_number += len(train_c_dat)
        if len(train_c_dat) < min_t:
            min_t = len(train_c_dat)

    cPickle.dump(train_dict, open(prepro_train_p, 'wb'))
    vocab_processor.save(vocab_path)

    print('init sequence number: {}'.format(total_seq))
    print('maximum sequence length: {}'.format(maxlen))
    print('average sequence length: {}'.format(avglen/total_seq))
    print('drop length: > {}'.format(math.ceil(avglen/total_seq)+add_count))
    print('remaining total train number: {}'.format(t_number))
    print('total video number: {}'.format(len(train_dict)))
    print('minimum train number: {} per video'.format(min_t))
    print('average train number: {} per video'.format(t_number//len(train_dict)))

    return vocab_processor, train_dict
项目:ForumAnalysis    作者:consjuly542    | 项目源码 | 文件源码
def cancel_filter(self, filter_type):
        if filter_type in self.filters_type:
            self.filters_data.pop(self.filters_type.index(filter_type))
            self.filters_type.remove(filter_type)

        self.cur_articles_list = self.articles_list_all.copy()
        for i, f in enumerate(self.filters_type):
            self.add_filter(filter_type=f, filter_data=self.filters_data[i])

        write_current_article_list(self.cur_articles_list)

# index = StatisticsModule(recompute_statistics = True)
# index = StatisticsModule(recompute_statistics = False)

# index = StatisticsModule(recompute_statistics = False).get_graphics()
# index.add_filter(filter_type='law', filter_data = '??????????? ??????')

# # print(len(index.article_index))
# # # for idx, k in enumerate(index.article_index.keys()):
# # #   if idx > 2:
# # #       break
# # #   print (k.to_dict())

# with open("./../data/statistics/current_article_list", "rb") as f:
#   articles = pickle.load(f)

#   for idx, k in enumerate(articles):
#       if idx > 1:
#           break

#       print (k['official_article']['law'])

#       print (k['questions_cnt'], len(get_questions(k['questions_filename'])), \
#           get_questions(k['questions_filename'])[0])
# g1= Guide("PPN.txt", "???????????? ?? ???????")
# g2= Guide("PPVS.txt", "??????? ??????? ?????? ?????")
# g3= Guide("PSP.txt", "???????????? ?? ???????? ????????")
# g4= Guide("PTS.txt", "???????????? ?? ???????? ??????")

# guides_list = [g1, g2, g3, g4]

# # print (g.name, g.articles_link.values(), len(g.articles_link))
# with open("../data/guide_articles/guides_list", "wb") as f:
#     cPickle.dump(guides_list, f, protocol=pickle.HIGHEST_PROTOCOL)
项目:LSH_Memory    作者:RUSH-LAB    | 项目源码 | 文件源码
def write_datafiles(directory, write_file,
                    resize=True, rotate=False,
                    new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
                    first_label=0):
  """Load and preprocess images from a directory and write them to a file.

  Args:
    directory: Directory of alphabet sub-directories.
    write_file: Filename to write to.
    resize: Whether to resize the images.
    rotate: Whether to augment the dataset with rotations.
    new_width: New resize width.
    new_height: New resize height.
    first_label: Label to start with.

  Returns:
    Number of new labels created.
  """

  # these are the default sizes for Omniglot:
  imgwidth = IMAGE_ORIGINAL_SIZE
  imgheight = IMAGE_ORIGINAL_SIZE

  logging.info('Reading the data.')
  images, labels, info = crawl_directory(directory, augment_with_rotations=rotate, first_label=first_label)

  images_np = np.zeros([len(images), imgwidth, imgheight], dtype=np.bool)
  labels_np = np.zeros([len(labels)], dtype=np.uint32)
  for idx in range(len(images)):
    images_np[idx, :, :] = images[idx]
    labels_np[idx] = labels[idx]

  if resize:
    logging.info('Resizing images.')
    resized_images = resize_images(images_np, new_width, new_height)

    logging.info('Writing resized data in float32 format.')
    data = {'images': resized_images,
            'labels': labels_np,
            'info': info}
    with tf.gfile.GFile(write_file, 'w') as f:
        pickle.dump(data, f)
  else:
    logging.info('Writing original sized data in boolean format.')
    data = {'images': images_np,
            'labels': labels_np,
            'info': info}
    with tf.gfile.GFile(write_file, 'w') as f:
        pickle.dump(data, f)

  return len(np.unique(labels_np))