从零开发短视频电商 PaddleOCR Java推理 (四)优化OCR工具类

发布时间:2024年01月16日

从零开发短视频电商 PaddleOCR Java推理 (四)优化OCR工具类

在这里插入图片描述

参考https://github.com/mymagicpower/AIAS/blob/9dc3c65d07568087ac71453de9070a416eb4e1d0/1_image_sdks/ocr_v4_sdk/src/main/java/top/aias/ocr/OcrV4RecExample.java

import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.paddlepaddle.zoo.cv.imageclassification.PpWordRotateTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import lombok.SneakyThrows;

import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

public class PPOCR {
    static ZooModel<Image, DetectedObjects> detectionModel = null;
    static ZooModel<Image, Classifications> rotateModel = null;
    static ZooModel<Image, String> recognitionModel;

    static {
        try {
            Criteria<Image, DetectedObjects> detectionCriteria = Criteria.builder()
                    .optEngine("PaddlePaddle")
                    .setTypes(Image.class, DetectedObjects.class)
                    .optModelUrls("https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/ch_PP-OCRv4_det_infer.tar")
                    .optTranslator(new PpWordDetectionTranslator2(new ConcurrentHashMap<String, String>()))
                    .build();
            detectionModel = detectionCriteria.loadModel();

            Criteria<Image, Classifications> rotateCriteria = Criteria.builder()
                    .optEngine("PaddlePaddle")
                    .setTypes(Image.class, Classifications.class)
                    .optModelUrls("https://resources.djl.ai/test-models/paddleOCR/mobile/cls.zip")
                    .optTranslator(new PpWordRotateTranslator())
                    .build();
            rotateModel = rotateCriteria.loadModel();

            Criteria<Image, String> recognitionCriteria = Criteria.builder()
                    .optEngine("PaddlePaddle")
                    .setTypes(Image.class, String.class)
                    .optModelPath(Path.of("C:\\laker-2"))
                    .optTranslator(new PpWordRecognitionTranslator2())
                    .build();
            recognitionModel = recognitionCriteria.loadModel();


        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static void main(String[] args) {
        doOCRFromUrl("https://img-blog.csdnimg.cn/direct/96de53d999c64c2589d0ab6a630e59d6.png");
//        doOCRFromFile("C:\\laker\\demo3\\1.png");
//        doOCRFromUrl("https://resources.djl.ai/images/flight_ticket.jpg");
    }

    @SneakyThrows
    public static String doOCRFromUrl(String url) {
        Image img = ImageFactory.getInstance().fromUrl(url);
        return doOCR(img);
    }

    @SneakyThrows
    public static String doOCRFromFile(String path) {
        Image img = ImageFactory.getInstance().fromFile(Path.of(path));
        return doOCR(img);
    }

    @SneakyThrows
    public static String doOCR(Image img) {


        List<DetectedObjects.DetectedObject> boxes = detection(img);

        List<String> names = new ArrayList<>();
        List<Double> prob = new ArrayList<>();
        List<BoundingBox> rect = new ArrayList<>();

        BoundingBox firstBox = boxes.get(0).getBoundingBox();
        List<ArrayList<String>> lines = new ArrayList<>();
        List<String> line = new ArrayList<>();
        lines.add((ArrayList) line);
        for (int i = 0; i < boxes.size(); i++) {
            System.out.println(boxes.get(i).getBoundingBox());
            BoundingBox tmpBox = boxes.get(i).getBoundingBox();
            double y1 = firstBox.getBounds().getY();
            double y2 = tmpBox.getBounds().getY();
            double dis = Math.abs(y2 - y1) * img.getHeight();

            Image subImg = getSubImage(img, boxes.get(i).getBoundingBox());
            if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {
                subImg = rotateImg(subImg);
            }
            Classifications.Classification result = getRotateResult(subImg);
            if ("Rotate".equals(result.getClassName()) && result.getProbability() > 0.8) {
                subImg = rotateImg(subImg);
            }
            String name = recognizer(subImg);
            names.add(name);
            prob.add(-1.0);
            rect.add(boxes.get(i).getBoundingBox());

            if (dis < 20) { // 认为是同 1 行  - Considered to be in the same line
                line.add(name);
            } else { // 换行 - Line break
                firstBox = tmpBox;
                line = new ArrayList<>();
                line.add(name);
                lines.add((ArrayList) line);
            }

        }
        String fullText = "";
        for (int i = 0; i < lines.size(); i++) {
            for (int j = 0; j < lines.get(i).size(); j++) {
                String text = lines.get(i).get(j);
                if (text.trim().equals(""))
                    continue;
                fullText += text + " ";
            }
            fullText += '\n';
        }

        System.out.println("fullText--------------\n" + fullText);

        Image newImage = img.duplicate();
        newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));
        newImage.getWrappedImage();
        newImage.save(Files.newOutputStream(Paths.get("C:\\laker\\demo3\\1-1-1.png")), "png");
        return "";
    }

    @SneakyThrows
    private static List<DetectedObjects.DetectedObject> detection(Image img) {
        Predictor<Image, DetectedObjects> detector = detectionModel.newPredictor();
        DetectedObjects detectedObj = detector.predict(img);
        System.out.println(detectedObj);
        return detectedObj.items();
    }

    @SneakyThrows
    private static Classifications.Classification getRotateResult(Image img) {
        Predictor<Image, Classifications> rotateClassifier = rotateModel.newPredictor();
        Classifications predict = rotateClassifier.predict(img);
        System.out.println(predict);
        return predict.best();
    }

    @SneakyThrows
    private static String recognizer(Image img) {

        Predictor<Image, String> recognizer = recognitionModel.newPredictor();
        String text = recognizer.predict(img);
        System.out.println(text);
        return text;
    }

    static Image getSubImage(Image img, BoundingBox box) {
        Rectangle rect = box.getBounds();
        double[] extended = new double[]{rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight()};
        int width = img.getWidth();
        int height = img.getHeight();
        int[] recovered = {
                (int) (extended[0] * width),
                (int) (extended[1] * height),
                (int) (extended[2] * width),
                (int) (extended[3] * height)
        };
        return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]);
    }


    private static Image rotateImg(Image image) {
        try (NDManager manager = NDManager.newBaseManager()) {
            NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);
            return ImageFactory.getInstance().fromNDArray(rotated);
        }
    }
}

输出纯文本

 java   博客之望 
 lakernote @ te码龄11年 (编辑资料) 
907,648 总访问量丨 丨499原创丨 |2,426排名| |52,478粉丝| 702铁粉| 」学习成就 
个人简介:不停的复盘自己,砥砺前行,不忘初衷 
IP属地:日本 
查看详细资料 
成就〉 最近  文章514  问答5 適资源42 原力等级 
CSDN @lakernote 

结果

在这里插入图片描述

文章来源:https://blog.csdn.net/abu935009066/article/details/135613485
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。