




"""Custom TorchServe model handler for YOLOv5 models.
from ts.torch_handler.base_handler import BaseHandler
import numpy as np
import base64
import torch
import torchvision.transforms as tf
import torchvision
import io
from PIL import Image

class ModelHandler(BaseHandler):
    A custom model handler implementation.

    img_size = 640
    """Image size (px). Images will be resized to this resolution before inference.

    def __init__(self):
        # call superclass initializer

    def preprocess(self, data):
        """Converts input images to float tensors.
            data (List): Input data from the request in the form of a list of image tensors.
            Tensor: single Tensor of shape [BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE]
        images = []

        transform = tf.Compose([
            tf.Resize((self.img_size, self.img_size))

        # load images
        # taken from

        # handle if images are given in base64, etc.
        for row in data:
            # Compat layer: normally the envelope should just return the data
            # directly, but older versions of Torchserve didn't have envelope.
            image = row.get("data") or row.get("body")
            if isinstance(image, str):
                # if the image is a string of bytesarray.
                image = base64.b64decode(image)

            # If the image is sent as bytesarray
            if isinstance(image, (bytearray, bytes)):
                image =
                # if the image is a list
                image = torch.FloatTensor(image)

            # force convert to tensor
            # and resize to [img_size, img_size]
            image = transform(image)


        # convert list of equal-size tensors to single stacked tensor
        # has shape BATCH_SIZE x 3 x IMG_SIZE x IMG_SIZE
        images_tensor = torch.stack(images).to(self.device)

        return images_tensor

    def postprocess(self, inference_output):
        # perform NMS (nonmax suppression) on model outputs
        pred = non_max_suppression(inference_output[0])

        # initialize empty list of detections for each image
        detections = [[] for _ in range(len(pred))]

        for i, image_detections in enumerate(pred):  # axis 0: for each image
            for det in image_detections:  # axis 1: for each detection
                # x1,y1,x2,y2 in normalized image coordinates (i.e. 0.0-1.0)
                xyxy = det[:4] / self.img_size
                # confidence value
                conf = det[4].item()
                # index of predicted class
                class_idx = int(det[5].item())
                # get label of predicted class
                # if missing, then just return class idx
                label = self.mapping.get(str(class_idx), class_idx)

                    "x1": xyxy[0].item(),
                    "y1": xyxy[1].item(),
                    "x2": xyxy[2].item(),
                    "y2": xyxy[3].item(),
                    "confidence": conf,
                    "class": label

        # format each detection
        return detections

def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
                        labels=(), max_det=300):
    """Runs Non-Maximum Suppression (NMS) on inference results
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]

    nc = prediction.shape[2] - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

    # Settings
    # (pixels) minimum and maximum box width and height
    min_wh, max_wh = 2, 4096
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 10.0  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    output = [torch.zeros((0, 6), device=prediction.device)
              ] * prediction.shape[0]
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            l = labels[xi]
            v = torch.zeros((len(l), nc + 5), device=x.device)
            v[:, :4] = l[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
            x =, v), 0)

        # If none remain process next image
        if not x.shape[0]:

        # Compute conf
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        box = xywh2xyxy(x[:, :4])

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x =[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1, keepdim=True)
            x =, conf, j.float()), 1)[
                conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
        elif n > max_nms:  # excess boxes
            # sort by confidence
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        # boxes (offset by class), scores
        boxes, scores = x[:, :4] + c, x[:, 4]
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = torchvision.box_iou(
                boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] =, x[:, :4]).float(
            ) / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]

    return output

def xywh2xyxy(x):
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
    return y


import logging
import os
from collections import Counter

import torch
from torchvision import transforms
from ultralytics import YOLO

from ts.torch_handler.object_detector import ObjectDetector

logger = logging.getLogger(__name__)

    import torch_xla.core.xla_model as xm

except ImportError as error:

class Yolov8Handler(ObjectDetector):
    image_processing = transforms.Compose(
        [transforms.Resize(640), transforms.CenterCrop(640), transforms.ToTensor()]

    def __init__(self):
        super(Yolov8Handler, self).__init__()

    def initialize(self, context):
        # Set device type
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        elif XLA_AVAILABLE:
            self.device = xm.xla_device()
            self.device = torch.device("cpu")

        # Load the model
        properties = context.system_properties
        self.manifest = context.manifest
        model_dir = properties.get("model_dir")
        self.model_pt_path = None
        if "serializedFile" in self.manifest["model"]:
            serialized_file = self.manifest["model"]["serializedFile"]
            self.model_pt_path = os.path.join(model_dir, serialized_file)
        self.model = self._load_torchscript_model(self.model_pt_path)
        logger.debug("Model file %s loaded successfully", self.model_pt_path)

        self.initialized = True

    def _load_torchscript_model(self, model_pt_path):
        """Loads the PyTorch model and returns the NN model object.

            model_pt_path (str): denotes the path of the model file.

            (NN Model Object) : Loads the model object.
        # TODO: remove this method if gets resolved

        model = YOLO(model_pt_path)
        return model

    def postprocess(self, res):
        output = []
        for data in res:
            classes = data.boxes.cls.tolist()
            names = data.names

            # Map to class names
            classes = map(lambda cls: names[int(cls)], classes)

            # Get a count of objects detected
            result = Counter(classes)

        return output



from torchvision import transforms
from ts.torch_handler.image_classifier import ImageClassifier
from torch.profiler import ProfilerActivity

class MNISTDigitClassifier(ImageClassifier):
    MNISTDigitClassifier handler class. This handler extends class ImageClassifier from, a
    default handler. This handler takes an image and returns the number in that image.

    Here method postprocess() has been overridden while others are reused from parent class.

    image_processing = transforms.Compose([
        transforms.Normalize((0.1307,), (0.3081,))

    def __init__(self):
        super(MNISTDigitClassifier, self).__init__()
        self.profiler_args = {
            "activities" : [ProfilerActivity.CPU],
            "record_shapes": True,

    def postprocess(self, data):
        """The post process of MNIST converts the predicted output response to a label.

            data (list): The predicted output from the Inference with probabilities is passed
            to the post-process function
            list : A list of dictionaries with predictions and explanations is returned
        return data.argmax(1).tolist()
