从零开发短视频电商 PaddleOCR Java推理 (三)优化文本检测模型输入和输出

发布时间:2024年01月15日

背景

PaddleOCR提供了一系列测试图片,你可以通过点击这里来下载。

值得注意的是,PaddleOCR的模型更新速度远远快于DJL,这导致了一些DJL的优化滞后问题。因此,我们需要采取一些策略来跟上PaddleOCR的最新进展。

针对文本检测模型,你可以参考以下资源:

文本检测模型推理,默认使用DB模型的配置参数。

通过参数limit_typedet_limit_side_len来对图片的尺寸进行限制, limit_type可选参数为[max, min], det_limit_size_len 为正整数,一般设置为32 的倍数,比如960。

参数默认设置为limit_type='max', det_limit_side_len=960。表示网络输入图像的最长边不能超过960, 如果超过这个值,会对图像做等宽比的resize操作,确保最长边为det_limit_side_len。 设置为limit_type='min', det_limit_side_len=960 则表示限制图像的最短边为960。

如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值。

DJL原始输入输出

模型的输入和输出在PpWordDetectionTranslator.java中。

Criteria<Image, DetectedObjects> criteria1 = Criteria.builder()
                .optEngine("PaddlePaddle")
                .setTypes(Image.class, DetectedObjects.class)
                .optModelUrls("https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar")
                .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap<String, String>()))
                .build();
public class PpWordDetectionTranslator implements NoBatchifyTranslator<Image, DetectedObjects> {

    private final int maxLength;

    /**
     * Creates the {@link PpWordDetectionTranslator} instance.
     *
     * @param arguments the arguments for the translator
     */
    public PpWordDetectionTranslator(Map<String, ?> arguments) {
        maxLength = ArgumentsUtil.intValue(arguments, "maxLength", 960);
    }

    /** {@inheritDoc} */
    @Override
    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        NDArray result = list.singletonOrThrow();
        ImageFactory factory = ImageFactory.getInstance();
        List<BoundingBox> boxes;
        // faster mechanism
        if ("ai.djl.opencv.OpenCVImageFactory".equals(factory.getClass().getName())) {
            result = result.squeeze(0);
            Image image = factory.fromNDArray(result);
            boxes =
                    image.findBoundingBoxes().parallelStream()
                            .filter(
                                    box -> {
                                        Rectangle rect = (Rectangle) box;
                                        return rect.getWidth() * image.getWidth() > 5
                                                || rect.getHeight() * image.getHeight() > 5;
                                    })
                            .collect(Collectors.toList());
        } else {
            result = result.squeeze().mul(255f).toType(DataType.UINT8, true).neq(0);
            boolean[] flattened = result.toBooleanArray();
            Shape shape = result.getShape();
            int w = (int) shape.get(0);
            int h = (int) shape.get(1);
            boolean[][] grid = new boolean[w][h];
            IntStream.range(0, flattened.length)
                    .parallel()
                    .forEach(i -> grid[i / h][i % h] = flattened[i]);
            boxes = new BoundFinder(grid).getBoxes();
        }
        List<String> names = new ArrayList<>();
        List<Double> probs = new ArrayList<>();
        int boxSize = boxes.size();
        for (int i = 0; i < boxSize; i++) {
            names.add("word");
            probs.add(1.0);
        }
        return new DetectedObjects(names, probs, boxes);
    }

    /** {@inheritDoc} */
    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        NDArray img = input.toNDArray(ctx.getNDManager());
        int h = input.getHeight();
        int w = input.getWidth();
        int[] hw = scale(h, w, maxLength);

        img = NDImageUtils.resize(img, hw[1], hw[0]);
        img = NDImageUtils.toTensor(img);
        img =
                NDImageUtils.normalize(
                        img,
                        new float[] {0.485f, 0.456f, 0.406f},
                        new float[] {0.229f, 0.224f, 0.225f});
        img = img.expandDims(0);
        return new NDList(img);
    }

    private int[] scale(int h, int w, int max) {
        int localMax = Math.max(h, w);
        float scale = 1.0f;
        if (max < localMax) {
            scale = max * 1.0f / localMax;
        }
        // paddle model only take 32-based size
        return resize32(h * scale, w * scale);
    }

    private int[] resize32(double h, double w) {
        double min = Math.min(h, w);
        if (min < 32) {
            h = 32.0 / min * h;
            w = 32.0 / min * w;
        }
        int h32 = (int) h / 32;
        int w32 = (int) w / 32;
        return new int[] {h32 * 32, w32 * 32};
    }
}

DJL输出优化 一

// 原输出如下:
[
	{"class": "word", "probability": 1.00000, "bounds": {"x"=0.121, "y"=0.873, "width"=0.262, "height"=0.021}}
	{"class": "word", "probability": 1.00000, "bounds": {"x"=0.384, "y"=0.854, "width"=0.538, "height"=0.031}}
	{"class": "word", "probability": 1.00000, "bounds": {"x"=0.395, "y"=0.670, "width"=0.335, "height"=0.020}}
	...
	置信度`probability`是写死的,`bounds`是按比例缩小的不对。
	
// 优化后输出如下:	
[
	{"class": "word", "probability": 1.00000, "bounds": {"x"=0.099, "y"=0.852, "width"=0.305, "height"=0.064}}
	{"class": "word", "probability": 1.00000, "bounds": {"x"=0.353, "y"=0.822, "width"=0.600, "height"=0.094}}
	{"class": "word", "probability": 1.00000, "bounds": {"x"=0.376, "y"=0.650, "width"=0.374, "height"=0.059}}
	...
	修复了`bounds`是按比例缩小的不对问题 置信度 没啥用,优化还不如 过滤掉过大或者过小的box

多引入个djl opencv包

  <dependency>
            <groupId>ai.djl.opencv</groupId>
            <artifactId>opencv</artifactId>
            <version>0.25.0</version>
        </dependency>

代码我尝试着翻译python为java但是很痛苦,我们还是在dlj上修改吧,如下:

1.修改模型加载这个地方

Criteria<Image, DetectedObjects> criteria1 = Criteria.builder()
                .optEngine("PaddlePaddle")
                .setTypes(Image.class, DetectedObjects.class)
                .optModelUrls("https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar")
                .optTranslator(new PpWordDetectionTranslator2(new ConcurrentHashMap<String, String>()))
                .build();

2.新增PpWordDetectionTranslator2类,仅修改如下内容

   public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        NDArray result = list.singletonOrThrow();
        ImageFactory factory = ImageFactory.getInstance();
        List<BoundingBox> boxes;
        // faster mechanism
        if ("ai.djl.opencv.OpenCVImageFactory".equals(factory.getClass().getName())) {
            result = result.squeeze(0);
            Image image = factory.fromNDArray(result);
            boxes =
                    image.findBoundingBoxes().parallelStream()
                            .filter(
                                    box -> {
                                        Rectangle rect = (Rectangle) box;
                                        return rect.getWidth() * image.getWidth() > 5
                                                || rect.getHeight() * image.getHeight() > 5;
                                    })
                            .collect(Collectors.toList());
        } else {
            result = result.squeeze().mul(255f).toType(DataType.UINT8, true).neq(0);
            boolean[] flattened = result.toBooleanArray();
            Shape shape = result.getShape();
            int w = (int) shape.get(0);
            int h = (int) shape.get(1);
            boolean[][] grid = new boolean[w][h];
            IntStream.range(0, flattened.length)
                    .parallel()
                    .forEach(i -> grid[i / h][i % h] = flattened[i]);
            boxes = new BoundFinder(grid).getBoxes();
        }
        List<String> names = new ArrayList<>();
        List<Double> probs = new ArrayList<>();
        // 对矩形区域进行扩展处理 start
        List<BoundingBox> extendBoxes = new ArrayList<>();
        for (BoundingBox box : boxes) {
            names.add("word");
            probs.add(1.0);

            Rectangle rect = box.getBounds();
            double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());
            extendBoxes.add(new Rectangle(extended[0], extended[1], extended[2], extended[3]));
        }
        // 对矩形区域进行扩展处理 end
        return new DetectedObjects(names, probs, extendBoxes);
    }

优化前

优化后

优化输入

会多识别出来一些内容。

    private int[] scale(int h, int w, int max) {
        int localMax = Math.max(h, w);
        float ratio = 1.0f;
        if (Math.max(h, w) > max) {
            if (h > w) {
                ratio = (float) max / (float) h;
            } else {
                ratio = (float) max / (float) w;
            }
        }
        // paddle model only take 32-based size
        return resize32(h * ratio, w * ratio);
    }

    private int[] resize32(double h, double w) {

        int resize_h = Math.round((float) h / 32f) * 32;
        int resize_w = Math.round((float) w / 32f) * 32;
        return new int[]{resize_h, resize_w};
    }
文章来源:https://blog.csdn.net/abu935009066/article/details/135606317
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。