Slicing Aided Hyper Inference
pip install ultralytics sahifrom sahi import AutoDetectionModel
detection_model = AutoDetectionModel.from_pretrained(
model_type="ultralytics",
model_path="model/best.pt",
confidence_threshold=0.3,
device="cuda:0", # or 'cpu'
)
from sahi.predict import get_prediction, get_sliced_prediction
import time
start = time.time()
# result = get_prediction("test_image/frame_240805_1538_03280.jpg", detection_model)
result = get_sliced_prediction(
"test_image/frame_240805_1538_03280.jpg",
detection_model,
slice_height=1280,
slice_width=1920,
overlap_height_ratio=0.1,
overlap_width_ratio=0.2,
)
print("Time taken:", time.time() - start)
result.export_visuals(export_dir="demo_data_sahi",file_name="result_4_onnx_half.jpg")批量推理改动
sahi.predict的232行, num_batch直接改似乎可行
Warning
不可行,要改很多地方.只改动这个会只推理每个
group的第一张,少推理图片改动可参考sahi_custom
def get_prediction(
image,
detection_model,
shift_amount: list = [0, 0],
full_shape=None,
postprocess: Optional[PostprocessPredictions] = None,
verbose: int = 0,
num_batch: int = 1,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
) -> PredictionResult:
"""
Function for performing prediction for given image using given detection_model.
Arguments:
image: str or np.ndarray or list[np.ndarray]
Location of image or numpy image matrix to slice
detection_model: model.DetectionMode
shift_amount: List
To shift the box and mask predictions from sliced image to full
sized image, should be in the form of [shift_x, shift_y]
full_shape: List
Size of the full image, should be in the form of [height, width]
postprocess: sahi.postprocess.combine.PostprocessPredictions
verbose: int
0: no print (default)
1: print prediction duration
exclude_classes_by_name: Optional[List[str]]
None: if no classes are excluded
List[str]: set of classes to exclude using its/their class label name/s
exclude_classes_by_id: Optional[List[int]]
None: if no classes are excluded
List[int]: set of classes to exclude using one or more IDs
Returns:
A dict with fields:
object_prediction_list: a list of ObjectPrediction
durations_in_seconds: a dict containing elapsed times for profiling
"""
durations_in_seconds = dict()
# read image as pil
if num_batch == 1:
image_as_pil = read_image_as_pil(image[0])
time_start = time.time()
detection_model.perform_inference(np.ascontiguousarray(image_as_pil),num_batch = num_batch)
#detection_model.perform_inference(images_as_pil,num_batch = num_batch)
time_end = time.time() - time_start
durations_in_seconds["prediction"] = time_end
else:
images_as_pil = [np.ascontiguousarray(read_image_as_pil(img)) for img in image]
time_start = time.time()
detection_model.perform_inference(images_as_pil,num_batch = num_batch)
time_end = time.time() - time_start
durations_in_seconds["prediction"] = time_end
if full_shape is None:
full_shape = [image_as_pil.height, image_as_pil.width]
# process prediction
time_start = time.time()
# works only with 1 batch
detection_model.convert_original_predictions(
shift_amount=shift_amount,
full_shape=full_shape,
)
object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list
object_prediction_list = filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id)
# postprocess matching predictions
if postprocess is not None:
object_prediction_list = postprocess(object_prediction_list)
time_end = time.time() - time_start
durations_in_seconds["postprocess"] = time_end
if verbose == 1:
print(
"Prediction performed in",
durations_in_seconds["prediction"],
"seconds.",
)
return PredictionResult(
image=image[0], object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds
)
def get_sliced_prediction(
image,
detection_model=None,
slice_height: Optional[int] = None,
slice_width: Optional[int] = None,
overlap_height_ratio: float = 0.2,
overlap_width_ratio: float = 0.2,
perform_standard_pred: bool = True,
postprocess_type: str = "GREEDYNMM",
postprocess_match_metric: str = "IOS",
postprocess_match_threshold: float = 0.5,
postprocess_class_agnostic: bool = False,
verbose: int = 1,
merge_buffer_length: Optional[int] = None,
auto_slice_resolution: bool = True,
slice_export_prefix: Optional[str] = None,
slice_dir: Optional[str] = None,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
batch: int = 1,
) -> PredictionResult:
"""
Function for slice image + get predicion for each slice + combine predictions in full image.
Args:
image: str or np.ndarray
Location of image or numpy image matrix to slice
detection_model: model.DetectionModel
slice_height: int
Height of each slice. Defaults to ``None``.
slice_width: int
Width of each slice. Defaults to ``None``.
overlap_height_ratio: float
Fractional overlap in height of each window (e.g. an overlap of 0.2 for a window
of size 512 yields an overlap of 102 pixels).
Default to ``0.2``.
overlap_width_ratio: float
Fractional overlap in width of each window (e.g. an overlap of 0.2 for a window
of size 512 yields an overlap of 102 pixels).
Default to ``0.2``.
perform_standard_pred: bool
Perform a standard prediction on top of sliced predictions to increase large object
detection accuracy. Default: True.
postprocess_type: str
Type of the postprocess to be used after sliced inference while merging/eliminating predictions.
Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'.
postprocess_match_metric: str
Metric to be used during object prediction matching after sliced prediction.
'IOU' for intersection over union, 'IOS' for intersection over smaller area.
postprocess_match_threshold: float
Sliced predictions having higher iou than postprocess_match_threshold will be
postprocessed after sliced prediction.
postprocess_class_agnostic: bool
If True, postprocess will ignore category ids.
verbose: int
0: no print
1: print number of slices (default)
2: print number of slices and slice/prediction durations
merge_buffer_length: int
The length of buffer for slices to be used during sliced prediction, which is suitable for low memory.
It may affect the AP if it is specified. The higher the amount, the closer results to the non-buffered.
scenario. See [the discussion](https://github.com/obss/sahi/pull/445).
auto_slice_resolution: bool
if slice parameters (slice_height, slice_width) are not given,
it enables automatically calculate these params from image resolution and orientation.
slice_export_prefix: str
Prefix for the exported slices. Defaults to None.
slice_dir: str
Directory to save the slices. Defaults to None.
exclude_classes_by_name: Optional[List[str]]
None: if no classes are excluded
List[str]: set of classes to exclude using its/their class label name/s
exclude_classes_by_id: Optional[List[int]]
None: if no classes are excluded
List[int]: set of classes to exclude using one or more IDs
Returns:
A Dict with fields:
object_prediction_list: a list of sahi.prediction.ObjectPrediction
durations_in_seconds: a dict containing elapsed times for profiling
"""
# for profiling
durations_in_seconds = dict()
# currently only 1 batch supported
num_batch = batch
# create slices from full image
time_start = time.time()
slice_image_result = slice_image(
image=image,
output_file_name=slice_export_prefix,
output_dir=slice_dir,
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
auto_slice_resolution=auto_slice_resolution,
)
from sahi.models.ultralytics import UltralyticsDetectionModel
num_slices = len(slice_image_result)
time_end = time.time() - time_start
durations_in_seconds["slice"] = time_end
if isinstance(detection_model, UltralyticsDetectionModel) and detection_model.is_obb:
# Only NMS is supported for OBB model outputs
postprocess_type = "NMS"
# init match postprocess instance
if postprocess_type not in POSTPROCESS_NAME_TO_CLASS.keys():
raise ValueError(
f"postprocess_type should be one of {list(POSTPROCESS_NAME_TO_CLASS.keys())} but given as {postprocess_type}"
)
postprocess_constructor = POSTPROCESS_NAME_TO_CLASS[postprocess_type]
postprocess = postprocess_constructor(
match_threshold=postprocess_match_threshold,
match_metric=postprocess_match_metric,
class_agnostic=postprocess_class_agnostic,
)
postprocess_time = 0
time_start = time.time()
# create prediction input
num_group = int(num_slices / num_batch)
if verbose == 1 or verbose == 2:
tqdm.write(f"Performing prediction on {num_slices} slices.")
object_prediction_list = []
# perform sliced prediction
for group_ind in range(num_group):
# prepare batch (currently supports only 1 batch)
image_list = []
shift_amount_list = []
for image_ind in range(num_batch):
image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])
shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind])
# perform batch prediction
prediction_result = get_prediction(
image=image_list[0:num_batch],
detection_model=detection_model,
shift_amount=shift_amount_list[0:num_batch],
full_shape=[
slice_image_result.original_image_height,
slice_image_result.original_image_width,
],
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
num_batch= num_batch,
)
# convert sliced predictions to full predictions
for object_prediction in prediction_result.object_prediction_list:
if object_prediction: # if not empty
object_prediction_list.append(object_prediction.get_shifted_object_prediction())
# merge matching predictions during sliced prediction
if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length:
postprocess_time_start = time.time()
object_prediction_list = postprocess(object_prediction_list)
postprocess_time += time.time() - postprocess_time_start
# perform standard prediction
if num_slices > 1 and perform_standard_pred:
prediction_result = get_prediction(
image=[image],
detection_model=detection_model,
shift_amount=[0, 0],
full_shape=[
slice_image_result.original_image_height,
slice_image_result.original_image_width,
],
postprocess=None,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
object_prediction_list.extend(prediction_result.object_prediction_list)
# merge matching predictions
if len(object_prediction_list) > 1:
postprocess_time_start = time.time()
object_prediction_list = postprocess(object_prediction_list)
postprocess_time += time.time() - postprocess_time_start
time_end = time.time() - time_start
durations_in_seconds["prediction"] = time_end - postprocess_time
durations_in_seconds["postprocess"] = postprocess_time
if verbose == 2:
print(
"Slicing performed in",
durations_in_seconds["slice"],
"seconds.",
)
print(
"Prediction performed in",
durations_in_seconds["prediction"],
"seconds.",
)
print(
"Postprocessing performed in",
durations_in_seconds["postprocess"],
"seconds.",
)
return PredictionResult(
image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds
) def perform_inference(self, image: np.ndarray, num_batch: int = 1):
"""
Prediction is performed using self.model and the prediction result is set to self._original_predictions.
Args:
image: np.ndarray
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
"""
# Confirm model is loaded
if self.model is None:
raise ValueError("Model is not loaded, load it by calling .load_model()")
kwargs = {"cfg": self.config_path, "verbose": False, "conf": self.confidence_threshold, "device": self.device}
if self.image_size is not None:
kwargs = {"imgsz": self.image_size, **kwargs}
if num_batch == 1:
prediction_result = self.model(image[:, :, ::-1], **kwargs)
else:
prediction_result = self.model([img[:, :, ::-1] for img in image], **kwargs) # For batch inference
# Handle different result types for PyTorch vs ONNX models
# ONNX models might return results in a different format
if self.has_mask:
from ultralytics.engine.results import Masks
if not prediction_result[0].masks:
# Create empty masks if none exist
if hasattr(self.model, "device"):
device = self.model.device
else:
device = "cpu" # Default for ONNX models
prediction_result[0].masks = Masks(
torch.tensor([], device=device), prediction_result[0].boxes.orig_shape
)
# We do not filter results again as confidence threshold is already applied above
prediction_result = [
(
result.boxes.data,
result.masks.data,
)
for result in prediction_result
]
elif self.is_obb:
# For OBB task, get OBB points in xyxyxyxy format
device = getattr(self.model, "device", "cpu")
prediction_result = [
(
# Get OBB data: xyxy, conf, cls
torch.cat(
[
result.obb.xyxy, # box coordinates
result.obb.conf.unsqueeze(-1), # confidence scores
result.obb.cls.unsqueeze(-1), # class ids
],
dim=1,
)
if result.obb is not None
else torch.empty((0, 6), device=device),
# Get OBB points in (N, 4, 2) format
result.obb.xyxyxyxy if result.obb is not None else torch.empty((0, 4, 2), device=device),
)
for result in prediction_result
]
else: # If model doesn't do segmentation or OBB then no need to check masks
prediction_result = [result.boxes.data for result in prediction_result]
self._original_predictions = prediction_result
# self._original_shape = image[0].shape def _create_object_prediction_list_from_original_predictions(
self,
shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
full_shape_list: Optional[List[List[int]]] = None,
):
"""
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
self._object_prediction_list_per_image.
Args:
shift_amount_list: list of list
To shift the box and mask predictions from sliced image to full sized image, should
be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
full_shape_list: list of list
Size of the full image after shifting, should be in the form of
List[[height, width],[height, width],...]
"""
original_predictions = self._original_predictions
if not isinstance(shift_amount_list[0],list):
shift_amount_list = fix_shift_amount_list(shift_amount_list)
full_shape_list = fix_full_shape_list(full_shape_list)
# handle all predictions
object_prediction_list_per_image = []
for image_ind, image_predictions in enumerate(original_predictions):
shift_amount = shift_amount_list[image_ind]
full_shape = None if full_shape_list is None else full_shape_list[0]
object_prediction_list = []
# Extract boxes and optional masks/obb
if self.has_mask or self.is_obb:
boxes = image_predictions[0].cpu().detach().numpy()
masks_or_points = image_predictions[1].cpu().detach().numpy()
else:
boxes = image_predictions.data.cpu().detach().numpy()
masks_or_points = None
# Process each prediction
for pred_ind, prediction in enumerate(boxes):
# Get bbox coordinates
bbox = prediction[:4].tolist()
score = prediction[4]
category_id = int(prediction[5])
category_name = self.category_mapping[str(category_id)]
# Fix box coordinates
bbox = [max(0, coord) for coord in bbox]
if full_shape is not None:
bbox[0] = min(full_shape[1], bbox[0])
bbox[1] = min(full_shape[0], bbox[1])
bbox[2] = min(full_shape[1], bbox[2])
bbox[3] = min(full_shape[0], bbox[3])
# Ignore invalid predictions
if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue
# Get segmentation or OBB points
segmentation = None
if masks_or_points is not None:
if self.has_mask:
bool_mask = masks_or_points[pred_ind]
# Resize mask to original image size
bool_mask = cv2.resize(
bool_mask.astype(np.uint8), (self._original_shape[1], self._original_shape[0])
)
segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
else: # is_obb
obb_points = masks_or_points[pred_ind] # Get OBB points for this prediction
segmentation = [obb_points.reshape(-1).tolist()]
if len(segmentation) == 0:
continue
# Create and append object prediction
object_prediction = ObjectPrediction(
bbox=bbox,
category_id=category_id,
score=score,
segmentation=segmentation,
category_name=category_name,
shift_amount=shift_amount,
full_shape=self._original_shape[:2] if full_shape is None else full_shape, # (height, width)
)
object_prediction_list.append(object_prediction)
object_prediction_list_per_image.append(object_prediction_list)
if len(object_prediction_list_per_image) > 1:
for i in range(1, len(object_prediction_list_per_image)):
object_prediction_list_per_image[0].extend(object_prediction_list_per_image[i])
self._object_prediction_list_per_image = object_prediction_list_per_image
其他处理
result.export_visuals()这块在最后加一行
def export_visuals(
self,
export_dir: str,
text_size: Optional[float] = None,
rect_th: Optional[int] = None,
hide_labels: bool = False,
hide_conf: bool = False,
file_name: str = "prediction_visual",
):
"""
Args:
export_dir: directory for resulting visualization to be exported
text_size: size of the category name over box
rect_th: rectangle thickness
hide_labels: hide labels
hide_conf: hide confidence
file_name: saving name
Returns:
"""
Path(export_dir).mkdir(parents=True, exist_ok=True)
res = visualize_object_predictions(
image=np.ascontiguousarray(self.image),
object_prediction_list=self.object_prediction_list,
rect_th=rect_th,
text_size=text_size,
text_th=None,
color=None,
hide_labels=hide_labels,
hide_conf=hide_conf,
output_dir=export_dir,
file_name=file_name,
export_format="png",
)
return res["image"] def yolo_detections(self):
"""
Returns a list of YOLO detections in the format:
[x1, y1, x2, y2, confidence, class_id]
"""
yolo_detections = []
for object_prediction in self.object_prediction_list:
x1, y1, x2, y2 = object_prediction.bbox.to_xyxy()
class_id = object_prediction.category.id
confidence = object_prediction.score.value
yolo_detections.append([x1, y1, x2, y2, confidence, class_id])
return yolo_detections不能直接使用engine,修改
if self.model_path and self.model_path.endswith(".pt"):