所有数据集都是torch.utils.data.Dataset的子类,即实现了__getitem__和__len__方法。因此,它们都可以传递给torch.utils.data. dataloader,它可以使用torch并行加载多个样本。多处理工人。例如:
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
还好官方上面文字说需要继承Dataset这个抽象类,实现__getitem__和__len__方法就ok了。
class CatDogDataSet(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
pass
def __len__(self):
pass
我知道ImageNet是从网上拉下来zip包解压后处理图片读取图片的,不妨看看ImageNet是如何实现的class ImageNet(ImageFolder):
。ImageFolder
!!!这个类让我有预感我很快就可以copy了。果然datasets.ImageFolder(root)
传入数据根目录且符合下面的格式就可以读取自定义数据集。
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way by default: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
源码中find_classes方法,根据目录名定义classes变量改为classes = list(frozenset([i.split('.')[0] for i in os.listdir(directory)]))
就可以了
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
再看数据集部分
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
"""
源码是判断目录是否与当前target一致,一直则读取这一目录
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
"""
for root, _, fnames in sorted(os.walk(directory, followlinks=True)):
for fname in sorted(fnames):
# TODO: 在此处添加判断,当前文件名是否包含target
if target_class in fname:
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
import os
from typing import Dict, Optional, Tuple, Callable, List, Union, cast
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
classes = list(frozenset([i.split('.')[0] for i in os.listdir(directory)]))
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
See :class:`DatasetFolder` for details.
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
by default.
"""
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
# target_dir = os.path.join(directory, target_class)
# if not os.path.isdir(target_dir):
# continue
for root, _, fnames in sorted(os.walk(directory, followlinks=True)):
for fname in sorted(fnames):
if target_class in fname:
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances
class CatDogLoader(ImageFolder):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super().__init__(root,
transform,
target_transform,
is_valid_file=is_valid_file)
classes, class_to_idx = self.find_classes(self.root)
self.samples = self.make_dataset(self.root, class_to_idx, IMG_EXTENSIONS if is_valid_file is None else None,
is_valid_file)
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
return find_classes(directory)
def make_dataset(
self,
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
if class_to_idx is None:
raise ValueError("The class_to_idx parameter cannot be None.")
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)