Python torchvision.datasets 模块,ImageFolder() 实例源码

我们从Python开源项目中,提取了以下12个代码示例,用于说明如何使用torchvision.datasets.ImageFolder()

项目:DeepLearning_PlantDiseases    作者:MarkoArsenovic    | 项目源码 | 文件源码
def load_labels(data_dir,resize=(224,224)):

    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomSizedCrop(max(resize)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }
    dsets = {x: datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms[x])
             for x in ['train']}  
    return (dsets['train'].classes)
项目:DeepLearning_PlantDiseases    作者:MarkoArsenovic    | 项目源码 | 文件源码
def load_labels(data_dir,resize=(224,224)):

    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomSizedCrop(max(resize)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }
    dsets = {x: datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms[x])
             for x in ['train']}  
    return (dsets['train'].classes)
项目:DeepLearning_PlantDiseases    作者:MarkoArsenovic    | 项目源码 | 文件源码
def load_data(resize):

    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomSizedCrop(max(resize)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            #Higher scale-up for inception
            transforms.Scale(int(max(resize)/224*256)),
            transforms.CenterCrop(max(resize)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    data_dir = 'PlantVillage'
    dsets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
             for x in ['train', 'val']}
    dset_loaders = {x: torch.utils.data.DataLoader(dsets[x], batch_size=batch_size,
                                                   shuffle=True)
                    for x in ['train', 'val']}
    dset_sizes = {x: len(dsets[x]) for x in ['train', 'val']}
    dset_classes = dsets['train'].classes

    return dset_loaders['train'], dset_loaders['val']
项目:convNet.pytorch    作者:eladhoffer    | 项目源码 | 文件源码
def get_dataset(name, split='train', transform=None,
                target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH):
    train = (split == 'train')
    root = os.path.join(datasets_path, name)
    if name == 'cifar10':
        return datasets.CIFAR10(root=root,
                                train=train,
                                transform=transform,
                                target_transform=target_transform,
                                download=download)
    elif name == 'cifar100':
        return datasets.CIFAR100(root=root,
                                 train=train,
                                 transform=transform,
                                 target_transform=target_transform,
                                 download=download)
    elif name == 'mnist':
        return datasets.MNIST(root=root,
                              train=train,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'stl10':
        return datasets.STL10(root=root,
                              split=split,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'imagenet':
        if train:
            root = os.path.join(root, 'train')
        else:
            root = os.path.join(root, 'val')
        return datasets.ImageFolder(root=root,
                                    transform=transform,
                                    target_transform=target_transform)
项目:bigBatch    作者:eladhoffer    | 项目源码 | 文件源码
def get_dataset(name, split='train', transform=None,
                target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH):
    train = (split == 'train')
    root = os.path.join(datasets_path, name)
    if name == 'cifar10':
        return datasets.CIFAR10(root=root,
                                train=train,
                                transform=transform,
                                target_transform=target_transform,
                                download=download)
    elif name == 'cifar100':
        return datasets.CIFAR100(root=root,
                                 train=train,
                                 transform=transform,
                                 target_transform=target_transform,
                                 download=download)
    elif name == 'mnist':
        return datasets.MNIST(root=root,
                              train=train,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'stl10':
        return datasets.STL10(root=root,
                              split=split,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'imagenet':
        if train:
            root = os.path.join(root, 'train')
        else:
            root = os.path.join(root, 'val')
        return datasets.ImageFolder(root=root,
                                    transform=transform,
                                    target_transform=target_transform)
项目:pytorch_60min_blitz    作者:kyuhyoung    | 项目源码 | 文件源码
def make_dataloader_torchvison_imagefolder(dir_data, data_transforms, ext_img,
                                           n_img_per_batch, n_worker):

    li_class = prepare_cifar10_dataset(dir_data, ext_img)
    li_set = ['train', 'test']
    dsets = {x: datasets.ImageFolder(join(dir_data, x), data_transforms[x])
             for x in li_set}
    dset_loaders = {x: torch.utils.data.DataLoader(
        dsets[x], batch_size=n_img_per_batch, shuffle=True, num_workers=n_worker) for x in li_set}

    trainloader, testloader = dset_loaders[li_set[0]], dset_loaders[li_set[1]]
    return trainloader, testloader, li_class
项目:pytorch_60min_blitz    作者:kyuhyoung    | 项目源码 | 文件源码
def make_dataloader_torchvison_imagefolder(dir_data, data_transforms, ext_img,
                                           n_img_per_batch, n_worker):

    li_class = prepare_cifar10_dataset(dir_data, ext_img)
    li_set = ['train', 'test']
    dsets = {x: datasets.ImageFolder(join(dir_data, x), data_transforms[x])
             for x in li_set}
    dset_loaders = {x: torch.utils.data.DataLoader(
        dsets[x], batch_size=n_img_per_batch, shuffle=True, num_workers=n_worker) for x in li_set}

    trainloader, testloader = dset_loaders[li_set[0]], dset_loaders[li_set[1]]
    return trainloader, testloader, li_class
项目:gan-error-avoidance    作者:aleju    | 项目源码 | 文件源码
def __init__(self, opt):
        transform_list = []

        if (opt.crop_height > 0) and (opt.crop_width > 0):
            transform_list.append(transforms.CenterCrop(opt.crop_height, crop_width))
        elif opt.crop_size > 0:
            transform_list.append(transforms.CenterCrop(opt.crop_size))

        transform_list.append(transforms.Scale(opt.image_size))
        transform_list.append(transforms.CenterCrop(opt.image_size))

        transform_list.append(transforms.ToTensor())

        if opt.dataset == 'cifar10':
            dataset1 = datasets.CIFAR10(root = opt.dataroot, download = True,
                transform = transforms.Compose(transform_list))
            dataset2 = datasets.CIFAR10(root = opt.dataroot, train = False,
                transform = transforms.Compose(transform_list))
            def get_data(k):
                if k < len(dataset1):
                    return dataset1[k][0]
                else:
                    return dataset2[k - len(dataset1)][0]
        else:
            if opt.dataset in ['imagenet', 'folder', 'lfw']:
                dataset = datasets.ImageFolder(root = opt.dataroot,
                    transform = transforms.Compose(transform_list))
            elif opt.dataset == 'lsun':
                dataset = datasets.LSUN(db_path = opt.dataroot, classes = [opt.lsun_class + '_train'],
                    transform = transforms.Compose(transform_list))
            def get_data(k):
                return dataset[k][0]

        data_index = torch.load(os.path.join(opt.dataroot, 'data_index.pt'))
        train_index = data_index['train']

        self.opt = opt
        self.get_data = get_data
        self.train_index = data_index['train']
        self.counter = 0
项目:pytorch-reverse-gan    作者:yxlao    | 项目源码 | 文件源码
def get_dataloader(opt):
    if opt.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=opt.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Scale(opt.imageScaleSize),
                                       transforms.CenterCrop(opt.imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif opt.dataset == 'lsun':
        dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Scale(opt.imageScaleSize),
                                transforms.CenterCrop(opt.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif opt.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=opt.dataroot, download=True,
                               transform=transforms.Compose([
                                   transforms.Scale(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ])
                               )
    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=int(opt.workers))
    return dataloader
项目:nn_tools    作者:hahnyuan    | 项目源码 | 文件源码
def Imagenet_LMDB_generate(imagenet_dir, output_dir, make_val=False, make_train=False):
    # the imagenet_dir should have direction named 'train' or 'val',with 1000 folders of raw jpeg photos
    train_name = 'imagenet_train_lmdb'
    val_name = 'imagenet_val_lmdb'

    def target_trans(target):
        return target

    if make_val:
        val_lmdb=lmdb_datasets.LMDB_generator(osp.join(output_dir,val_name))
        def trans_val_data(dir):
            tensor = transforms.Compose([
                transforms.Scale(256),
                transforms.CenterCrop(224),
                transforms.ToTensor()
            ])(dir)
            tensor=(tensor.numpy()*255).astype(np.uint8)
            return tensor

        val = datasets.ImageFolder(osp.join(imagenet_dir,'val'), trans_val_data,target_trans)
        val_lmdb.write_classification_lmdb(val, num_per_dataset=DATASET_SIZE)
    if make_train:
        train_lmdb = lmdb_datasets.LMDB_generator(osp.join(output_dir, train_name))
        def trans_train_data(dir):
            tensor = transforms.Compose([
                transforms.Scale(256),
                transforms.ToTensor()
            ])(dir)
            tensor=(tensor.numpy()*255).astype(np.uint8)
            return tensor

        train = datasets.ImageFolder(osp.join(imagenet_dir, 'train'), trans_train_data, target_trans)
        train.imgs=np.random.permutation(train.imgs)

        train_lmdb.write_classification_lmdb(train, num_per_dataset=DATASET_SIZE, write_shape=True)
项目:diracnets    作者:szagoruyko    | 项目源码 | 文件源码
def create_iterator(opt, mode):
    if opt.dataset.startswith('CIFAR'):
        convert = tnt.transform.compose([
            lambda x: x.astype(np.float32),
            T.Normalize([125.3, 123.0, 113.9], [63.0, 62.1, 66.7]),
            lambda x: x.transpose(2,0,1),
            torch.from_numpy,
        ])

        train_transform = tnt.transform.compose([
            T.RandomHorizontalFlip(),
            T.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
            T.RandomCrop(32),
            convert,
        ])

        ds = getattr(datasets, opt.dataset)(opt.dataroot, train=mode, download=True)
        smode = 'train' if mode else 'test'
        ds = tnt.dataset.TensorDataset([getattr(ds, smode + '_data'),
                                        getattr(ds, smode + '_labels')])
        ds = ds.transform({0: train_transform if mode else convert})
        return ds.parallel(batch_size=opt.batchSize, shuffle=mode,
                           num_workers=opt.nthread, pin_memory=True)

    elif opt.dataset == 'ImageNet':

        def cvload(path):
            img = cv2.imread(path, cv2.IMREAD_COLOR)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            return img

        convert = tnt.transform.compose([
            lambda x: x.astype(np.float32) / 255.0,
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            lambda x: x.transpose(2, 0, 1).astype(np.float32),
            torch.from_numpy,
        ])

        print("| setting up data loader...")
        if mode:
            traindir = os.path.join(opt.dataroot, 'train')
            ds = datasets.ImageFolder(traindir, tnt.transform.compose([
                T.RandomSizedCrop(224),
                T.RandomHorizontalFlip(),
                convert,
            ]), loader=cvload)
        else:
            valdir = os.path.join(opt.dataroot, 'val')
            ds = datasets.ImageFolder(valdir, tnt.transform.compose([
                T.Scale(256),
                T.CenterCrop(224),
                convert,
            ]), loader=cvload)

        return torch.utils.data.DataLoader(ds,
                                           batch_size=opt.batchSize, shuffle=mode,
                                           num_workers=opt.nthread, pin_memory=False)
    else:
        raise ValueError('dataset not understood')
项目:MMD-GAN    作者:OctoberChang    | 项目源码 | 文件源码
def get_data(args, train_flag=True):
    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    if args.dataset in ['imagenet', 'folder', 'lfw']:
        dataset = dset.ImageFolder(root=args.dataroot,
                                   transform=transform)

    elif args.dataset == 'lsun':
        dataset = dset.LSUN(db_path=args.dataroot,
                            classes=['bedroom_train'],
                            transform=transform)

    elif args.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=args.dataroot,
                               download=True,
                               train=train_flag,
                               transform=transform)

    elif args.dataset == 'cifar100':
        dataset = dset.CIFAR100(root=args.dataroot,
                                download=True,
                                train=train_flag,
                                transform=transform)

    elif args.dataset == 'mnist':
        dataset = dset.MNIST(root=args.dataroot,
                             download=True,
                             train=train_flag,
                             transform=transform)

    elif args.dataset == 'celeba':
        imdir = 'train' if train_flag else 'val'
        dataroot = os.path.join(args.dataroot, imdir)
        if args.image_size != 64:
            raise ValueError('the image size for CelebA dataset need to be 64!')

        dataset = FolderWithImages(root=dataroot,
                                   input_transform=transforms.Compose([
                                       ALICropAndScale(),
                                       transforms.ToTensor(),
                                       transforms.Normalize(
                                           (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]),
                                   target_transform=transforms.ToTensor()
                                   )
    else:
        raise ValueError("Unknown dataset %s" % (args.dataset))
    return dataset