问题讨论:Too many open files error · Issue #11201 · pytorch/pytorch (github.com)
解决方法:
maskformer
官方代码依赖detectron2
库,我没有安装成功,直接把源码下载下来,然后将里面的detectron2
包拷贝的MaskFormer的项目目录下,可以正常使用。
数据加载器的构建类:detectron2/data/build.py
文件下的build_detection_train_loader
。
数据加载使用到的类所在文件路径:
MapDataset
:detectron2/data/common.pyMaskFormerSemanticDatasetMapper
:mask_former/data/dataset_mappers/mask_former_semantic_dataset_mapper.pyTrainingSampler
: detectron2/data/samplers/distributed_sampler.py流程:
在train_net.py中,找到Trainer的父类DefaultTrainer
,在父类中,会有构建模型、优化器和数据加载器的代码:
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
实际调用的Trainer
中的build_train_loader
方法,首先会创建一个mapper;
然后,将mapper作为参数,调用build_detection_train_loader
;
在build_detection_train_loader
中,将根据_train_loader_from_config
来构建没有传递的参数,比如dataset、mapper、sampler,如果我们有自己的参数,可以在调用build_detection_train_loader
是直接传递进去。
# 传递了mapper
build_detection_train_loader(cfg, mapper=mapper)
后面就是的调用建立真正的数据加载器。
Mapper是由MaskFormerSemanticDatasetMapper
类实现的,具体作用:将图像和标签,转化为Detetron2要求的格式。
maskformer在进行训练时需要:
data={
'image':图像,shape=(3,256,256),
'sem_seg_gt':标签,shape=(256,256),
'classes':当前图像中的所有类别,是一个一维的torch数组,shape=(4),
'masks':每一个类别的区域蒙版,shape=(4,256,256),
}
classes=np.unique(sem_seg_gt)
masks = []
for class_id in classes:
masks.append(sem_seg_gt == class_id)
Mapper就是对图像和标签进行了处理得到训练模型需要的数据及格式。
注意:由于这时的数据格式已经不是普通的torch数组或者numpy数据,而是复杂的json,我们需要重新自定义一个collator函数,告诉程序,如何将一个batch的数据堆叠起来。
这里实现比较简单,直接什么都不处理,也就是不需要每一个项都堆叠,直接以列表形式组织在一起就可以了。
def trivial_batch_collator(batch): """ A batch collator that does nothing. """ return batch
maskformer默认的数据采样器是:TrainerSample,这是一个无限流的采样器,如果你使用for进行迭代获取数据,将会永不终止,不断地随机获取数据索引。
实际使用时,我进行了修改,不想要无限流这个特点。
def __iter__(self):
start = self._rank
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
def _infinite_indices(self):
g = torch.Generator()
g.manual_seed(self._seed)
while True:
if self._shuffle:
yield from torch.randperm(self._size, generator=g).tolist()
else:
yield from torch.arange(self._size).tolist()
# 改之后
def __iter__(self):
return iter(torch.randperm(self._size).tolist())
在这里,我并没有使用官方的数据集格式,我的数据集:
----images
-----001.jpg
-----002.jpg
-----003.jpg
----labels
-----001.png
-----002.png
-----003.png
然后,在构建数据加载器时,自然也没有使用官方的数据集类,而是将自己写的类直接作为参数传递给构建函数,这样后面程序看到dataset不是None,也就不会根据配置文件创建一个官方dataset类。
build_detection_train_loader(cfg, mapper=mapper,dataset=dataset)