Python tqdm.tqdm 模块,write() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tqdm.tqdm.write()

项目:hadan-gcloud    作者:youkpan    | 项目源码 | 文件源码
def loadConversations(self, dirName):
        """
        Args:
            dirName (str): folder to load
        Return:
            array(question, answer): the extracted QA pairs
        """
        conversations = []
        dirList = self.filesInDir(dirName)
        for filepath in tqdm(dirList, "OpenSubtitles data files"):
            if filepath.endswith('gz'):
                try:
                    doc = self.getXML(filepath)
                    conversations.extend(self.genList(doc))
                except ValueError:
                    tqdm.write("Skipping file %s with errors." % filepath)
                except:
                    print("Unexpected error:", sys.exc_info()[0])
                    raise
        return conversations
项目:Parallel.GAMIT    作者:demiangomez    | 项目源码 | 文件源码
def __init__(self, cnn, NetworkConfig, year, doys):

        try:
            self.Name = NetworkConfig['network_id'].lower()

            self.Core = NetClass(cnn, self.Name, NetworkConfig['stn_core'], year, doys)

            self.Secondary = NetClass(cnn, self.Name + '.Secondary', NetworkConfig['stn_list'], year, doys, self.Core.StrStns)

            # create a StationAlias if needed, if not, just assign StationCode
            self.AllStations = []
            for Station in self.Core.Stations + self.Secondary.Stations:
                self.CheckStationCodes(Station)
                if [Station.NetworkCode, Station.StationCode] not in [[stn['NetworkCode'], stn['StationCode']] for stn in self.AllStations]:
                    self.AllStations.append({'NetworkCode': Station.NetworkCode, 'StationCode': Station.StationCode, 'StationAlias': Station.StationAlias})

            self.total_stations = len(self.Core.Stations) + len(self.Secondary.Stations)

            sys.stdout.write('\n >> Total number of stations: %i (including core)\n\n' % (self.total_stations))

        except:
            raise

        return
项目:Parallel.GAMIT    作者:demiangomez    | 项目源码 | 文件源码
def write_error(folder, filename, msg):

    # do append just in case...
    count = 0
    while True:
        try:
            file = open(os.path.join(folder,filename),'a')
            file.write(msg)
            file.close()
            break
        except IOError as e:
            if count < 3:
                count += 1
            else:
                raise IOError(str(e) + ' after 3 retries')
            continue
        except:
            raise

    return
项目:Parallel.GAMIT    作者:demiangomez    | 项目源码 | 文件源码
def output_handle(callback):

    messages = [outmsg.errors for outmsg in callback]

    if len([out_msg for out_msg in messages if out_msg]) > 0:
        tqdm.write(
            ' >> There were unhandled errors during this batch. Please check errors_pyScanArchive.log for details')

    # function to print any error that are encountered during parallel execution
    for msg in messages:
        if msg:
            f = open('errors_pyScanArchive.log','a')
            f.write('ON ' + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + ' an unhandled error occurred:\n')
            f.write(msg + '\n')
            f.write('END OF ERROR =================== \n\n')
            f.close()

    return []
项目:Parallel.GAMIT    作者:demiangomez    | 项目源码 | 文件源码
def PrintStationInfo(cnn, stnlist, short=False):

    for stn in stnlist:
        NetworkCode = stn['NetworkCode']
        StationCode = stn['StationCode']

        try:
            stninfo = pyStationInfo.StationInfo(cnn,NetworkCode,StationCode)

            stninfo_lines = stninfo.return_stninfo().split('\n')

            if short:
                stninfo_lines = [' ' + NetworkCode.upper() + '.' + line[1:110] + ' [...] ' + line[160:] for line in stninfo_lines]
                sys.stdout.write('\n'.join(stninfo_lines) + '\n\n')
            else:
                stninfo_lines = [line for line in stninfo_lines]
                sys.stdout.write('# ' + NetworkCode.upper() + '.' + StationCode.upper() + '\n' + '\n'.join(stninfo_lines) + '\n')

        except pyStationInfo.pyStationInfoException as e:
            sys.stdout.write(str(e))
项目:ngraph    作者:NervanaSystems    | 项目源码 | 文件源码
def __call__(self, transformer, callback_data, phase, data, idx):
        if phase == CallbackPhase.train_pre_:
            self.total_iterations = callback_data['config'].attrs['total_iterations']
            num_intervals = self.total_iterations // self.frequency
            for loss_name in self.interval_loss_comp.output_keys:
                callback_data.create_dataset("cost/{}".format(loss_name), (num_intervals,))
            callback_data.create_dataset("time/loss", (num_intervals,))
        elif phase == CallbackPhase.train_post:
            losses = loop_eval(self.dataset, self.interval_loss_comp)
            tqdm.write("Training complete.  Avg losses: {}".format(losses))
        elif phase == CallbackPhase.minibatch_post and ((idx + 1) % self.frequency == 0):
            start_loss = default_timer()
            interval_idx = idx // self.frequency

            losses = loop_eval(self.dataset, self.interval_loss_comp)

            for loss_name, loss in losses.items():
                callback_data["cost/{}".format(loss_name)][interval_idx] = loss

            callback_data["time/loss"][interval_idx] = (default_timer() - start_loss)
            tqdm.write("Interval {} Iteration {} complete.  Avg losses: {}".format(
                interval_idx + 1, idx + 1, losses))
项目:DeepQA    作者:Conchylicultor    | 项目源码 | 文件源码
def loadConversations(self, dirName):
        """
        Args:
            dirName (str): folder to load
        Return:
            array(question, answer): the extracted QA pairs
        """
        conversations = []
        dirList = self.filesInDir(dirName)
        for filepath in tqdm(dirList, "OpenSubtitles data files"):
            if filepath.endswith('gz'):
                try:
                    doc = self.getXML(filepath)
                    conversations.extend(self.genList(doc))
                except ValueError:
                    tqdm.write("Skipping file %s with errors." % filepath)
                except:
                    print("Unexpected error:", sys.exc_info()[0])
                    raise
        return conversations
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def _check_predictions(self, predictions, ground_truth):
        """
        :param predictions:     [batch_size, max_question_length]
        :param ground_truth:    [batch_size, max_question_length]
        :return:
        """
        p = np.array(predictions)
        g = np.array(ground_truth)
        result = np.sum(np.abs(p - g), axis=-1)
        correct = 0
        for idx, r in enumerate(result):
            if r == 0:
                correct += 1
                # tqdm.write(str(p[r]))
                # tqdm.write(str(g[r]))
                # tqdm.write("======================================")
        return correct
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def _check_predictions(self, predictions, ground_truth):
        """
        :param predictions:     [batch_size, max_question_length]
        :param ground_truth:    [batch_size, max_question_length]
        :return:
        """
        p = np.array(predictions)
        g = np.array(ground_truth)
        result = np.sum(np.abs(p - g), axis=-1)
        correct = 0
        for idx, r in enumerate(result):
            if r == 0:
                correct += 1
                # tqdm.write(str(p[r]))
                # tqdm.write(str(g[r]))
                # tqdm.write("======================================")
        return correct
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def _check_predictions(self, predictions, ground_truth):
        """
        :param predictions:     [batch_size, max_question_length]
        :param ground_truth:    [batch_size, max_question_length]
        :return:
        """
        p = np.array(predictions)
        g = np.array(ground_truth)
        result = np.sum(np.abs(p - g), axis=-1)
        correct = 0
        for idx, r in enumerate(result):
            if r == 0:
                correct += 1
                # tqdm.write(str(p[r]))
                # tqdm.write(str(g[r]))
                # tqdm.write("======================================")
        return correct
项目:DeepLearningAndTensorflow    作者:azheng333    | 项目源码 | 文件源码
def loadConversations(self, dirName):
        """
        Args:
            dirName (str): folder to load
        Return:
            array(question, answer): the extracted QA pairs
        """
        conversations = []
        dirList = self.filesInDir(dirName)
        for filepath in tqdm(dirList, "OpenSubtitles data files"):
            if filepath.endswith('gz'):
                try:
                    doc = self.getXML(filepath)
                    conversations.extend(self.genList(doc))
                except ValueError:
                    tqdm.write("Skipping file %s with errors." % filepath)
                except:
                    print("Unexpected error:", sys.exc_info()[0])
                    raise
        return conversations
项目:odin    作者:imito    | 项目源码 | 文件源码
def pause(self):
    # ====== clear the report ====== #
    if self._last_report is not None:
      nlines = len(self._last_report.split("\n"))
      self.__pb.moveto(-nlines)
      for i in range(nlines):
        Progbar.FP.write('\r')
        Progbar.FP.write(' ' * _environ_cols_wrapper()(Progbar.FP))
        Progbar.FP.write('\r')  # place cursor back at the beginning of line
        self.__pb.moveto(1)
    else:
      nlines = 0
    # ====== clear the bar ====== #
    if self.__pb is not None:
      self.__pb.clear()
      self.__pb.moveto(-nlines)
    # ====== reset the last report ====== #
    # because we already clean everythin, set _last_report=None prevent
    # further moveto(-nlines) in add()
    self._last_report = None
    return self
项目:ARTI-NOAH    作者:Glitch-is    | 项目源码 | 文件源码
def loadConversations(self, dirName):
        """
        Args:
            dirName (str): folder to load
        Return:
            array(question, answer): the extracted QA pairs
        """
        conversations = []
        dirList = self.filesInDir(dirName)
        for filepath in tqdm(dirList, "OpenSubtitles data files"):
            if filepath.endswith('gz'):
                try:
                    doc = self.getXML(filepath)
                    conversations.extend(self.genList(doc))
                except ValueError:
                    tqdm.write("Skipping file %s with errors." % filepath)
                except:
                    print("Unexpected error:", sys.exc_info()[0])
                    raise
        return conversations
项目:cartoframes    作者:CartoDB    | 项目源码 | 文件源码
def normalize_colnames(columns):
    """SQL-normalize columns in `dataframe` to reflect changes made through
    CARTO's SQL API.

    Args:
        columns (list of str): List of column names

    Returns:
        list of str: Normalized column names
    """
    normalized_columns = [norm_colname(c) for c in columns]
    changed_cols = '\n'.join([
        '\033[1m{orig}\033[0m -> \033[1m{new}\033[0m'.format(
            orig=c,
            new=normalized_columns[i])
        for i, c in enumerate(columns)
        if c != normalized_columns[i]])
    if changed_cols != '':
        tqdm.write('The following columns were changed in the CARTO '
                   'copy of this dataframe:\n{0}'.format(changed_cols))

    return normalized_columns
项目:cumin    作者:wikimedia    | 项目源码 | 文件源码
def print_output(output_format, worker):
    """Print the execution results in a specific format.

    Arguments:
        output_format: the output format to use, one of: 'txt', 'json'.
        worker: the Transport worker instance to retrieve the results from.
    """
    if output_format not in OUTPUT_FORMATS:
        raise cumin.CuminError("Got invalid output format '{fmt}', expected one of {allowed}".format(
            fmt=output_format, allowed=OUTPUT_FORMATS))

    out = {}
    for nodeset, output in worker.get_results():
        for node in nodeset:
            if output_format == 'txt':
                out[node] = '\n'.join(['{node}: {line}'.format(node=node, line=line) for line in output.lines()])
            elif output_format == 'json':
                out[node] = output.message()

    if output_format == 'txt':
        for node in sorted(out.keys()):
            tqdm.write(out[node])
    elif output_format == 'json':
        tqdm.write(json.dumps(out, indent=4, sort_keys=True))
项目:skymod    作者:DelusionalLogic    | 项目源码 | 文件源码
def fetch_file(self, uri, filename):
        print("Fetching {}".format(filename))
        # Is the uri cached
        if uri in self.cache:
            tqdm.write("{} found in cache".format(uri))
            return (self.cache.get(uri) / "file", filename)

        with self.cache.atomic_add(uri) as dl_cache:
            for handler in self.handlers:
                if handler.accept(uri):
                    handler.fetch(uri, dl_cache / "file")

        return (self.cache.get(uri) / "file", filename)
项目:Parallel.GAMIT    作者:demiangomez    | 项目源码 | 文件源码
def print_columns(l):

    for a, b, c, d, e, f in zip(l[::6], l[1::6], l[2::6], l[3::6], l[4::6], l[5::6]):
        print('    {:<10}{:<10}{:<10}{:<10}{:<10}{:<}'.format(a, b, c, d, e, f))

    if len(l) % 6 != 0:
        sys.stdout.write('    ')
        for i in range(len(l) - len(l) % 6, len(l)):
            sys.stdout.write('{:<10}'.format(l[i]))
        sys.stdout.write('\n')
项目:Parallel.GAMIT    作者:demiangomez    | 项目源码 | 文件源码
def get_differences(differences):

    err = [diff.error for diff in differences if diff.error]

    # print out any error messages
    for error in err:
        sys.stdout.write(error + '\n')

    return [diff.diff for diff in differences if not diff.diff is None]
项目:Parallel.GAMIT    作者:demiangomez    | 项目源码 | 文件源码
def GetStnGaps(cnn, stnlist, ignore_val, start_date, end_date):

    for stn in stnlist:
        NetworkCode = stn['NetworkCode']
        StationCode = stn['StationCode']

        rs = cnn.query(
            'SELECT * FROM rinex WHERE "NetworkCode" = \'%s\' AND "StationCode" = \'%s\' AND "ObservationSTime" BETWEEN \'%s\' AND \'%s\' ORDER BY "ObservationSTime"' % (NetworkCode, StationCode, start_date.yyyymmdd(), end_date.yyyymmdd()))

        rnxtbl = rs.dictresult()
        gap_begin = None
        gaps = []
        for i, rnx in enumerate(rnxtbl):

            if i > 0:
                d1 = pyDate.Date(year=rnx['ObservationYear'],doy=rnx['ObservationDOY'])
                d2 = pyDate.Date(year=rnxtbl[i-1]['ObservationYear'],doy=rnxtbl[i-1]['ObservationDOY'])

                if d1 != d2 + 1 and not gap_begin:
                    gap_begin = d2 + 1

                if d1 == d2 + 1 and gap_begin:
                    days = ((d2-1).mjd - gap_begin.mjd)+1
                    if days > ignore_val:
                        gaps.append('%s.%s gap in data found %s -> %s (%i days)' % (NetworkCode,StationCode,gap_begin.yyyyddd(),(d2-1).yyyyddd(), days))

                    gap_begin = None

        if gaps:
            sys.stdout.write('\nData gaps in %s.%s follow:\n' % (NetworkCode, StationCode))
            sys.stdout.write('\n'.join(gaps) + '\n')
        else:
            sys.stdout.write('\nNo data gaps found for %s.%s\n' % (NetworkCode, StationCode))
项目:supernovae    作者:astrocatalogs    | 项目源码 | 文件源码
def download_file(url, path, overwrite=False):
    local_filename = url.split('/')[-1]
    # NOTE the stream=True parameter
    local_path = os.path.join(path, local_filename)
    if os.path.isfile(local_path) and not overwrite:
        return local_path
    r = requests.get(url, stream=True)
    with open(local_path, 'wb') as f:
        for chunk in r.iter_content(chunk_size=1024): 
            if chunk: # filter out keep-alive new chunks
                f.write(chunk)
                #f.flush() commented by recommendation from J.F.Sebastian
    return local_path
项目:taskcv-2017-public    作者:VisionLearningGroup    | 项目源码 | 文件源码
def emit(self, record):
        msg = self.format(record)
        tqdm.write(msg)
项目:ngraph    作者:NervanaSystems    | 项目源码 | 文件源码
def ingest_librispeech(input_directory, manifest_file=None, absolute_paths=True):
    """ Finds all .txt files and their indicated .flac files and writes them to an Aeon
    compatible manifest file.

    Arguments:
        input_directory (str): Path to librispeech directory
        manifest_file (str): Path to manifest file to output.
        absolute_paths (bool): Whether audio file paths should be absolute or
                               relative to input_directory.
    """

    if not os.path.isdir(input_directory):
        raise IOError("Data directory does not exist! {}".format(input_directory))

    if manifest_file is None:
        manifest_file = os.path.join(input_directory, manifest_file)

    transcript_files = get_files(input_directory, pattern="*.txt")
    if len(transcript_files) == 0:
        raise IOError("No .txt files were found in {}".format(input_directory))

    tqdm.write("Preparing manifest file...")
    with open(manifest_file, "w") as manifest:
        manifest.write("@FILE\tSTRING\n")
        for tfile in tqdm(transcript_files, unit=" Files", mininterval=.001):
            directory = os.path.dirname(tfile)
            if absolute_paths is False:
                directory = os.path.relpath(directory, input_directory)

            with open(tfile, "r") as fid:
                for line in fid.readlines():
                    id_, transcript = line.split(" ", 1)
                    afile = "{}.flac".format(os.path.join(directory, id_))
                    manifest.write("{}\t{}\n".format(afile, transcript))
项目:ngraph    作者:NervanaSystems    | 项目源码 | 文件源码
def __call__(self, transformer, callback_data, phase, data, idx):
        if phase == CallbackPhase.minibatch_post:
            if ((idx + 1) % self.frequency == 0):
                interval = slice(idx + 1 - self.frequency, idx)
                train_cost = callback_data["cost/train"][interval].mean()
                tqdm.write("Interval {} Iteration {} complete.  Avg Train cost: {}".format(
                    idx // self.frequency + 1, idx + 1, train_cost))
项目:trainer    作者:nutszebra    | 项目源码 | 文件源码
def log_progressbar(sentence):
        tqdm.write(sentence)
        return True
项目:trainer    作者:nutszebra    | 项目源码 | 文件源码
def save_text(data, output):
        """Save as text

        Edited date:
            160626

        Test:
            160626

        Examples:

        ::

            data = ['this', 'is', 'test']
            output = 'test.text'
            self.save_text(data, output)

        Args:
            data (list): data to save
            output (str): output name

        Returns:
            bool: True if successful, False otherwise
        """
        if not type(data) == list:
            data = [data]
        with open(output, 'w') as f:
            for i in six.moves.range(len(data)):
                f.write(data[i] + '\n')
        return True
项目:ml-utils    作者:LinxiFan    | 项目源码 | 文件源码
def write(self, x):
        # Avoid print() second call (useless \n)
        if len(x.rstrip()) > 0:
            # WARNING: must use TqdmPro.write instead of tqdm.write because it's
            # a classmethod affected by __new__
            TqdmPro.write(x, file=self.file)
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def log(self, file, batch, predictions, is_detail=False):
        with open(file, "a") as f:
            string = ""
            for t, p, qid, cv, table_id in zip(batch.ground_truth, predictions, batch.questions_ids, batch.cell_value_length, batch.table_map_ids):
                result = np.sum(np.abs(np.array(p) - np.array(t)), axis=-1)
                string += "=======================\n"
                string += ("id: " + str(qid) + "\n")
                string += ("tid: " + str(table_id) + "\n")
                string += ("max_column: " + str(len(cv)) + "\n")
                string += ("max_cell_value_per_col: " + str(len(cv[0])) + "\n")
                string += ("t: " + (', '.join([str(i) for i in t])) + "\n")
                string += ("p: " + (', '.join([str(i) for i in p])) + "\n")
                string += ("Result: " + str(result == 0) + "\n")
                # string += ("s: " + str(scores) + "\n")
            f.write(string)
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def _epoch_log(self, file, num_epoch, train_accuracy, dev_accuracy, average_loss):
        """
        Log epoch
        :param file:
        :param num_epoch:
        :param train_accuracy:
        :param dev_accuracy:
        :param average_loss:
        :return:
        """
        with open(file, "a") as f:
            f.write("epoch: %d, train_accuracy: %f, dev_accuracy: %f, average_loss: %f\n" % (num_epoch, train_accuracy, dev_accuracy, average_loss))
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def test(self, data_iterator, is_log=False):
        tqdm.write("Testing...")
        total = 0
        correct = 0
        file = os.path.join(self._result_log_base_path, "test_" + self._curr_time + ".log")
        for i in tqdm(range(data_iterator.batch_per_epoch)):
            batch = data_iterator.get_batch()
            predictions, feed_dict = self._test_model.predict(batch)
            predictions = self._session.run(predictions, feed_dict=feed_dict)

            correct += self._check_predictions(
                predictions=predictions,
                ground_truth=batch.ground_truth
            )

            total += batch.size

            if is_log:
                self.log(
                    file=file,
                    batch=batch,
                    predictions=predictions
                )

        accuracy = float(correct)/float(total)
        tqdm.write("test_acc: %f" % accuracy)
        return accuracy
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def log(self, file, batch, predictions, is_detail=False):
        with open(file, "a") as f:
            string = ""
            for t, p, qid, cv, table_id in zip(batch.ground_truth, predictions, batch.questions_ids, batch.cell_value_length, batch.table_map_ids):
                result = np.sum(np.abs(np.array(p) - np.array(t)), axis=-1)
                string += "=======================\n"
                string += ("id: " + str(qid) + "\n")
                string += ("tid: " + str(table_id) + "\n")
                string += ("max_column: " + str(len(cv)) + "\n")
                string += ("max_cell_value_per_col: " + str(len(cv[0])) + "\n")
                string += ("t: " + (', '.join([str(i) for i in t])) + "\n")
                string += ("p: " + (', '.join([str(i) for i in p])) + "\n")
                string += ("Result: " + str(result == 0) + "\n")
                # string += ("s: " + str(scores) + "\n")
            f.write(string)
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def _epoch_log(self, file, num_epoch, train_accuracy, dev_accuracy, average_loss):
        """
        Log epoch
        :param file:
        :param num_epoch:
        :param train_accuracy:
        :param dev_accuracy:
        :param average_loss:
        :return:
        """
        with open(file, "a") as f:
            f.write("epoch: %d, train_accuracy: %f, dev_accuracy: %f, average_loss: %f\n" % (num_epoch, train_accuracy, dev_accuracy, average_loss))
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def test(self, data_iterator, is_log=False):
        tqdm.write("Testing...")
        total = 0
        correct = 0
        file = os.path.join(self._result_log_base_path, "test_" + self._curr_time + ".log")
        for i in tqdm(range(data_iterator.batch_per_epoch)):
            batch = data_iterator.get_batch()
            predictions, feed_dict = self._test_model.predict(batch)
            predictions = self._session.run(predictions, feed_dict=feed_dict)

            correct += self._check_predictions(
                predictions=predictions,
                ground_truth=batch.ground_truth
            )

            total += batch.size

            if is_log:
                self.log(
                    file=file,
                    batch=batch,
                    predictions=predictions
                )

        accuracy = float(correct)/float(total)
        tqdm.write("test_acc: %f" % accuracy)
        return accuracy
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def _epoch_log(self, file, num_epoch, train_accuracy, dev_accuracy, average_loss):
        """
        Log epoch
        :param file:
        :param num_epoch:
        :param train_accuracy:
        :param dev_accuracy:
        :param average_loss:
        :return:
        """
        with open(file, "a") as f:
            f.write("epoch: %d, train_accuracy: %f, dev_accuracy: %f, average_loss: %f\n" % (num_epoch, train_accuracy, dev_accuracy, average_loss))
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def test(self, data_iterator, is_log=False):
        tqdm.write("Testing...")
        total = 0
        correct = 0
        file = os.path.join(self._result_log_base_path, "test_" + self._curr_time + ".log")
        for i in tqdm(range(data_iterator.batch_per_epoch)):
            batch = data_iterator.get_batch()
            predictions, feed_dict = self._test_model.predict(batch)
            predictions = self._session.run(predictions, feed_dict=feed_dict)

            correct += self._check_predictions(
                predictions=predictions,
                ground_truth=batch.ground_truth
            )

            total += batch.size

            if is_log:
                self.log(
                    file=file,
                    batch=batch,
                    predictions=predictions
                )

        accuracy = float(correct)/float(total)
        tqdm.write("test_acc: %f" % accuracy)
        return accuracy
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def log(self, file, batch, tag_predictions, segment_length_predictions):

        unfold_predictions, unfold_ground_truth = self._process_predictions(
            tag_predictions=tag_predictions,
            segment_length_predictions=segment_length_predictions,
            ground_truth=batch.ground_truth,
            ground_truth_segmentation_length=batch.ground_truth_segmentation_length,
            ground_truth_segment_length=batch.ground_truth_segment_length,
            question_length=batch.questions_length
        )

        with open(file, "a") as f:
            string = ""
            for tt, ts, pt, ps, qid, cv, table_id, unfold_p, unfold_t in zip(
                    batch.ground_truth,
                    batch.ground_truth_segment_length,
                    tag_predictions,
                    segment_length_predictions,
                    batch.questions_ids,
                    batch.cell_value_length,
                    batch.table_map_ids,
                    unfold_predictions,
                    unfold_ground_truth
            ):
                result = np.sum(np.abs(np.array(unfold_p) - np.array(unfold_t)), axis=-1)
                string += "=======================\n"
                string += ("id: " + str(qid) + "\n")
                string += ("tid: " + str(table_id) + "\n")
                string += ("max_column: " + str(len(cv)) + "\n")
                string += ("max_cell_value_per_col: " + str(len(cv[0])) + "\n")
                string += ("unfold_t: " + (', '.join([str(i) for i in unfold_t])) + "\n")
                string += ("unfold_p: " + (', '.join([str(i) for i in unfold_p])) + "\n")
                string += ("ts: " + (', '.join([str(i) for i in ts])) + "\n")
                string += ("tt: " + (', '.join([str(i) for i in tt])) + "\n")
                string += ("pt: " + (', '.join([str(i) for i in pt])) + "\n")
                string += ("ps: " + (', '.join([str(i) for i in ps])) + "\n")
                string += ("Result: " + str(result == 0) + "\n")
                # string += ("s: " + str(scores) + "\n")
            f.write(string)
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def _epoch_log(self, file, num_epoch, train_accuracy, dev_accuracy, average_loss):
        """
        Log epoch
        :param file:
        :param num_epoch:
        :param train_accuracy:
        :param dev_accuracy:
        :param average_loss:
        :return:
        """
        with open(file, "a") as f:
            f.write("epoch: %d, train_accuracy: %f, dev_accuracy: %f, average_loss: %f\n" % (
            num_epoch, train_accuracy, dev_accuracy, average_loss))
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def test(self, data_iterator, is_log=False):
        tqdm.write("Testing...")
        total = 0
        correct = 0
        file = os.path.join(self._result_log_base_path, "test_" + self._curr_time + ".log")
        for i in tqdm(range(data_iterator.batch_per_epoch)):
            batch = data_iterator.get_batch()
            tag_predictions, segment_length_predictions, feed_dict = self._test_model.predict(batch)
            tag_predictions, segment_length_predictions = self._session.run(
                (tag_predictions, segment_length_predictions,),
                feed_dict=feed_dict
            )

            correct += self._check_predictions(
                tag_predictions=tag_predictions,
                segment_length_predictions=segment_length_predictions,
                ground_truth=batch.ground_truth,
                ground_truth_segment_length=batch.ground_truth_segment_length,
                ground_truth_segmentation_length=batch.ground_truth_segmentation_length,
                question_length=batch.questions_length
            )

            total += batch.size

            if is_log:
                self.log(
                    file=file,
                    batch=batch,
                    tag_predictions=tag_predictions,
                    segment_length_predictions=segment_length_predictions
                )

        accuracy = float(correct) / float(total)
        tqdm.write("test_acc: %f" % accuracy)
        return accuracy
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def log(self, file, batch, predictions, is_detail=False):
        with open(file, "a") as f:
            string = ""
            for t, p, qid, cv, table_id in zip(batch.ground_truth, predictions, batch.questions_ids, batch.cell_value_length, batch.table_map_ids):
                result = np.sum(np.abs(np.array(p) - np.array(t)), axis=-1)
                string += "=======================\n"
                string += ("id: " + str(qid) + "\n")
                string += ("tid: " + str(table_id) + "\n")
                string += ("max_column: " + str(len(cv)) + "\n")
                string += ("max_cell_value_per_col: " + str(len(cv[0])) + "\n")
                string += ("t: " + (', '.join([str(i) for i in t])) + "\n")
                string += ("p: " + (', '.join([str(i) for i in p])) + "\n")
                string += ("Result: " + str(result == 0) + "\n")
                # string += ("s: " + str(scores) + "\n")
            f.write(string)
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def _epoch_log(self, file, num_epoch, train_accuracy, dev_accuracy, average_loss):
        """
        Log epoch
        :param file:
        :param num_epoch:
        :param train_accuracy:
        :param dev_accuracy:
        :param average_loss:
        :return:
        """
        with open(file, "a") as f:
            f.write("epoch: %d, train_accuracy: %f, dev_accuracy: %f, average_loss: %f\n" % (num_epoch, train_accuracy, dev_accuracy, average_loss))
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def test(self, data_iterator, is_log=False):
        tqdm.write("Testing...")
        total = 0
        correct = 0
        file = os.path.join(self._result_log_base_path, "test_" + self._curr_time + ".log")
        for i in tqdm(range(data_iterator.batch_per_epoch)):
            batch = data_iterator.get_batch()
            predictions, feed_dict = self._test_model.predict(batch)
            predictions = self._session.run(predictions, feed_dict=feed_dict)

            correct += self._check_predictions(
                predictions=predictions,
                ground_truth=batch.ground_truth
            )

            total += batch.size

            if is_log:
                self.log(
                    file=file,
                    batch=batch,
                    predictions=predictions
                )

        accuracy = float(correct)/float(total)
        tqdm.write("test_acc: %f" % accuracy)
        return accuracy
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def log(self, file, batch, tag_predictions, segment_length_predictions):

        unfold_predictions, unfold_ground_truth = self._process_predictions(
            tag_predictions=tag_predictions,
            segment_length_predictions=segment_length_predictions,
            ground_truth=batch.ground_truth,
            ground_truth_segmentation_length=batch.ground_truth_segmentation_length,
            ground_truth_segment_length=batch.ground_truth_segment_length,
            question_length=batch.questions_length
        )

        with open(file, "a") as f:
            string = ""
            for tt, ts, pt, ps, qid, cv, table_id, unfold_p, unfold_t in zip(
                    batch.ground_truth,
                    batch.ground_truth_segment_length,
                    tag_predictions,
                    segment_length_predictions,
                    batch.questions_ids,
                    batch.cell_value_length,
                    batch.table_map_ids,
                    unfold_predictions,
                    unfold_ground_truth
            ):
                result = np.sum(np.abs(np.array(unfold_p) - np.array(unfold_t)), axis=-1)
                string += "=======================\n"
                string += ("id: " + str(qid) + "\n")
                string += ("tid: " + str(table_id) + "\n")
                string += ("max_column: " + str(len(cv)) + "\n")
                string += ("max_cell_value_per_col: " + str(len(cv[0])) + "\n")
                string += ("unfold_t: " + (', '.join([str(i) for i in unfold_t])) + "\n")
                string += ("unfold_p: " + (', '.join([str(i) for i in unfold_p])) + "\n")
                string += ("ts: " + (', '.join([str(i) for i in ts])) + "\n")
                string += ("tt: " + (', '.join([str(i) for i in tt])) + "\n")
                string += ("pt: " + (', '.join([str(i) for i in pt])) + "\n")
                string += ("ps: " + (', '.join([str(i) for i in ps])) + "\n")
                string += ("Result: " + str(result == 0) + "\n")
                # string += ("s: " + str(scores) + "\n")
            f.write(string)
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def test(self, data_iterator, is_log=False):
        tqdm.write("Testing...")
        total = 0
        correct = 0
        file = os.path.join(self._result_log_base_path, "test_" + self._curr_time + ".log")
        for i in tqdm(range(data_iterator.batch_per_epoch)):
            batch = data_iterator.get_batch()
            tag_predictions, segment_length_predictions, feed_dict = self._test_model.predict(batch)
            tag_predictions, segment_length_predictions = self._session.run(
                (tag_predictions, segment_length_predictions,),
                feed_dict=feed_dict
            )

            correct += self._check_predictions(
                tag_predictions=tag_predictions,
                segment_length_predictions=segment_length_predictions,
                ground_truth=batch.ground_truth,
                ground_truth_segment_length=batch.ground_truth_segment_length,
                ground_truth_segmentation_length=batch.ground_truth_segmentation_length,
                question_length=batch.questions_length
            )

            total += batch.size

            if is_log:
                self.log(
                    file=file,
                    batch=batch,
                    tag_predictions=tag_predictions,
                    segment_length_predictions=segment_length_predictions
                )

        accuracy = float(correct) / float(total)
        tqdm.write("test_acc: %f" % accuracy)
        return accuracy
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def log(self, file, batch, tag_predictions, segment_length_predictions):

        unfold_predictions, unfold_ground_truth = self._process_predictions(
            tag_predictions=tag_predictions,
            segment_length_predictions=segment_length_predictions,
            ground_truth=batch.ground_truth,
            ground_truth_segmentation_length=batch.ground_truth_segmentation_length,
            ground_truth_segment_length=batch.ground_truth_segment_length,
            question_length=batch.questions_length
        )

        with open(file, "a") as f:
            string = ""
            for tt, ts, pt, ps, qid, cv, table_id, unfold_p, unfold_t in zip(
                    batch.ground_truth,
                    batch.ground_truth_segment_length,
                    tag_predictions,
                    segment_length_predictions,
                    batch.questions_ids,
                    batch.cell_value_length,
                    batch.table_map_ids,
                    unfold_predictions,
                    unfold_ground_truth
            ):
                result = np.sum(np.abs(np.array(unfold_p) - np.array(unfold_t)), axis=-1)
                string += "=======================\n"
                string += ("id: " + str(qid) + "\n")
                string += ("tid: " + str(table_id) + "\n")
                string += ("max_column: " + str(len(cv)) + "\n")
                string += ("max_cell_value_per_col: " + str(len(cv[0])) + "\n")
                string += ("unfold_t: " + (', '.join([str(i) for i in unfold_t])) + "\n")
                string += ("unfold_p: " + (', '.join([str(i) for i in unfold_p])) + "\n")
                string += ("ts: " + (', '.join([str(i) for i in ts])) + "\n")
                string += ("tt: " + (', '.join([str(i) for i in tt])) + "\n")
                string += ("pt: " + (', '.join([str(i) for i in pt])) + "\n")
                string += ("ps: " + (', '.join([str(i) for i in ps])) + "\n")
                string += ("Result: " + str(result == 0) + "\n")
                # string += ("s: " + str(scores) + "\n")
            f.write(string)
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def _epoch_log(self, file, num_epoch, train_accuracy, dev_accuracy, average_loss):
        """
        Log epoch
        :param file:
        :param num_epoch:
        :param train_accuracy:
        :param dev_accuracy:
        :param average_loss:
        :return:
        """
        with open(file, "a") as f:
            f.write("epoch: %d, train_accuracy: %f, dev_accuracy: %f, average_loss: %f\n" % (
            num_epoch, train_accuracy, dev_accuracy, average_loss))
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def test(self, data_iterator, is_log=False):
        tqdm.write("Testing...")
        total = 0
        correct = 0
        file = os.path.join(self._result_log_base_path, "test_" + self._curr_time + ".log")
        for i in tqdm(range(data_iterator.batch_per_epoch)):
            batch = data_iterator.get_batch()
            tag_predictions, segment_length_predictions, feed_dict = self._test_model.predict(batch)
            tag_predictions, segment_length_predictions = self._session.run(
                (tag_predictions, segment_length_predictions,),
                feed_dict=feed_dict
            )

            correct += self._check_predictions(
                tag_predictions=tag_predictions,
                segment_length_predictions=segment_length_predictions,
                ground_truth=batch.ground_truth,
                ground_truth_segment_length=batch.ground_truth_segment_length,
                ground_truth_segmentation_length=batch.ground_truth_segmentation_length,
                question_length=batch.questions_length
            )

            total += batch.size

            if is_log:
                self.log(
                    file=file,
                    batch=batch,
                    tag_predictions=tag_predictions,
                    segment_length_predictions=segment_length_predictions
                )

        accuracy = float(correct) / float(total)
        tqdm.write("test_acc: %f" % accuracy)
        return accuracy
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def _epoch_log(self, file, num_epoch, train_accuracy, dev_accuracy, average_loss):
        """
        Log epoch
        :param file:
        :param num_epoch:
        :param train_accuracy:
        :param dev_accuracy:
        :param average_loss:
        :return:
        """
        with open(file, "a") as f:
            f.write("epoch: %d, train_accuracy: %f, dev_accuracy: %f, average_loss: %f\n" % (num_epoch, train_accuracy, dev_accuracy, average_loss))
项目:entity_binding    作者:JasperGuo    | 项目源码 | 文件源码
def test(self, data_iterator, is_log=False):
        tqdm.write("Testing...")
        total = 0
        correct = 0
        file = os.path.join(self._result_log_base_path, "test_" + self._curr_time + ".log")
        for i in tqdm(range(data_iterator.batch_per_epoch)):
            batch = data_iterator.get_batch()
            predictions, feed_dict = self._test_model.predict(batch)
            predictions = self._session.run(predictions, feed_dict=feed_dict)

            correct += self._check_predictions(
                predictions=predictions,
                ground_truth=batch.ground_truth
            )

            total += batch.size

            if is_log:
                self.log(
                    file=file,
                    batch=batch,
                    predictions=predictions
                )

        accuracy = float(correct)/float(total)
        tqdm.write("test_acc: %f" % accuracy)
        return accuracy
项目:DeepPoseComparison    作者:ynaka81    | 项目源码 | 文件源码
def write(self, log):
        """ Write log. """
        tqdm.write(log)
        tqdm.write(log, file=self.file)
        self.file.flush()
        self.logs.append(log)
项目:DeepPoseComparison    作者:ynaka81    | 项目源码 | 文件源码
def load_state_dict(self, state_dict):
        """ Loads the logger state. """
        self.logs = state_dict['logs']
        # write logs.
        tqdm.write(self.logs[-1])
        for log in self.logs:
            tqdm.write(log, file=self.file)
项目:DeepPoseComparison    作者:ynaka81    | 项目源码 | 文件源码
def _train(self, model, optimizer, train_iter, log_interval, logger, start_time):
        model.train()
        for iteration, batch in enumerate(tqdm(train_iter, desc='this epoch'), 1):
            image, pose, visibility = Variable(batch[0]), Variable(batch[1]), Variable(batch[2])
            if self.gpu:
                image, pose, visibility = image.cuda(), pose.cuda(), visibility.cuda()
            optimizer.zero_grad()
            output = model(image)
            loss = mean_squared_error(output, pose, visibility, self.use_visibility)
            loss.backward()
            optimizer.step()
            if iteration % log_interval == 0:
                log = 'elapsed_time: {0}, loss: {1}'.format(time.time() - start_time, loss.data[0])
                logger.write(log)