Python torch 模块,set_num_threads() 实例源码

我们从Python开源项目中,提取了以下11个代码示例,用于说明如何使用torch.set_num_threads()

项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    global _use_shared_memory
    _use_shared_memory = True

    torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples))
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    global _use_shared_memory
    _use_shared_memory = True

    torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples))
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    global _use_shared_memory
    _use_shared_memory = True

    torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples))
项目:drl.pth    作者:seba-1511    | 项目源码 | 文件源码
def async_update(agent, opt, rank, outputs):
    th.set_num_threads(1)
    # Proceed with training but keeping the current agent
    args, env, _, _ = get_setup(seed_offset=rank)
    is_root = (rank == 0)
    train_rewards = train(args, env, agent, opt, train_update, verbose=is_root)
    if is_root:
        for r in train_rewards:
            outputs.put(r)
项目:drl.pth    作者:seba-1511    | 项目源码 | 文件源码
def init_processes(rank, size, fn, backend='tcp'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    th.set_num_threads(1)
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)
项目:pytorch-dist    作者:apaszke    | 项目源码 | 文件源码
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples))
项目:pytorch-nlp    作者:endymecy    | 项目源码 | 文件源码
def main(args):
    torch.set_num_threads(5)
    if args.method == 'cbow':
        word2vec = Word2Vec(input_file_name=args.input_file_name,
                            output_file_name=args.output_file_name,
                            emb_dimension=args.emb_dimension,
                            batch_size=args.batch_size,
                            # windows_size used by Skip-Gram model
                            window_size=args.window_size,
                            iteration=args.iteration,
                            initial_lr=args.initial_lr,
                            min_count=args.min_count,
                            using_hs=args.using_hs,
                            using_neg=args.using_neg,
                            # context_size used by CBOW model
                            context_size=args.context_size,
                            hidden_size=args.hidden_size,
                            cbow=True,
                            skip_gram=False)
        word2vec.cbow_train()
    elif args.method == 'skip_gram':
        word2vec = Word2Vec(input_file_name=args.input_file_name,
                            output_file_name=args.output_file_name,
                            emb_dimension=args.emb_dimension,
                            batch_size=args.batch_size,
                            # windows_size used by Skip-Gram model
                            window_size=args.window_size,
                            iteration=args.iteration,
                            initial_lr=args.initial_lr,
                            min_count=args.min_count,
                            using_hs=args.using_hs,
                            using_neg=args.using_neg,
                            # context_size used by CBOW model
                            context_size=args.context_size,
                            hidden_size=args.hidden_size,
                            cbow=False,
                            skip_gram=True)
        word2vec.skip_gram_train()
项目:vqa.pytorch    作者:Cadene    | 项目源码 | 文件源码
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples))
项目:ray    作者:ray-project    | 项目源码 | 文件源码
def __init__(self, ob_space, action_space, name="local", summarize=True):
        self.local_steps = 0
        self.summarize = summarize
        self._setup_graph(ob_space, action_space)
        torch.set_num_threads(2)
        self.lock = Lock()
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
    global _use_shared_memory
    _use_shared_memory = True

    # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
    # module's handlers are executed after Python returns from C low-level
    # handlers, likely when the same fatal signal happened again already.
    # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
    _set_worker_signal_handlers()

    torch.set_num_threads(1)
    torch.manual_seed(seed)

    if init_fn is not None:
        init_fn(worker_id)

    while True:
        r = index_queue.get()
        if r is None:
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples))
项目:mxbox    作者:Lyken17    | 项目源码 | 文件源码
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
    global _use_shared_memory
    _use_shared_memory = True

    # torch.set_num_threads(1)
    while True:
        r = index_queue.get()
        if r is None:
            data_queue.put(None)
            break
        idx, batch_indices = r
        try:
            samples = collate_fn([dataset[i] for i in batch_indices])
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples))


# numpy_type_map = {
#     'float64': torch.DoubleTensor,
#     'float32': torch.FloatTensor,
#     'float16': torch.HalfTensor,
#     'int64': torch.LongTensor,
#     'int32': torch.IntTensor,
#     'int16': torch.ShortTensor,
#     'int8': torch.CharTensor,
#     'uint8': torch.ByteTensor,
# }