|  | 
|  | 1 | +import os | 
|  | 2 | +import time | 
|  | 3 | +import sys | 
|  | 4 | +import numpy as np | 
|  | 5 | +import yaml | 
|  | 6 | +from addict import Dict | 
|  | 7 | + | 
|  | 8 | +__dir__ = os.path.dirname(os.path.abspath(__file__)) | 
|  | 9 | +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) | 
|  | 10 | + | 
|  | 11 | +import cv2 | 
|  | 12 | +import numpy as np | 
|  | 13 | + | 
|  | 14 | +from pipeline.framework.module_base import ModuleBase | 
|  | 15 | +from pipeline.tasks import TaskType | 
|  | 16 | +from infer.classification.classification import ClsPostProcess | 
|  | 17 | +from tools.infer.text.utils import crop_text_region | 
|  | 18 | +from pipeline.data_process.utils.cv_utils import crop_box_from_image | 
|  | 19 | + | 
|  | 20 | + | 
|  | 21 | +class ClsPostNode(ModuleBase): | 
|  | 22 | +    def __init__(self, args, msg_queue, tqdm_info): | 
|  | 23 | +        super(ClsPostNode, self).__init__(args, msg_queue, tqdm_info) | 
|  | 24 | +        self.cls_postprocess = ClsPostProcess(args) | 
|  | 25 | +        self.task_type = self.args.task_type | 
|  | 26 | +        self.cls_thresh = 0.9 | 
|  | 27 | + | 
|  | 28 | +    def init_self_args(self): | 
|  | 29 | +        super().init_self_args() | 
|  | 30 | + | 
|  | 31 | +    def process(self, input_data): | 
|  | 32 | +        """ | 
|  | 33 | +        Input: | 
|  | 34 | +          - input_data.data: [np.ndarray], shape:[?,2] | 
|  | 35 | +        Output: | 
|  | 36 | +          - input_data.sub_image_list: [np.ndarray], shape:[1,3,-1,-1], e.g. [1,3,48,192] | 
|  | 37 | +          - input_data.data = None | 
|  | 38 | +          or | 
|  | 39 | +          - input_data.infer_result = [(str, float)] | 
|  | 40 | +        """ | 
|  | 41 | +        if input_data.skip: | 
|  | 42 | +            self.send_to_next_module(input_data) | 
|  | 43 | +            return | 
|  | 44 | + | 
|  | 45 | +        data = input_data.data | 
|  | 46 | +        pred = data["pred"] | 
|  | 47 | +        output = self.cls_postprocess(pred) | 
|  | 48 | +        angles = output["angles"] | 
|  | 49 | +        scores = np.array(output["scores"]).tolist() | 
|  | 50 | + | 
|  | 51 | +        batch = input_data.sub_image_size | 
|  | 52 | +        if self.task_type.value == TaskType.DET_CLS_REC.value: | 
|  | 53 | +            sub_images = input_data.sub_image_list | 
|  | 54 | +            for i in range(batch): | 
|  | 55 | +                angle, score = angles[i], scores[i] | 
|  | 56 | +                if "180" == angle and score > self.cls_thresh: | 
|  | 57 | +                    sub_images[i] = cv2.rotate(sub_images[i], cv2.ROTATE_180) | 
|  | 58 | +            input_data.sub_image_list = sub_images | 
|  | 59 | +        else: | 
|  | 60 | +            input_data.infer_result = [(angle, score) for angle, score in zip(angles, scores)] | 
|  | 61 | + | 
|  | 62 | +        input_data.data = None | 
|  | 63 | +        self.send_to_next_module(input_data) | 
0 commit comments