Ahmed-El-Sharkawy commited on
Commit
b65b15a
·
1 Parent(s): 794a8d6

Add application file

Browse files
Files changed (3) hide show
  1. .gitingore +0 -0
  2. app.py +82 -0
  3. requirements.txt +6 -0
.gitingore ADDED
File without changes
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import torch
4
+ import os
5
+ import numpy as np
6
+ from torchvision.models.detection import FasterRCNN
7
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
8
+
9
+ # Load Models
10
+ def load_model(model_path, backbone_name, num_classes):
11
+ if backbone_name == "resnet50":
12
+ model = torch.load(model_path)
13
+ elif backbone_name == "mobilenet":
14
+ model = torch.load(model_path)
15
+ model.eval()
16
+ return model
17
+
18
+ resnet_model = load_model('fasterrcnnResnet.pth', 'resnet50', num_classes=6)
19
+ mobilenet_model = load_model('fasterrcnnMobilenet.pth', 'mobilenet', num_classes=6)
20
+
21
+ class_names = ['background', 'Ambulance', 'Bus', 'Car', 'Motorcycle', 'Truck']
22
+
23
+ # Inference Function for Images and Videos
24
+
25
+ def predict_image(image_path, model):
26
+ image = cv2.imread(image_path)
27
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
28
+ image_tensor = torch.tensor(image / 255.0).permute(2, 0, 1).float().unsqueeze(0)
29
+ with torch.no_grad():
30
+ output = model(image_tensor)[0]
31
+ for box, label, score in zip(output['boxes'], output['labels'], output['scores']):
32
+ if score > 0.5:
33
+ x1, y1, x2, y2 = map(int, box.tolist())
34
+ cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
35
+ cv2.putText(image, f"{class_names[label]}: {score:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
36
+ return image
37
+
38
+
39
+ def predict_video(video_path, model):
40
+ cap = cv2.VideoCapture(video_path)
41
+ frames = []
42
+ while cap.isOpened():
43
+ ret, frame = cap.read()
44
+ if not ret:
45
+ break
46
+ frame_tensor = torch.tensor(frame / 255.0).permute(2, 0, 1).float().unsqueeze(0)
47
+ with torch.no_grad():
48
+ output = model(frame_tensor)[0]
49
+ for box, label, score in zip(output['boxes'], output['labels'], output['scores']):
50
+ if score > 0.5:
51
+ x1, y1, x2, y2 = map(int, box.tolist())
52
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
53
+ cv2.putText(frame, f"{class_names[label]}: {score:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
54
+ frames.append(frame)
55
+ cap.release()
56
+ return frames[0] if frames else None
57
+
58
+ # Gradio Interface for Image and Video Inference
59
+
60
+ model_selection = gr.Dropdown(choices=["ResNet50", "MobileNet"], label="Select Model")
61
+
62
+ inputs_image = [gr.Image(type="filepath", label="Upload Image"), model_selection]
63
+ outputs_image = gr.Image(type="numpy", label="Detection Output")
64
+
65
+ inputs_video = [gr.Video(type="filepath", label="Upload Video"), model_selection]
66
+ outputs_video = gr.Image(type="numpy", label="Detection Output")
67
+
68
+ image_interface = gr.Interface(
69
+ fn=lambda img, model_name: predict_image(img, resnet_model if model_name == "ResNet50" else mobilenet_model),
70
+ inputs=inputs_image,
71
+ outputs=outputs_image,
72
+ title="Image Inference"
73
+ )
74
+
75
+ video_interface = gr.Interface(
76
+ fn=lambda vid, model_name: predict_video(vid, resnet_model if model_name == "ResNet50" else mobilenet_model),
77
+ inputs=inputs_video,
78
+ outputs=outputs_video,
79
+ title="Video Inference"
80
+ )
81
+
82
+ gr.TabbedInterface([image_interface, video_interface], ["Image", "Video"]).launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ opencv-python-headless
3
+ torch
4
+ torchvision
5
+ numpy
6
+ matplotlib