Rohit8y commited on
Commit
01e150d
1 Parent(s): d982c0a

Application init

Browse files
Files changed (2) hide show
  1. app.py +45 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+
4
+ from detectron2.config import get_cfg
5
+ from detectron2 import model_zoo
6
+ from detectron2.engine import DefaultPredictor
7
+ from detectron2.utils.visualizer import Visualizer
8
+ from detectron2.data import MetadataCatalog
9
+
10
+
11
+ def predict(input_image):
12
+ # Initialise model
13
+ cfg = get_cfg()
14
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
15
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
16
+ cfg.MODEL.DEVICE = "cpu"
17
+ predictor = DefaultPredictor(cfg)
18
+
19
+ assert input_image.shape[2] == 3
20
+ height, width, _ = input_image.shape
21
+
22
+ # Apply Panoptic segmentation
23
+ panoptic_seg, segments_info = predictor(input_image)["panoptic_seg"]
24
+ v = Visualizer(input_image[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
25
+ out = v.draw_panoptic_seg_predictions(panoptic_seg.to("cpu"), segments_info)
26
+ segmented_image = out.get_image()[:, :, ::-1]
27
+
28
+ # Resize image if required
29
+ if not segmented_image.shape[:2] == (height, width):
30
+ segmented_image = cv2.resize(segmented_image, (height, width))
31
+
32
+ # Combine the segmented and original image
33
+ combined_image = cv2.hconcat([segmented_image, input_image])
34
+
35
+ return combined_image
36
+
37
+
38
+ # Create Gradio interface
39
+ image_input = gr.Image(type="pil", label="Input Image")
40
+
41
+ iface = gr.Interface(fn=predict,
42
+ inputs=[image_input],
43
+ outputs=gr.Image(type="pil"),
44
+ examples=[["examples/trump-fake.jpeg"]])
45
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ detectron2==0.6
2
+ gradio==4.29.0
3
+ numpy==1.23.5
4
+ opencv_python==4.8.0.76