Quantcast
Channel: TAO Toolkit - NVIDIA Developer Forums
Viewing all articles
Browse latest Browse all 549

Getting wrong result on ONNX file using cutom script. (TLT to ONNX)

$
0
0

Hi @Morganh ,

I have converted one model from tlt to onnx. But when I am performing inference then It is giving me multiple b-boxes. Please suggest where I am making mistake. Below is the code.

`import onnxruntime as ort
import onnx
import cv2
import numpy as np

def validate_onnx_model(model_onnx_path):
    try:
        onnx_model = onnx.load(model_onnx_path)
        onnx.checker.check_model(onnx_model)
        print("ONNX model is valid.")
        return True
    except Exception as e:
        print(f"ONNX model validation failed: {e}")
        return False

def model_detection_init(model_onnx_path):
    try:
        session = ort.InferenceSession(model_onnx_path)
        input_name = session.get_inputs()[0].name
        output_names = [output.name for output in session.get_outputs()]
        
        input_shape = session.get_inputs()[0].shape
        detection_height, detection_width = input_shape[2], input_shape[3]
        
        return session, input_name, output_names, detection_height, detection_width
    except Exception as e:
        print("Exception in Detection model load:", e)
        return None

model_h = 544
model_w = 960
stride = 16
box_norm = 35.0

grid_h = int(model_h / stride)
grid_w = int(model_w / stride)
grid_size = grid_h * grid_w

grid_centers_w = [(i * stride + 0.5) / box_norm for i in range(grid_w)]
grid_centers_h = [(i * stride + 0.5) / box_norm for i in range(grid_h)]

def applyBoxNorm(o1, o2, o3, o4, x, y):
    o1 = (o1 - grid_centers_w[x]) * -box_norm
    o2 = (o2 - grid_centers_h[y]) * -box_norm
    o3 = (o3 + grid_centers_w[x]) * box_norm
    o4 = (o4 + grid_centers_h[y]) * box_norm
    return o1, o2, o3, o4

def postprocess(outputs, min_confidence, analysis_classes, wh_format=True):
    bbs = []
    class_ids = []
    scores = []
    for c in analysis_classes:
        x1_idx = c * 4 * grid_size
        y1_idx = x1_idx + grid_size
        x2_idx = y1_idx + grid_size
        y2_idx = x2_idx + grid_size

        boxes = outputs[0]
        for h in range(grid_h):
            for w in range(grid_w):
                i = w + h * grid_w
                score = outputs[1][c * grid_size + i]
                if score >= min_confidence:
                    o1 = boxes[x1_idx + w + h * grid_w]
                    o2 = boxes[y1_idx + w + h * grid_w]
                    o3 = boxes[x2_idx + w + h * grid_w]
                    o4 = boxes[y2_idx + w + h * grid_w]

                    o1, o2, o3, o4 = applyBoxNorm(o1, o2, o3, o4, w, h)

                    xmin = int(o1)
                    ymin = int(o2)
                    xmax = int(o3)
                    ymax = int(o4)
                    if wh_format:
                        bbs.append([xmin, ymin, xmax - xmin, ymax - ymin])
                    else:
                        bbs.append([xmin, ymin, xmax, ymax])
                    class_ids.append(c)
                    scores.append(float(score))

    return bbs, class_ids, scores

NUM_CLASSES = 3
threshold = 0.01

def vehicle_detection(frames, session, input_name, output_names, detection_height, detection_width):
    try:
        input_images = np.stack([cv2.resize(image, (detection_width, detection_height)) for image in frames])
        input_images = input_images.transpose((0, 3, 1, 2)).astype(np.float32)

        results = session.run(output_names, {input_name: input_images})

        score_info = results[1].reshape(6120)
        output_tensor = np.squeeze(results[0])

        outputs = output_tensor.reshape(24480)

        bboxes, class_ids, scores = postprocess([outputs, score_info], threshold, list(range(NUM_CLASSES)))
        print("bboxes : ",class_ids)

        # Filter out boxes with low confidence scores
        filtered_bboxes = []
        filtered_class_ids = []
        filtered_scores = []
        for i, score in enumerate(scores):
            print("score : ",score)
            if score >= threshold:
                filtered_bboxes.append(bboxes[i])
                filtered_class_ids.append(class_ids[i])
                filtered_scores.append(score)
                
        # Apply NMS
        if len(filtered_bboxes) > 0:
            # Adjust these parameters based on your requirements
            nms_threshold = 0.4  # overlapThreshold
            score_threshold = threshold  # scoreThreshold

            # Convert bounding boxes to the format required by NMSBoxes
            bboxes = np.array(filtered_bboxes)
            scores = np.array(filtered_scores)

            # Apply NMS
            indices = cv2.dnn.NMSBoxes(bboxes.tolist(), scores.tolist(), score_threshold, nms_threshold)

            print("indices : ",indices)

            # Draw remaining boxes after NMS
            for idx in indices:
                idx = int(idx)
                xmin, ymin, w, h = filtered_bboxes[idx]
                print("xmin, ymin, w, h ",xmin, ymin, w, h)
                if xmin > 0 and ymin > 0:
                    class_id = filtered_class_ids[idx]
                    color = [255, 0, 0] if class_id else [0, 0, 255]
                    cv2.rectangle(frames[0], (xmin, ymin), (xmin + w, ymin + h), color, 2)
                    cv2.imwrite("image1.jpg", frames[0])
        else:
            print("No valid detections after score filtering.")

    except Exception as e:
        print("Exception in vehicle detection:", e)

# Path to the ONNX model file
detection_model_onnx_path = '/home/smarg/Documents/openvino_container/MODEL/VehicleDetection_MobileNetV1_ReTrained_V1.6.onnx'

if validate_onnx_model(detection_model_onnx_path):
    session, input_name, output_names, detection_height, detection_width = model_detection_init(detection_model_onnx_path)
    if session:
        img = './input_image.jpg'
        frames = [cv2.imread(img)]
        vehicle_detection(frames, session, input_name, output_names, detection_height, detection_width)
    else:
        print("Failed to initialize the detection model.")
else:
    print("ONNX model validation failed.")
`

Below is the output image.

Please suggest.

Thanks.

8 posts - 2 participants

Read full topic


Viewing all articles
Browse latest Browse all 549

Trending Articles