Python torch.utils.data 模块,Dataset() 实例源码

我们从Python开源项目中,提取了以下40个代码示例,用于说明如何使用torch.utils.data.Dataset()

项目:vsepp    作者:fartashf    | 项目源码 | 文件源码
def get_test_loader(split_name, data_name, vocab, crop_size, batch_size,
                    workers, opt):
    dpath = os.path.join(opt.data_path, data_name)
    if opt.data_name.endswith('_precomp'):
        test_loader = get_precomp_loader(dpath, split_name, vocab, opt,
                                         batch_size, False, workers)
    else:
        # Build Dataset Loader
        roots, ids = get_paths(dpath, data_name, opt.use_restval)

        transform = get_transform(data_name, split_name, opt)
        test_loader = get_loader_single(opt.data_name, split_name,
                                        roots[split_name]['img'],
                                        roots[split_name]['cap'],
                                        vocab, transform, ids=ids[split_name],
                                        batch_size=batch_size, shuffle=False,
                                        num_workers=workers,
                                        collate_fn=collate_fn)

    return test_loader
项目:make_dataset    作者:hyzhan    | 项目源码 | 文件源码
def __init__(self, audio_conf, manifest_filepath, labels, normalize=False, augment=False):
        """
        Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
        a comma. Each new line is a different sample. Example below:

        /path/to/audio.wav,/path/to/audio.txt
        ...

        :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
        :param manifest_filepath: Path to manifest csv as describe above
        :param labels: String containing all the possible characters to map to
        :param normalize: Apply standard mean and deviation normalization to audio tensor
        :param augment(default False):  Apply random tempo and gain perturbations
        """
        with open(manifest_filepath) as f:
            ids = f.readlines()
        ids = [x.strip().split(',') for x in ids]
        self.ids = ids
        self.size = len(ids)
        self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
        super(SpectrogramDataset, self).__init__(audio_conf, normalize, augment)
项目:audio    作者:pytorch    | 项目源码 | 文件源码
def __init__(self, root, downsample=True, transform=None, target_transform=None, download=False, dev_mode=False):
        self.root = os.path.expanduser(root)
        self.downsample = downsample
        self.transform = transform
        self.target_transform = target_transform
        self.dev_mode = dev_mode
        self.data = []
        self.labels = []
        self.chunk_size = 1000
        self.num_samples = 0
        self.max_len = 0
        self.cached_pt = 0

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')
        self._read_info()
        self.data, self.labels = torch.load(os.path.join(
            self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
项目:audio    作者:pytorch    | 项目源码 | 文件源码
def __init__(self, root, transform=None, target_transform=None, download=False, dev_mode=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.dev_mode = dev_mode
        self.data = []
        self.labels = []
        self.num_samples = 0
        self.max_len = 0

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')
        self.data, self.labels = torch.load(os.path.join(
            self.root, self.processed_folder, self.processed_file))
项目:c3d_pytorch    作者:whitesnowdrop    | 项目源码 | 文件源码
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train                              # training set or test set

        if download:
            self.download()

        #if not self._check_exists():
        #    raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(root, self.processed_folder, self.test_file))
项目:seq2seq-dataloader    作者:yunjey    | 项目源码 | 文件源码
def get_loader(src_path, trg_path, src_word2id, trg_word2id, batch_size=100):
    """Returns data loader for custom dataset.

    Args:
        src_path: txt file path for source domain.
        trg_path: txt file path for target domain.
        src_word2id: word-to-id dictionary (source domain).
        trg_word2id: word-to-id dictionary (target domain).
        batch_size: mini-batch size.

    Returns:
        data_loader: data loader for custom dataset.
    """
    # build a custom dataset
    dataset = Dataset(src_path, trg_path, src_word2id, trg_word2id)

    # data loader for custome dataset
    # this will return (src_seqs, src_lengths, trg_seqs, trg_lengths) for each iteration
    # please see collate_fn for details
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              collate_fn=collate_fn)

    return data_loader
项目:two-stream-action-recognition    作者:jeffreyhuang1    | 项目源码 | 文件源码
def __getitem__(self, idx):
        #print ('mode:',self.mode,'calling Dataset:__getitem__ @ idx=%d'%idx)

        if self.mode == 'train':
            self.video, nb_clips = self.keys[idx].split('-')
            self.clips_idx = random.randint(1,int(nb_clips))
        elif self.mode == 'val':
            self.video,self.clips_idx = self.keys[idx].split('-')
        else:
            raise ValueError('There are only train and val mode')

        label = self.values[idx]
        label = int(label)-1 
        data = self.stackopf()

        if self.mode == 'train':
            sample = (data,label)
        elif self.mode == 'val':
            sample = (self.video,data,label)
        else:
            raise ValueError('There are only train and val mode')
        return sample
项目:generative_zoo    作者:DL-IT    | 项目源码 | 文件源码
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file))
项目:deepspeech.pytorch    作者:SeanNaren    | 项目源码 | 文件源码
def __init__(self, audio_conf, manifest_filepath, labels, normalize=False, augment=False):
        """
        Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
        a comma. Each new line is a different sample. Example below:

        /path/to/audio.wav,/path/to/audio.txt
        ...

        :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
        :param manifest_filepath: Path to manifest csv as describe above
        :param labels: String containing all the possible characters to map to
        :param normalize: Apply standard mean and deviation normalization to audio tensor
        :param augment(default False):  Apply random tempo and gain perturbations
        """
        with open(manifest_filepath) as f:
            ids = f.readlines()
        ids = [x.strip().split(',') for x in ids]
        self.ids = ids
        self.size = len(ids)
        self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
        super(SpectrogramDataset, self).__init__(audio_conf, normalize, augment)
项目:make_dataset    作者:hyzhan    | 项目源码 | 文件源码
def __init__(self, audio_conf, manifest_filepath, labels, normalize=False, augment=False):
        """
        Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
        a comma. Each new line is a different sample. Example below:

        /path/to/audio.wav,/path/to/audio.txt
        ...

        :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
        :param manifest_filepath: Path to manifest csv as describe above
        :param labels: String containing all the possible characters to map to
        :param normalize: Apply standard mean and deviation normalization to audio tensor
        :param augment(default False):  Apply random tempo and gain perturbations
        """
        with open(manifest_filepath) as f:
            ids = f.readlines()
        ids = [x.strip().split(',') for x in ids]
        self.ids = ids
        self.size = len(ids)
        self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
        super(SpectrogramDataset, self).__init__(audio_conf, normalize, augment)
项目:nnmnkwii    作者:r9y9    | 项目源码 | 文件源码
def test_sequence_wise_torch_data_loader():
    import torch
    from torch.utils import data as data_utils

    X, Y = _get_small_datasets(padded=False)

    class TorchDataset(data_utils.Dataset):
        def __init__(self, X, Y):
            self.X = X
            self.Y = Y

        def __getitem__(self, idx):
            return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

        def __len__(self):
            return len(self.X)

    def __test(X, Y, batch_size):
        dataset = TorchDataset(X, Y)
        loader = data_utils.DataLoader(
            dataset, batch_size=batch_size, num_workers=1, shuffle=True)
        for idx, (x, y) in enumerate(loader):
            assert len(x.shape) == len(y.shape)
            assert len(x.shape) == 3
            print(idx, x.shape, y.shape)

    # Test with batch_size = 1
    yield __test, X, Y, 1
    # Since we have variable length frames, batch size larger than 1 causes
    # runtime error.
    yield raises(RuntimeError)(__test), X, Y, 2

    # For padded dataset, which can be reprensented by (N, T^max, D), batchsize
    # can be any number.
    X, Y = _get_small_datasets(padded=True)
    yield __test, X, Y, 1
    yield __test, X, Y, 2
项目:nnmnkwii    作者:r9y9    | 项目源码 | 文件源码
def test_frame_wise_torch_data_loader():
    import torch
    from torch.utils import data as data_utils

    X, Y = _get_small_datasets(padded=False)

    # Since torch's Dataset (and Chainer, and maybe others) assumes dataset has
    # fixed size length, i.e., implements `__len__` method, we need to know
    # number of frames for each utterance.
    # Sum of the number of frames is the dataset size for frame-wise iteration.
    lengths = np.array([len(x) for x in X], dtype=np.int)

    # For the above reason, we need to explicitly give the number of frames.
    X = MemoryCacheFramewiseDataset(X, lengths, cache_size=len(X))
    Y = MemoryCacheFramewiseDataset(Y, lengths, cache_size=len(Y))

    class TorchDataset(data_utils.Dataset):
        def __init__(self, X, Y):
            self.X = X
            self.Y = Y

        def __getitem__(self, idx):
            return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

        def __len__(self):
            return len(self.X)

    def __test(X, Y, batch_size):
        dataset = TorchDataset(X, Y)
        loader = data_utils.DataLoader(
            dataset, batch_size=batch_size, num_workers=1, shuffle=True)
        for idx, (x, y) in enumerate(loader):
            assert len(x.shape) == 2
            assert len(y.shape) == 2

    yield __test, X, Y, 128
    yield __test, X, Y, 256
项目:pytorch-arda    作者:corenel    | 项目源码 | 文件源码
def __init__(self, root, train=True, transform=None, download=False):
        """Init USPS dataset."""
        # init params
        self.root = os.path.expanduser(root)
        self.filename = "usps_28x28.pkl"
        self.train = train
        # Num of Train = 7438, Num ot Test 1860
        self.transform = transform
        self.dataset_size = None

        # download dataset.
        if download:
            self.download()
        if not self._check_exists():
            raise RuntimeError("Dataset not found." +
                               " You can use download=True to download it")

        self.train_data, self.train_labels = self.load_samples()
        if self.train:
            total_num_samples = self.train_labels.shape[0]
            indices = np.arange(total_num_samples)
            np.random.shuffle(indices)
            self.train_data = self.train_data[indices[0:self.dataset_size], ::]
            self.train_labels = self.train_labels[indices[0:self.dataset_size]]
        self.train_data *= 255.0
        self.train_data = self.train_data.transpose(
            (0, 2, 3, 1))  # convert to HWC
项目:sketchnet    作者:jtoy    | 项目源码 | 文件源码
def __init__(self, parent_ds, offset, length):
        self.parent_ds = parent_ds
        self.offset = offset
        self.length = length
        assert len(parent_ds)>=offset+length, Exception("Parent Dataset not long enough")
        super(PartialDataset, self).__init__()
项目:pytorch-adda    作者:corenel    | 项目源码 | 文件源码
def __init__(self, root, train=True, transform=None, download=False):
        """Init USPS dataset."""
        # init params
        self.root = os.path.expanduser(root)
        self.filename = "usps_28x28.pkl"
        self.train = train
        # Num of Train = 7438, Num ot Test 1860
        self.transform = transform
        self.dataset_size = None

        # download dataset.
        if download:
            self.download()
        if not self._check_exists():
            raise RuntimeError("Dataset not found." +
                               " You can use download=True to download it")

        self.train_data, self.train_labels = self.load_samples()
        if self.train:
            total_num_samples = self.train_labels.shape[0]
            indices = np.arange(total_num_samples)
            np.random.shuffle(indices)
            self.train_data = self.train_data[indices[0:self.dataset_size], ::]
            self.train_labels = self.train_labels[indices[0:self.dataset_size]]
        self.train_data *= 255.0
        self.train_data = self.train_data.transpose(
            (0, 2, 3, 1))  # convert to HWC
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_numpy(self):
        import numpy as np

        class TestDataset(torch.utils.data.Dataset):
            def __getitem__(self, i):
                return np.ones((2, 3, 4)) * i

            def __len__(self):
                return 1000

        loader = DataLoader(TestDataset(), batch_size=12)
        batch = next(iter(loader))
        self.assertIsInstance(batch, torch.DoubleTensor)
        self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_numpy_scalars(self):
        import numpy as np

        class ScalarDataset(torch.utils.data.Dataset):
            def __init__(self, dtype):
                self.dtype = dtype

            def __getitem__(self, i):
                return self.dtype()

            def __len__(self):
                return 4

        dtypes = {
            np.float64: torch.DoubleTensor,
            np.float32: torch.FloatTensor,
            np.float16: torch.HalfTensor,
            np.int64: torch.LongTensor,
            np.int32: torch.IntTensor,
            np.int16: torch.ShortTensor,
            np.int8: torch.CharTensor,
            np.uint8: torch.ByteTensor,
        }
        for dt, tt in dtypes.items():
            dset = ScalarDataset(dt)
            loader = DataLoader(dset, batch_size=2)
            batch = next(iter(loader))
            self.assertIsInstance(batch, tt)
项目:PyTorchText    作者:chenyuntc    | 项目源码 | 文件源码
def __init__(self,train_root,labels_file,type_='char'):
        '''
        Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
        '''
        import json
        with open(labels_file) as f:
            labels_ = json.load(f)

        # embedding_d = np.load(embedding_root)['vector']
        question_d = np.load(train_root)

            # all_data_title,all_data_content =\
        all_char_title,all_char_content=      question_d['title_char'],question_d['content_char']
            # all_data_title,all_data_content =\
        all_word_title,all_word_content=     question_d['title_word'],question_d['content_word']

        self.train_data = (all_char_title[:-20000],all_char_content[:-20000]),( all_word_title[:-20000],all_word_content[:-20000])
        self.val_data = (all_char_title[-20000:],all_char_content[-20000:]), (all_word_title[-20000:],all_word_content[-20000:])
        self.all_num = len(all_char_title)
        # del all_data_title,all_data_content

        self.data_title,self.data_content = self.train_data
        self.len_ = len(self.data_title[0])

        self.index2qid = question_d['index2qid'].item()
        self.l_end=0
        self.labels = labels_['d']
项目:PyTorchText    作者:chenyuntc    | 项目源码 | 文件源码
def __init__(self,train_root,labels_file,type_='char',augument=True):
        '''
        Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
        '''
        import json
        with open(labels_file) as f:
            labels_ = json.load(f)

        # embedding_d = np.load(embedding_root)['vector']
        self.augument=augument
        question_d = np.load(train_root)
        self.type_=type_
        if type_ == 'char':
            all_data_title,all_data_content =\
                 question_d['title_char'],question_d['content_char']

        elif type_ == 'word':
            all_data_title,all_data_content =\
                 question_d['title_word'],question_d['content_word']

        self.train_data = all_data_title[:-200000],all_data_content[:-200000]
        self.val_data = all_data_title[-200000:],all_data_content[-200000:]

        self.all_num = len(all_data_content)
        # del all_data_title,all_data_content

        self.data_title,self.data_content = self.train_data
        self.len_ = len(self.data_title)

        self.index2qid = question_d['index2qid'].item()
        self.l_end=0
        self.labels = labels_['d']

        self.training=True
项目:PyTorchText    作者:chenyuntc    | 项目源码 | 文件源码
def __init__(self,train_root,labels_file,type_='char',fold=0):
        '''
        Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
        '''
        import json
        with open(labels_file) as f:
            labels_ = json.load(f)
        self.fold=fold 
        # embedding_d = np.load(embedding_root)['vector']
        question_d = np.load(train_root)

            # all_data_title,all_data_content =\
        all_char_title,all_char_content=      question_d['title_char'],question_d['content_char']
            # all_data_title,all_data_content =\
        all_word_title,all_word_content=     question_d['title_word'],question_d['content_word']

        self.train_data = (all_char_title[:-200000],all_char_content[:-200000]),( all_word_title[:-200000],all_word_content[:-200000])
        self.val_data = (all_char_title[-200000:],all_char_content[-200000:]), (all_word_title[-200000:],all_word_content[-200000:])
        self.all_num = len(all_char_title)
        # del all_data_title,all_data_content

        self.data_title,self.data_content = self.train_data
        self.len_ = len(self.data_title[0])
        self.training=True
        self.index2qid = question_d['index2qid'].item()
        self.l_end=0
        self.labels = labels_['d']
项目:PyTorchText    作者:chenyuntc    | 项目源码 | 文件源码
def __init__(self,train_root,labels_file,type_='char',fold=0):
        '''
        Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
        '''
        import json
        with open(labels_file) as f:
            labels_ = json.load(f)
        self.fold =fold

        # embedding_d = np.load(embedding_root)['vector']
        question_d = np.load(train_root)
        self.type_=type_
        if type_ == 'char':
            all_data_title,all_data_content =\
                 question_d['title_char'],question_d['content_char']

        elif type_ == 'word':
            all_data_title,all_data_content =\
                 question_d['title_word'],question_d['content_word']

        self.train_data = all_data_title[:-200000],all_data_content[:-200000]
        self.val_data = all_data_title[-200000:],all_data_content[-200000:]

        self.all_num = len(all_data_content)
        # del all_data_title,all_data_content

        self.data_title,self.data_content = self.train_data
        self.len_ = len(self.data_title)

        self.training=True

        self.index2qid = question_d['index2qid'].item()
        self.l_end=0
        self.labels = labels_['d']
项目:sk-torch    作者:mattHawthorn    | 项目源码 | 文件源码
def __init__(self, dataset: Dataset, X_encoder: Opt[Callable[[T1], TensorType]]=None,
                 y_encoder: Opt[Callable[[T2], TensorType]]=None, **kwargs):
        super().__init__(dataset=dataset, **kwargs)
        self.X_encoder = X_encoder
        self.y_encoder = y_encoder
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_numpy(self):
        import numpy as np

        class TestDataset(torch.utils.data.Dataset):
            def __getitem__(self, i):
                return np.ones((2, 3, 4)) * i

            def __len__(self):
                return 1000

        loader = DataLoader(TestDataset(), batch_size=12)
        batch = next(iter(loader))
        self.assertIsInstance(batch, torch.DoubleTensor)
        self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_numpy_scalars(self):
        import numpy as np

        class ScalarDataset(torch.utils.data.Dataset):
            def __init__(self, dtype):
                self.dtype = dtype

            def __getitem__(self, i):
                return self.dtype()

            def __len__(self):
                return 4

        dtypes = {
            np.float64: torch.DoubleTensor,
            np.float32: torch.FloatTensor,
            np.float16: torch.HalfTensor,
            np.int64: torch.LongTensor,
            np.int32: torch.IntTensor,
            np.int16: torch.ShortTensor,
            np.int8: torch.CharTensor,
            np.uint8: torch.ByteTensor,
        }
        for dt, tt in dtypes.items():
            dset = ScalarDataset(dt)
            loader = DataLoader(dset, batch_size=2)
            batch = next(iter(loader))
            self.assertIsInstance(batch, tt)
项目:MatchingNetworks    作者:gitabcworld    | 项目源码 | 文件源码
def __init__(self, root, transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.'
                               + ' You can use download=True to download it')

        self.all_items=find_classes(os.path.join(self.root, self.processed_folder))
        self.idx_classes=index_classes(self.all_items)
项目:MatchingNetworks    作者:gitabcworld    | 项目源码 | 文件源码
def __init__(self, dataroot = '/home/aberenguel/Dataset/miniImagenet', type = 'train',
                 nEpisodes = 1000, classes_per_set=10, samples_per_class=1):

        self.nEpisodes = nEpisodes
        self.classes_per_set = classes_per_set
        self.samples_per_class = samples_per_class
        self.n_samples = self.samples_per_class * self.classes_per_set
        self.n_samplesNShot = 5 # Samples per meta-test. In this case 1 as is OneShot.
        # Transformations to the image
        self.transform = transforms.Compose([filenameToPILImage,
                                             PiLImageResize,
                                             transforms.ToTensor()
                                             ])

        def loadSplit(splitFile):
            dictLabels = {}
            with open(splitFile) as csvfile:
                csvreader = csv.reader(csvfile, delimiter=',')
                next(csvreader, None)
                for i,row in enumerate(csvreader):
                    filename = row[0]
                    label = row[1]
                    if label in dictLabels.keys():
                        dictLabels[label].append(filename)
                    else:
                        dictLabels[label] = [filename]
            return dictLabels

        #requiredFiles = ['train','val','test']
        self.miniImagenetImagesDir = os.path.join(dataroot,'images')
        self.data = loadSplit(splitFile = os.path.join(dataroot,type + '.csv'))
        self.data = collections.OrderedDict(sorted(self.data.items()))
        self.classes_dict = {self.data.keys()[i]:i  for i in range(len(self.data.keys()))}
        self.create_episodes(self.nEpisodes)
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_numpy(self):
        import numpy as np

        class TestDataset(torch.utils.data.Dataset):
            def __getitem__(self, i):
                return np.ones((2, 3, 4)) * i

            def __len__(self):
                return 1000

        loader = DataLoader(TestDataset(), batch_size=12)
        batch = next(iter(loader))
        self.assertIsInstance(batch, torch.DoubleTensor)
        self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_numpy_scalars(self):
        import numpy as np

        class ScalarDataset(torch.utils.data.Dataset):
            def __init__(self, dtype):
                self.dtype = dtype

            def __getitem__(self, i):
                return self.dtype()

            def __len__(self):
                return 4

        dtypes = {
            np.float64: torch.DoubleTensor,
            np.float32: torch.FloatTensor,
            np.float16: torch.HalfTensor,
            np.int64: torch.LongTensor,
            np.int32: torch.IntTensor,
            np.int16: torch.ShortTensor,
            np.int8: torch.CharTensor,
            np.uint8: torch.ByteTensor,
        }
        for dt, tt in dtypes.items():
            dset = ScalarDataset(dt)
            loader = DataLoader(dset, batch_size=2)
            batch = next(iter(loader))
            self.assertIsInstance(batch, tt)
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_numpy(self):
        import numpy as np

        class TestDataset(torch.utils.data.Dataset):
            def __getitem__(self, i):
                return np.ones((2, 3, 4)) * i

            def __len__(self):
                return 1000

        loader = DataLoader(TestDataset(), batch_size=12)
        batch = next(iter(loader))
        self.assertIsInstance(batch, torch.DoubleTensor)
        self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_numpy_scalars(self):
        import numpy as np

        class ScalarDataset(torch.utils.data.Dataset):
            def __init__(self, dtype):
                self.dtype = dtype

            def __getitem__(self, i):
                return self.dtype()

            def __len__(self):
                return 4

        dtypes = {
            np.float64: torch.DoubleTensor,
            np.float32: torch.FloatTensor,
            np.float16: torch.HalfTensor,
            np.int64: torch.LongTensor,
            np.int32: torch.IntTensor,
            np.int16: torch.ShortTensor,
            np.int8: torch.CharTensor,
            np.uint8: torch.ByteTensor,
        }
        for dt, tt in dtypes.items():
            dset = ScalarDataset(dt)
            loader = DataLoader(dset, batch_size=2)
            batch = next(iter(loader))
            self.assertIsInstance(batch, tt)
项目:keita    作者:iwasaki-kenta    | 项目源码 | 文件源码
def __init__(self, root='data/omniglot', transform=None, target_transform=None, download=True):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        if download: self.download()

        assert self._check_exists(), 'Dataset not found. You can use download=True to download it'

        self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
        self.classes = index_classes(self.all_items)
项目:pix2pix-pytorch    作者:1zb    | 项目源码 | 文件源码
def __init__(self, image_dir):
        super(Dataset, self).__init__()
        # self.path = image_dir
        self.input_filenames = glob.glob(os.path.join(image_dir, "*.jpg"))
项目:generative_models    作者:j-min    | 项目源码 | 文件源码
def get_custom_dataset(config):
    dataset = None
    if config.dataset_mode == 'aligned':
        dataset = AlignedDataset()
    elif config.dataset_mode == 'unaligned':
        dataset = UnalignedDataset()
    elif config.dataset_mode == 'single':
        dataset = SingleDataset()
    else:
        raise ValueError("Dataset [%s] not recognized." % config.dataset_mode)

    print("dataset [%s] was created" % (dataset.name()))
    dataset.initialize(config)
    return dataset
项目:vsepp    作者:fartashf    | 项目源码 | 文件源码
def get_loaders(data_name, vocab, crop_size, batch_size, workers, opt):
    dpath = os.path.join(opt.data_path, data_name)
    if opt.data_name.endswith('_precomp'):
        train_loader = get_precomp_loader(dpath, 'train', vocab, opt,
                                          batch_size, True, workers)
        val_loader = get_precomp_loader(dpath, 'dev', vocab, opt,
                                        batch_size, False, workers)
    else:
        # Build Dataset Loader
        roots, ids = get_paths(dpath, data_name, opt.use_restval)

        transform = get_transform(data_name, 'train', opt)
        train_loader = get_loader_single(opt.data_name, 'train',
                                         roots['train']['img'],
                                         roots['train']['cap'],
                                         vocab, transform, ids=ids['train'],
                                         batch_size=batch_size, shuffle=True,
                                         num_workers=workers,
                                         collate_fn=collate_fn)

        transform = get_transform(data_name, 'val', opt)
        val_loader = get_loader_single(opt.data_name, 'val',
                                       roots['val']['img'],
                                       roots['val']['cap'],
                                       vocab, transform, ids=ids['val'],
                                       batch_size=batch_size, shuffle=False,
                                       num_workers=workers,
                                       collate_fn=collate_fn)

    return train_loader, val_loader
项目:triplet-network-pytorch    作者:andreasveit    | 项目源码 | 文件源码
def __init__(self, root,  n_train_triplets=50000, n_test_triplets=10000, train=True, transform=None, target_transform=None, download=False):
        self.root = root

        self.transform = transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
            self.make_triplet_list(n_train_triplets)
            triplets = []
            for line in open(os.path.join(root, self.processed_folder, self.train_triplet_file)):
                triplets.append((int(line.split()[0]), int(line.split()[1]), int(line.split()[2]))) # anchor, close, far
            self.triplets_train = triplets
        else:
            self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
            self.make_triplet_list(n_test_triplets)
            triplets = []
            for line in open(os.path.join(root, self.processed_folder, self.test_triplet_file)):
                triplets.append((int(line.split()[0]), int(line.split()[1]), int(line.split()[2]))) # anchor, close, far
            self.triplets_test = triplets
项目:TreeLSTMSentiment    作者:ttpro1995    | 项目源码 | 文件源码
def read_labels(self, filename):
        with open(filename,'r') as f:
            labels = map(lambda x: float(x), f.readlines())
            labels = torch.Tensor(labels)
        return labels

# Dataset class for SICK dataset
项目:textobjdetection    作者:andfoy    | 项目源码 | 文件源码
def __init__(self, root, transform=None, target_transform=None,
                 train=True, test=False, top=100, group=True,
                 additional_transform=None):
        self.root = root
        self.transform = transform
        self.additional_transform = additional_transform
        self.target_transform = target_transform
        self.top_objects = top
        self.top_folder = 'top_{0}'.format(top)
        self.group = group

        if not osp.exists(self.root):
            raise RuntimeError('Dataset not found ' +
                               'please download it from: ' +
                               'http://visualgenome.org/api/v0/api_home.html')

        if not self.__check_exists():
            self.process_dataset()

        # self.region_objects, self.obj_idx = self.load_region_objects()

        if train:
            train_file = osp.join(self.data_path, self.top_folder,
                                  self.region_train_file)
            with open(train_file, 'rb') as f:
                self.regions = torch.load(f)
        elif test:
            test_file = osp.join(self.data_path, self.top_folder,
                                 self.region_test_file)
            with open(test_file, 'rb') as f:
                self.regions = torch.load(f)
        else:
            val_file = osp.join(self.data_path, self.top_folder,
                                self.region_val_file)
            with open(val_file, 'rb') as f:
                self.regions = torch.load(f)

        if self.group:
            self.regions = self.__group_regions_by_id(self.regions)

        corpus_file = osp.join(self.data_path, self.processed_folder,
                               self.corpus_file)
        with open(corpus_file, 'rb') as f:
            self.corpus = torch.load(f)

        region_obj_file = osp.join(self.data_path, self.top_folder,
                                   self.region_objects_file)
        with open(region_obj_file, 'rb') as f:
            self.region_objects = torch.load(f)

        obj_idx_path = osp.join(self.data_path, self.top_folder,
                                self.obj_idx_file)

        with open(obj_idx_path, 'rb') as f:
            self.obj_idx = torch.load(f)

        self.idx_obj = {v: k for k, v in self.obj_idx.items()}
        # del region_objects
项目:PyTorchText    作者:chenyuntc    | 项目源码 | 文件源码
def __init__(self,train_root,labels_file,type_='char'):
        '''
        Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
        '''
        import json
        with open(labels_file) as f:
            labels_ = json.load(f)

        # embedding_d = np.load(embedding_root)['vector']
        question_d = np.load(train_root)
        self.type_=type_
        if type_ == 'char':
            all_data_title,all_data_content =\
                 question_d['title_char'],question_d['content_char']

        elif type_ == 'word':
            all_data_title,all_data_content =\
                 question_d['title_word'],question_d['content_word']

        self.train_data = all_data_title[:-20000],all_data_content[:-20000]
        self.val_data = all_data_title[-20000:],all_data_content[-20000:]

        self.all_num = len(all_data_content)
        # del all_data_title,all_data_content

        self.data_title,self.data_content = self.train_data
        self.len_ = len(self.data_title)

        self.index2qid = question_d['index2qid'].item()
        self.l_end=0
        self.labels = labels_['d']

    # def augument(self,d):
    #     '''
    #     ?????:   ????
    #     '''
    #     if self.type_=='char':
    #         _index = (-8,8)
    #     else :_index =(-5,5)
    #     r = d.new(d.size()).fill_(0)
    #     index = random.randint(-3,4)
    #     if _index >0:
    #         r[index:] = d[:-index]
    #     else:
    #         r[:-index] = d[index:]
    #     return r

    # def augument(self,d,type_=1):
    #     if type_==1:
    #         return self.shuffle(d)
    #     else :
    #         if self.type_=='char':
    #             return self.dropout(d,p=0.6)
项目:SMASH    作者:ajbrock    | 项目源码 | 文件源码
def eval_parser():
    usage = 'Samples SMASH architectures and tests them on CIFAR.'
    parser = ArgumentParser(description=usage)
    parser.add_argument(
        '--SMASH', type=str, default=None, metavar='FILE',
        help='The SMASH network .pth file to evaluate.')
    parser.add_argument(
        '--batch-size', type=int, default=100,
        help='Images per batch (default: %(default)s)')
    parser.add_argument(
        '--which-dataset', type=str, default='C100',
        help='Which Dataset to train on (default: %(default)s)')
    parser.add_argument(
        '--seed', type=int, default=0,
        help='Random seed to use.')
    parser.add_argument(
        '--validate', action='store_true', default=True,
        help='Perform validation on validation set (ensabled by default)')
    parser.add_argument(
        '--validate-test', action='store_const', dest='validate',
        const='test', help='Evaluate on test set after every epoch.')
    parser.add_argument(
        '--num-random', type=int, default=500,
        help='Number of random architectures to sample (default: %(default)s)')
    parser.add_argument(
        '--num-perturb', type=int, default=100,
        help='Number of random perturbations to sample (default: %(default)s)')
    parser.add_argument(
        '--num-markov', type=int, default=100,
        help='Number of markov steps to take after perturbation (default: %(default)s)')
    parser.add_argument(
        '--perturb-prob', type=float, default=0.05,
        help='Chance of any individual element being perturbed (default: %(default)s)')
    parser.add_argument(
        '--arch-SGD', action='store_true', default=False,
        help='Perturb archs with architectural SGD. (default: %(default)s)')
    parser.add_argument(
        '--fp16', action='store_true', default=False,
        help='Evaluate with half-precision. (default: %(default)s)')
    parser.add_argument(
        '--parallel', action='store_true', default=False,
        help='Evaluate with multiple GPUs. (default: %(default)s)')
    return parser
项目:neural-combinatorial-rl-pytorch    作者:pemami4911    | 项目源码 | 文件源码
def create_dataset(
    problem_size, 
    data_dir):

    def find_or_return_empty(data_dir, problem_size):
        #train_fname1 = os.path.join(data_dir, 'tsp{}.txt'.format(problem_size))
        val_fname1 = os.path.join(data_dir, 'tsp{}_test.txt'.format(problem_size))
        #train_fname2 = os.path.join(data_dir, 'tsp-{}.txt'.format(problem_size))
        val_fname2 = os.path.join(data_dir, 'tsp-{}_test.txt'.format(problem_size))

        if not os.path.isdir(data_dir):
            os.mkdir(data_dir)
        else:
    #         if os.path.exists(train_fname1) and os.path.exists(val_fname1):
    #             return train_fname1, val_fname1
    #         if os.path.exists(train_fname2) and os.path.exists(val_fname2):
    #             return train_fname2, val_fname2
    #     return None, None

    # train, val = find_or_return_empty(data_dir, problem_size)
    # if train is None and val is None:
    #     download_google_drive_file(data_dir,
    #         'tsp', '', problem_size) 
    #     train, val = find_or_return_empty(data_dir, problem_size)

    # return train, val
            if os.path.exists(val_fname1):
                return val_fname1
            if os.path.exists(val_fname2):
                return val_fname2
        return None

    val = find_or_return_empty(data_dir, problem_size)
    if val is None:
        download_google_drive_file(data_dir, 'tsp', '', problem_size)
        val = find_or_return_empty(data_dir, problem_size)

    return val


#######################################
# Dataset
#######################################