pytorch一致数据增强—异用增强

发布时间:2024年01月14日

前作 [1] 介绍了一种用 pytorch 模仿 MONAI 实现多幅图(如:image 与 label)同用 random seed 保证一致变换的写法,核心是 MultiCompose 类和 to_multi 包装函数。不过 [1] 没考虑不同图用不同 augmentation 的情况,如:

  1. ColorJitter 只对 image 做,而不对 label 做;
  2. image 的 resize interpolation 可任选,但 label 只能用 nearest

本篇更新写法,支持各图同用、异用 augmentation。

Code

  • 对比 [1],主要改变是改写 MultiCompose 类,并将 to_multi 吸收入内。
  • MultiCompose 的用法还是和 torchvision.transforms.Compose 几乎一致,不过支持异用 augmentation:只要为各图指定各自的 augmentation 类/函数即可。见下一节例程。
def to_multi():
	"""不用单独的 to_multi 打包了,已并入 MultiCompose"""
	raise NotImplementedError


class MultiCompose:
    """Extension of torchvision.transforms.Compose that accepts multiple inputs
    and ensures the same random seed is applied on each of these inputs at each transforms.
    This can be useful when simultaneously transforming images & segmentation masks.
    """

    # numpy.random.seed range error:
    #   ValueError: Seed must be between 0 and 2**32 - 1
    MIN_SEED = 0 # - 0x8000_0000_0000_0000
    MAX_SEED = min(2**32 - 1, 0xffff_ffff_ffff_ffff)

    def __init__(self, transforms):
        # self.transforms = [to_multi(t) for t in transforms]
        no_op = lambda x: x # i.e. identity function
        self.transforms = []
        for t in transforms:
            if isinstance(t, (tuple, list)):
            	# convert `None` to `no_op` for convenience
                self.transforms.append([no_op if _t is None else _t for _t in t])
            else:
                self.transforms.append(t)

    def __call__(self, *images):
        for t in self.transforms:
            if isinstance(t, (tuple, list)):
                assert len(images) <= len(t) # allow redundant transform
            else:
                t = [t] * len(images)

            _aug_images = []
            _seed = random.randint(self.MIN_SEED, self.MAX_SEED)
            for _im, _t in zip(images, t):
                seed_everything(_seed)
                _aug_images.append(_t(_im))

            images = _aug_images

        if len(images) == 1:
            images = images[0]
        return images

Usage & Test

例程沿用 [1],但改一下 augmentation:

train_trans = MultiCompose([
	# image 用 bilinear,label 用 nearest
    (ResizeZoomPad((224, 256), "bilinear"), ResizeZoomPad((224, 256), "nearest")), # 异用
    transforms.RandomAffine(30, (0.1, 0.1)), # 同用,传一个就行
    transforms.RandomHorizontalFlip(), # 同用
    # ColorJitter 只对 image 做,label 不做(None)
    [transforms.ColorJitter(0.1, 0.2, 0.3, 0.4), None], # 异用
])
  • 效果:

augmented

References

  1. pytorch一致数据增强
文章来源:https://blog.csdn.net/HackerTom/article/details/135587725
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。