Slicing Aided Hyper Inference
pip install ultralytics sahifrom sahi import AutoDetectionModel
 
detection_model = AutoDetectionModel.from_pretrained(
	model_type="ultralytics",
	model_path="model/best.pt",
	confidence_threshold=0.3,
	device="cuda:0", # or 'cpu'
)
 
from sahi.predict import get_prediction, get_sliced_prediction
import time
 
start = time.time()
# result = get_prediction("test_image/frame_240805_1538_03280.jpg", detection_model)
result = get_sliced_prediction(
	"test_image/frame_240805_1538_03280.jpg",
	detection_model,
	slice_height=1280,
	slice_width=1920,
	overlap_height_ratio=0.1,
	overlap_width_ratio=0.2,
)
print("Time taken:", time.time() - start)
 
result.export_visuals(export_dir="demo_data_sahi",file_name="result_4_onnx_half.jpg")批量推理改动
sahi.predict的232行, num_batch直接改似乎可行
Warning
不可行,要改很多地方.只改动这个会只推理每个
group的第一张,少推理图片改动可参考sahi_custom
def get_prediction(
    image,
    detection_model,
    shift_amount: list = [0, 0],
    full_shape=None,
    postprocess: Optional[PostprocessPredictions] = None,
    verbose: int = 0,
    num_batch: int = 1,
    exclude_classes_by_name: Optional[List[str]] = None,
    exclude_classes_by_id: Optional[List[int]] = None,
) -> PredictionResult:
    """
    Function for performing prediction for given image using given detection_model.
 
    Arguments:
        image: str or np.ndarray or list[np.ndarray]
            Location of image or numpy image matrix to slice
        detection_model: model.DetectionMode
        shift_amount: List
            To shift the box and mask predictions from sliced image to full
            sized image, should be in the form of [shift_x, shift_y]
        full_shape: List
            Size of the full image, should be in the form of [height, width]
        postprocess: sahi.postprocess.combine.PostprocessPredictions
        verbose: int
            0: no print (default)
            1: print prediction duration
        exclude_classes_by_name: Optional[List[str]]
            None: if no classes are excluded
            List[str]: set of classes to exclude using its/their class label name/s
        exclude_classes_by_id: Optional[List[int]]
            None: if no classes are excluded
            List[int]: set of classes to exclude using one or more IDs
    Returns:
        A dict with fields:
            object_prediction_list: a list of ObjectPrediction
            durations_in_seconds: a dict containing elapsed times for profiling
    """
    durations_in_seconds = dict()
 
    # read image as pil
    if num_batch == 1:
        image_as_pil = read_image_as_pil(image[0])
        time_start = time.time()
        detection_model.perform_inference(np.ascontiguousarray(image_as_pil),num_batch = num_batch)
        #detection_model.perform_inference(images_as_pil,num_batch = num_batch)
        time_end = time.time() - time_start
        durations_in_seconds["prediction"] = time_end
    else:
        images_as_pil = [np.ascontiguousarray(read_image_as_pil(img)) for img in image]
        time_start = time.time()
        detection_model.perform_inference(images_as_pil,num_batch = num_batch)
        time_end = time.time() - time_start
        durations_in_seconds["prediction"] = time_end
 
    if full_shape is None:
        full_shape = [image_as_pil.height, image_as_pil.width]
 
    # process prediction
    time_start = time.time()
    # works only with 1 batch
    detection_model.convert_original_predictions(
        shift_amount=shift_amount,
        full_shape=full_shape,
    )
    object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list
    object_prediction_list = filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id)
 
    # postprocess matching predictions
    if postprocess is not None:
        object_prediction_list = postprocess(object_prediction_list)
 
    time_end = time.time() - time_start
    durations_in_seconds["postprocess"] = time_end
 
    if verbose == 1:
        print(
            "Prediction performed in",
            durations_in_seconds["prediction"],
            "seconds.",
        )
 
    return PredictionResult(
        image=image[0], object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds
    )
 
 
def get_sliced_prediction(
    image,
    detection_model=None,
    slice_height: Optional[int] = None,
    slice_width: Optional[int] = None,
    overlap_height_ratio: float = 0.2,
    overlap_width_ratio: float = 0.2,
    perform_standard_pred: bool = True,
    postprocess_type: str = "GREEDYNMM",
    postprocess_match_metric: str = "IOS",
    postprocess_match_threshold: float = 0.5,
    postprocess_class_agnostic: bool = False,
    verbose: int = 1,
    merge_buffer_length: Optional[int] = None,
    auto_slice_resolution: bool = True,
    slice_export_prefix: Optional[str] = None,
    slice_dir: Optional[str] = None,
    exclude_classes_by_name: Optional[List[str]] = None,
    exclude_classes_by_id: Optional[List[int]] = None,
    batch: int = 1,
) -> PredictionResult:
    """
    Function for slice image + get predicion for each slice + combine predictions in full image.
 
    Args:
        image: str or np.ndarray
            Location of image or numpy image matrix to slice
        detection_model: model.DetectionModel
        slice_height: int
            Height of each slice.  Defaults to ``None``.
        slice_width: int
            Width of each slice.  Defaults to ``None``.
        overlap_height_ratio: float
            Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window
            of size 512 yields an overlap of 102 pixels).
            Default to ``0.2``.
        overlap_width_ratio: float
            Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window
            of size 512 yields an overlap of 102 pixels).
            Default to ``0.2``.
        perform_standard_pred: bool
            Perform a standard prediction on top of sliced predictions to increase large object
            detection accuracy. Default: True.
        postprocess_type: str
            Type of the postprocess to be used after sliced inference while merging/eliminating predictions.
            Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'.
        postprocess_match_metric: str
            Metric to be used during object prediction matching after sliced prediction.
            'IOU' for intersection over union, 'IOS' for intersection over smaller area.
        postprocess_match_threshold: float
            Sliced predictions having higher iou than postprocess_match_threshold will be
            postprocessed after sliced prediction.
        postprocess_class_agnostic: bool
            If True, postprocess will ignore category ids.
        verbose: int
            0: no print
            1: print number of slices (default)
            2: print number of slices and slice/prediction durations
        merge_buffer_length: int
            The length of buffer for slices to be used during sliced prediction, which is suitable for low memory.
            It may affect the AP if it is specified. The higher the amount, the closer results to the non-buffered.
            scenario. See [the discussion](https://github.com/obss/sahi/pull/445).
        auto_slice_resolution: bool
            if slice parameters (slice_height, slice_width) are not given,
            it enables automatically calculate these params from image resolution and orientation.
        slice_export_prefix: str
            Prefix for the exported slices. Defaults to None.
        slice_dir: str
            Directory to save the slices. Defaults to None.
        exclude_classes_by_name: Optional[List[str]]
            None: if no classes are excluded
            List[str]: set of classes to exclude using its/their class label name/s
        exclude_classes_by_id: Optional[List[int]]
            None: if no classes are excluded
            List[int]: set of classes to exclude using one or more IDs
    Returns:
        A Dict with fields:
            object_prediction_list: a list of sahi.prediction.ObjectPrediction
            durations_in_seconds: a dict containing elapsed times for profiling
    """
 
    # for profiling
    durations_in_seconds = dict()
 
    # currently only 1 batch supported
    num_batch = batch
    # create slices from full image
    time_start = time.time()
    slice_image_result = slice_image(
        image=image,
        output_file_name=slice_export_prefix,
        output_dir=slice_dir,
        slice_height=slice_height,
        slice_width=slice_width,
        overlap_height_ratio=overlap_height_ratio,
        overlap_width_ratio=overlap_width_ratio,
        auto_slice_resolution=auto_slice_resolution,
    )
    from sahi.models.ultralytics import UltralyticsDetectionModel
 
    num_slices = len(slice_image_result)
    time_end = time.time() - time_start
    durations_in_seconds["slice"] = time_end
 
    if isinstance(detection_model, UltralyticsDetectionModel) and detection_model.is_obb:
        # Only NMS is supported for OBB model outputs
        postprocess_type = "NMS"
 
    # init match postprocess instance
    if postprocess_type not in POSTPROCESS_NAME_TO_CLASS.keys():
        raise ValueError(
            f"postprocess_type should be one of {list(POSTPROCESS_NAME_TO_CLASS.keys())} but given as {postprocess_type}"
        )
    postprocess_constructor = POSTPROCESS_NAME_TO_CLASS[postprocess_type]
    postprocess = postprocess_constructor(
        match_threshold=postprocess_match_threshold,
        match_metric=postprocess_match_metric,
        class_agnostic=postprocess_class_agnostic,
    )
 
    postprocess_time = 0
    time_start = time.time()
 
    # create prediction input
    num_group = int(num_slices / num_batch)
    if verbose == 1 or verbose == 2:
        tqdm.write(f"Performing prediction on {num_slices} slices.")
    object_prediction_list = []
    # perform sliced prediction
    for group_ind in range(num_group):
        # prepare batch (currently supports only 1 batch)
        image_list = []
        shift_amount_list = []
        for image_ind in range(num_batch):
            image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])
            shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind])
        # perform batch prediction
        prediction_result = get_prediction(
            image=image_list[0:num_batch],
            detection_model=detection_model,
            shift_amount=shift_amount_list[0:num_batch],
            full_shape=[
                slice_image_result.original_image_height,
                slice_image_result.original_image_width,
            ],
            exclude_classes_by_name=exclude_classes_by_name,
            exclude_classes_by_id=exclude_classes_by_id,
            num_batch= num_batch,
        )
        # convert sliced predictions to full predictions
        for object_prediction in prediction_result.object_prediction_list:
            if object_prediction:  # if not empty
                object_prediction_list.append(object_prediction.get_shifted_object_prediction())
 
        # merge matching predictions during sliced prediction
        if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length:
            postprocess_time_start = time.time()
            object_prediction_list = postprocess(object_prediction_list)
            postprocess_time += time.time() - postprocess_time_start
 
    # perform standard prediction
    if num_slices > 1 and perform_standard_pred:
        prediction_result = get_prediction(
            image=[image],
            detection_model=detection_model,
            shift_amount=[0, 0],
            full_shape=[
                slice_image_result.original_image_height,
                slice_image_result.original_image_width,
            ],
            postprocess=None,
            exclude_classes_by_name=exclude_classes_by_name,
            exclude_classes_by_id=exclude_classes_by_id,
        )
        object_prediction_list.extend(prediction_result.object_prediction_list)
 
    # merge matching predictions
    if len(object_prediction_list) > 1:
        postprocess_time_start = time.time()
        object_prediction_list = postprocess(object_prediction_list)
        postprocess_time += time.time() - postprocess_time_start
 
    time_end = time.time() - time_start
    durations_in_seconds["prediction"] = time_end - postprocess_time
    durations_in_seconds["postprocess"] = postprocess_time
 
    if verbose == 2:
        print(
            "Slicing performed in",
            durations_in_seconds["slice"],
            "seconds.",
        )
        print(
            "Prediction performed in",
            durations_in_seconds["prediction"],
            "seconds.",
        )
        print(
            "Postprocessing performed in",
            durations_in_seconds["postprocess"],
            "seconds.",
        )
 
    return PredictionResult(
        image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds
    )    def perform_inference(self, image: np.ndarray, num_batch: int = 1):
        """
        Prediction is performed using self.model and the prediction result is set to self._original_predictions.
        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
        """
        # Confirm model is loaded
        if self.model is None:
            raise ValueError("Model is not loaded, load it by calling .load_model()")
 
        kwargs = {"cfg": self.config_path, "verbose": False, "conf": self.confidence_threshold, "device": self.device}
 
        if self.image_size is not None:
            kwargs = {"imgsz": self.image_size, **kwargs}
 
        if num_batch == 1:
            prediction_result = self.model(image[:, :, ::-1], **kwargs)
        else:
            prediction_result = self.model([img[:, :, ::-1] for img in image], **kwargs)  # For batch inference
        # Handle different result types for PyTorch vs ONNX models
        # ONNX models might return results in a different format
        if self.has_mask:
            from ultralytics.engine.results import Masks
 
            if not prediction_result[0].masks:
                # Create empty masks if none exist
                if hasattr(self.model, "device"):
                    device = self.model.device
                else:
                    device = "cpu"  # Default for ONNX models
                prediction_result[0].masks = Masks(
                    torch.tensor([], device=device), prediction_result[0].boxes.orig_shape
                )
 
            # We do not filter results again as confidence threshold is already applied above
            prediction_result = [
                (
                    result.boxes.data,
                    result.masks.data,
                )
                for result in prediction_result
            ]
        elif self.is_obb:
            # For OBB task, get OBB points in xyxyxyxy format
            device = getattr(self.model, "device", "cpu")
            prediction_result = [
                (
                    # Get OBB data: xyxy, conf, cls
                    torch.cat(
                        [
                            result.obb.xyxy,  # box coordinates
                            result.obb.conf.unsqueeze(-1),  # confidence scores
                            result.obb.cls.unsqueeze(-1),  # class ids
                        ],
                        dim=1,
                    )
                    if result.obb is not None
                    else torch.empty((0, 6), device=device),
                    # Get OBB points in (N, 4, 2) format
                    result.obb.xyxyxyxy if result.obb is not None else torch.empty((0, 4, 2), device=device),
                )
                for result in prediction_result
            ]
        else:  # If model doesn't do segmentation or OBB then no need to check masks
            prediction_result = [result.boxes.data for result in prediction_result]
                
        self._original_predictions = prediction_result
        # self._original_shape = image[0].shape    def _create_object_prediction_list_from_original_predictions(
        self,
        shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
        full_shape_list: Optional[List[List[int]]] = None,
    ):
        """
        self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
        self._object_prediction_list_per_image.
        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        original_predictions = self._original_predictions
 
        if not isinstance(shift_amount_list[0],list):
            shift_amount_list = fix_shift_amount_list(shift_amount_list)
        full_shape_list = fix_full_shape_list(full_shape_list)
 
        # handle all predictions
        object_prediction_list_per_image = []
 
        for image_ind, image_predictions in enumerate(original_predictions):
            shift_amount = shift_amount_list[image_ind]
            full_shape = None if full_shape_list is None else full_shape_list[0]
            object_prediction_list = []
 
            # Extract boxes and optional masks/obb
            if self.has_mask or self.is_obb:
                boxes = image_predictions[0].cpu().detach().numpy()
                masks_or_points = image_predictions[1].cpu().detach().numpy()
            else:
                boxes = image_predictions.data.cpu().detach().numpy()
                masks_or_points = None
 
            # Process each prediction
            for pred_ind, prediction in enumerate(boxes):
                # Get bbox coordinates
                bbox = prediction[:4].tolist()
                score = prediction[4]
                category_id = int(prediction[5])
                category_name = self.category_mapping[str(category_id)]
 
                # Fix box coordinates
                bbox = [max(0, coord) for coord in bbox]
                if full_shape is not None:
                    bbox[0] = min(full_shape[1], bbox[0])
                    bbox[1] = min(full_shape[0], bbox[1])
                    bbox[2] = min(full_shape[1], bbox[2])
                    bbox[3] = min(full_shape[0], bbox[3])
 
                # Ignore invalid predictions
                if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
                    logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
                    continue
 
                # Get segmentation or OBB points
                segmentation = None
                if masks_or_points is not None:
                    if self.has_mask:
                        bool_mask = masks_or_points[pred_ind]
                        # Resize mask to original image size
                        bool_mask = cv2.resize(
                            bool_mask.astype(np.uint8), (self._original_shape[1], self._original_shape[0])
                        )
                        segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
                    else:  # is_obb
                        obb_points = masks_or_points[pred_ind]  # Get OBB points for this prediction
                        segmentation = [obb_points.reshape(-1).tolist()]
 
                    if len(segmentation) == 0:
                        continue
 
                # Create and append object prediction
                object_prediction = ObjectPrediction(
                    bbox=bbox,
                    category_id=category_id,
                    score=score,
                    segmentation=segmentation,
                    category_name=category_name,
                    shift_amount=shift_amount,
                    full_shape=self._original_shape[:2] if full_shape is None else full_shape,  # (height, width)
                )
                object_prediction_list.append(object_prediction)
 
            object_prediction_list_per_image.append(object_prediction_list)
 
        if len(object_prediction_list_per_image) > 1:
            for i in range(1, len(object_prediction_list_per_image)):
                object_prediction_list_per_image[0].extend(object_prediction_list_per_image[i])
                
        self._object_prediction_list_per_image = object_prediction_list_per_image
 其他处理
result.export_visuals()这块在最后加一行
    def export_visuals(
        self,
        export_dir: str,
        text_size: Optional[float] = None,
        rect_th: Optional[int] = None,
        hide_labels: bool = False,
        hide_conf: bool = False,
        file_name: str = "prediction_visual",
    ):
        """
 
        Args:
            export_dir: directory for resulting visualization to be exported
            text_size: size of the category name over box
            rect_th: rectangle thickness
            hide_labels: hide labels
            hide_conf: hide confidence
            file_name: saving name
        Returns:
 
        """
        Path(export_dir).mkdir(parents=True, exist_ok=True)
        res = visualize_object_predictions(
            image=np.ascontiguousarray(self.image),
            object_prediction_list=self.object_prediction_list,
            rect_th=rect_th,
            text_size=text_size,
            text_th=None,
            color=None,
            hide_labels=hide_labels,
            hide_conf=hide_conf,
            output_dir=export_dir,
            file_name=file_name,
            export_format="png",
        )
 
        return res["image"]    def yolo_detections(self):
        """
        Returns a list of YOLO detections in the format:
        [x1, y1, x2, y2, confidence, class_id]
        """
        yolo_detections = []
        for object_prediction in self.object_prediction_list:
            x1, y1, x2, y2 = object_prediction.bbox.to_xyxy()
            class_id = object_prediction.category.id
            confidence = object_prediction.score.value
            yolo_detections.append([x1, y1, x2, y2, confidence, class_id])
        return yolo_detections不能直接使用engine,修改
if self.model_path and self.model_path.endswith(".pt"):