Python torch 模块,__version__() 实例源码

我们从Python开源项目中,提取了以下8个代码示例,用于说明如何使用torch.__version__()

项目:allennlp    作者:allenai    | 项目源码 | 文件源码
def log_pytorch_version_info():
    import torch
    logger.info("Pytorch version: %s", torch.__version__)
项目:pyro    作者:uber    | 项目源码 | 文件源码
def parse_torch_version():
    """
    Parses `torch.__version__` into a semver-ish version tuple.
    This is needed to handle subpatch `_n` parts outside of the semver spec.

    :returns: a tuple `(major, minor, patch, extra_stuff)`
    """
    match = re.match(r"(\d\.\d\.\d)(.*)", torch.__version__)
    major, minor, patch = map(int, match.group(1).split("."))
    extra_stuff = match.group(2)
    return major, minor, patch, extra_stuff
项目:ExperimentPackage_PyTorch    作者:ICEORY    | 项目源码 | 文件源码
def paramscheck(self):
        torch_version = torch.__version__
        torch_version_split = torch_version.split("_")

        if torch_version_split[0] != "0.1.10":
            self.drawNetwork = False
            print "|===>DrawNetwork is unsupported by PyTorch with version: ", torch_version

        if self.netType == "LeNet":
            self.save_path = "log_%s_%s_%s/" % (self.netType, self.data_set, self.experimentID)
        else:
            self.save_path = "log_%s_%s_%d_%s/" % (self.netType, self.data_set,
                                                   self.depth, self.experimentID)

        if self.useDefaultSetting:
            print("|===> Use Default Setting")
            if self.data_set == "cifar10" or self.data_set == "cifar100":
                if self.nEpochs == 160:
                    self.LR = 0.5
                    self.lrPolicy = "exp"
                    self.momentum = 0.9
                    self.weightDecay = 1e-4
                    self.step = 2.0
                    self.gamma = math.pow(0.001 / self.LR, 1.0/math.floor(self.nEpochs/self.step))
                else:
                    self.LR = 0.1
                    self.lrPolicy = "multistep"
                    self.momentum = 0.9
                    self.weightDecay = 1e-4
            else:
                assert False, "invalid data set"

        if self.data_set == "cifar10" or self.data_set == "mnist":
            self.nClasses = 10
        elif self.data_set == "cifar100":
            self.nClasses = 100
项目:pyprob    作者:probprog    | 项目源码 | 文件源码
def get_config():
    ret = []
    ret.append(colored('pyprob  {}'.format(pyprob.__version__), 'blue', attrs=['bold']))
    ret.append('PyTorch {}'.format(torch.__version__))
    cpu_info = cpuinfo.get_cpu_info()
    if 'brand' in cpu_info:
        ret.append('CPU           : {}'.format(cpu_info['brand']))
    else:
        ret.append('CPU           : unknown')
    if 'count' in cpu_info:
        ret.append('CPU count     : {0} (logical)'.format(cpu_info['count']))
    else:
        ret.append('CPU count     : unknown')
    if torch.cuda.is_available():
        ret.append('CUDA          : available')
        ret.append('CUDA devices  : {0}'.format(torch.cuda.device_count()))
        if cuda_enabled:
            if cuda_device == -1:
                ret.append('CUDA selected : all')
            else:
                ret.append('CUDA selected : {0}'.format(cuda_device))
    else:
        ret.append('CUDA          : not available')
    if cuda_enabled:
        ret.append('Running on    : CUDA')
    else:
        ret.append('Running on    : CPU')
    return '\n'.join(ret)
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def version(self):
        """ Returns the PyTorch version, as a tuple of (MAJOR, MINOR, PATCH).
        """
        import torch                            # pylint: disable=import-error
        version = torch.__version__
        match = re.match(r'([0-9]+)\.([0-9]+)\.([0-9]+)\.*', version)
        if not match:
            logger.warning('Unable to infer PyTorch version. We '
                'cannot check for version incompatibilities.')
            return (0, 0, 0)
        return tuple(int(x) for x in match.groups())

    ###########################################################################
项目:kur    作者:deepgram    | 项目源码 | 文件源码
def version(self):
        """ Returns the PyTorch version, as a tuple of (MAJOR, MINOR, PATCH).
        """
        import torch                            # pylint: disable=import-error
        version = torch.__version__
        match = re.match(r'([0-9]+)\.([0-9]+)\.([0-9]+)\.*', version)
        if not match:
            logger.warning('Unable to infer PyTorch version. We '
                'cannot check for version incompatibilities.')
            return (0, 0, 0)
        return tuple(int(x) for x in match.groups())

    ###########################################################################
项目:pyprob    作者:probprog    | 项目源码 | 文件源码
def __init__(self, dropout=0.2, cuda=False, cuda_device_id=0, standardize_observes=False, softmax_boost=20, mixture_components=10):
        super(Artifact, self).__init__()

        self.sample_layers = {}
        self.proposal_layers = {}
        self.observe_layer = None
        self.lstm = None

        self.model_name = ''
        self.created = util.get_time_str()
        self.modified = util.get_time_str()
        self.on_cuda = cuda
        self.trained_on = ''
        self.cuda_device_id = cuda_device_id
        self.code_version = pyprob.__version__
        self.pytorch_version = torch.__version__
        self.standardize_observes = standardize_observes
        self.one_hot_address = {}
        self.one_hot_distribution = {}
        self.one_hot_address_dim = None
        self.one_hot_distribution_dim = None
        self.one_hot_address_empty = None
        self.one_hot_distribution_empty = None
        self.address_histogram = {}
        self.trace_length_histogram = {}
        self.valid_size = None
        self.valid_batch = None
        self.lstm_dim = None
        self.lstm_depth = None
        self.lstm_input_dim = None
        self.smp_emb_dim = None
        self.obs_emb = None
        self.obs_emb_dim = None
        self.num_params_history_trace = []
        self.num_params_history_num_params = []
        self.trace_length_min = sys.maxsize
        self.trace_length_max = 0
        self.trace_examples_histogram = {}
        self.trace_examples_addresses = {}
        self.trace_examples_limit = 10000
        self.train_loss_best = None
        self.train_loss_worst = None
        self.valid_loss_best = None
        self.valid_loss_worst = None
        self.valid_history_trace = []
        self.valid_history_loss = []
        self.train_history_trace = []
        self.train_history_loss = []
        self.total_training_seconds = 0
        self.total_iterations = 0
        self.total_traces = 0
        self.updates = 0
        self.optimizer = None
        self.dropout = dropout
        self.softmax_boost = softmax_boost
        self.mixture_components = mixture_components

        self._state_observes = None
        self._state_observes_embedding = None
        self._state_new_trace = True
项目:pyprob    作者:probprog    | 项目源码 | 文件源码
def load_artifact(file_name, cuda=False, device_id=-1):
    try:
        if cuda:
            artifact = torch.load(file_name)
        else:
            artifact = torch.load(file_name, map_location=lambda storage, loc: storage)
    except:
        logger.log_error('load_artifact: Cannot load file')
    if artifact.code_version != pyprob.__version__:
        logger.log()
        logger.log_warning('Different pyprob versions (artifact: {0}, current: {1})'.format(artifact.code_version, pyprob.__version__))
        logger.log()
    if artifact.pytorch_version != torch.__version__:
        logger.log()
        logger.log_warning('Different PyTorch versions (artifact: {0}, current: {1})'.format(artifact.pytorch_version, torch.__version__))
        logger.log()

    # if print_info:
    #     file_size = '{:,}'.format(os.path.getsize(file_name))
    #     log_print('File name             : {0}'.format(file_name))
    #     log_print('File size (Bytes)     : {0}'.format(file_size))
    #     log_print(artifact.get_info())
    #     log_print()

    if cuda:
        if device_id == -1:
            device_id = torch.cuda.current_device()

        if artifact.on_cuda:
            if device_id != artifact.cuda_device_id:
                logger.log_warning('Loading CUDA (device {0}) artifact to CUDA (device {1})'.format(artifact.cuda_device_id, device_id))
                logger.log()
                artifact.move_to_cuda(device_id)
        else:
            logger.log_warning('Loading CPU artifact to CUDA (device {0})'.format(device_id))
            logger.log()
            artifact.move_to_cuda(device_id)
    else:
        if artifact.on_cuda:
            logger.log_warning('Loading CUDA artifact to CPU')
            logger.log()
            artifact.move_to_cpu()

    return artifact