Python torch 模块,set_default_tensor_type() 实例源码

我们从Python开源项目中,提取了以下7个代码示例,用于说明如何使用torch.set_default_tensor_type()

项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def default_tensor_type(type):
    type_str = torch.typename(type)

    def decorator(fn):
        @wraps(fn)
        def wrapper(*args, **kwargs):
            old_type = torch.typename(torch.Tensor())
            torch.set_default_tensor_type(type_str)
            try:
                return fn(*args, **kwargs)
            finally:
                torch.set_default_tensor_type(old_type)

        return wrapper

    return decorator
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def default_tensor_type(type):
    type_str = torch.typename(type)

    def decorator(fn):
        @wraps(fn)
        def wrapper(*args, **kwargs):
            old_type = torch.typename(torch.Tensor())
            torch.set_default_tensor_type(type_str)
            try:
                return fn(*args, **kwargs)
            finally:
                torch.set_default_tensor_type(old_type)

        return wrapper

    return decorator
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def default_tensor_type(type):
    type_str = torch.typename(type)

    def decorator(fn):
        @wraps(fn)
        def wrapper(*args, **kwargs):
            old_type = torch.typename(torch.Tensor())
            torch.set_default_tensor_type(type_str)
            try:
                return fn(*args, **kwargs)
            finally:
                torch.set_default_tensor_type(old_type)

        return wrapper

    return decorator
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def default_tensor_type(type):
    type_str = torch.typename(type)

    def decorator(fn):
        @wraps(fn)
        def wrapper(*args, **kwargs):
            old_type = torch.typename(torch.Tensor())
            torch.set_default_tensor_type(type_str)
            try:
                return fn(*args, **kwargs)
            finally:
                torch.set_default_tensor_type(old_type)

        return wrapper

    return decorator
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def default_tensor_type(type):
    type_str = torch.typename(type)
    def decorator(fn):
        @wraps(fn)
        def wrapper(*args, **kwargs):
            old_type = torch.typename(torch.Tensor())
            torch.set_default_tensor_type(type_str)
            try:
                return fn(*args, **kwargs)
            finally:
                torch.set_default_tensor_type(old_type)
        return wrapper
    return decorator
项目:pyro    作者:uber    | 项目源码 | 文件源码
def tensors_default_to(host):
    """
    Context manager to temporarily use Cpu or Cuda tensors in Pytorch.

    :param str host: Either "cuda" or "cpu".
    """
    assert host in ('cpu', 'cuda'), host
    old_module = torch.Tensor.__module__
    name = torch.Tensor.__name__
    new_module = 'torch.cuda' if host == 'cuda' else 'torch'
    torch.set_default_tensor_type('{}.{}'.format(new_module, name))
    try:
        yield
    finally:
        torch.set_default_tensor_type('{}.{}'.format(old_module, name))
项目:end-to-end-negotiator    作者:facebookresearch    | 项目源码 | 文件源码
def use_cuda(enabled, device_id=0):
    """Verifies if CUDA is available and sets default device to be device_id."""
    if not enabled:
        return None
    assert torch.cuda.is_available(), 'CUDA is not available'
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.set_device(device_id)
    return device_id