File size: 3,897 Bytes
5637560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bf2094
 
5637560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b36573
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# ------------ tackle some noisy warning
import os
import warnings


def warn(*args, **kwargs):
    pass

warnings.warn = warn
warnings.filterwarnings("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import random

import gdown
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from PIL import Image

import mrcnn.model as modellib
from config import WheatDetectorConfig
from config import WheatInferenceConfig
from mrcnn import utils
from mrcnn import visualize
from mrcnn.model import log
from utils import get_ax


# for reproducibility
def seed_all(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    os.environ["PYTHONHASHSEED"] = str(SEED)


ORIG_SIZE = 1024
seed_all(42)

config = WheatDetectorConfig()
inference_config = WheatInferenceConfig()


def get_model_weight(model_id):
    """Get the trained weights."""
    if not os.path.exists("model.h5"):
        model_weight = gdown.download(id=model_id, quiet=False)
    else:
        model_weight = "model.h5"
    return model_weight


def get_model():
    """Get the model."""
    model = modellib.MaskRCNN(mode="inference", config=inference_config, model_dir="./")
    return model


def load_model(model_id):
    """Load trained model."""
    weight = get_model_weight(model_id)
    model = get_model()
    model.load_weights(weight, by_name=True)
    return model


def prepare_image(image):
    """Prepare incoming sample."""
    image = image[:, :, ::-1]
    resize_factor = ORIG_SIZE / config.IMAGE_SHAPE[0]

    # If grayscale. Convert to RGB for consistency.
    if len(image.shape) != 3 or image.shape[2] != 3:
        image = np.stack((image,) * 3, -1)

    resized_image, window, scale, padding, crop = utils.resize_image(
        image,
        min_dim=config.IMAGE_MIN_DIM,
        min_scale=config.IMAGE_MIN_SCALE,
        max_dim=config.IMAGE_MAX_DIM,
        mode=config.IMAGE_RESIZE_MODE,
    )

    return resized_image


def predict_fn(image):

    image = prepare_image(image)

    model = load_model(model_id="1k4_WGBAUJCPbkkHkvtscX2jufTqETNYd")
    results = model.detect([image])
    r = results[0]
    class_names = ["Wheat"] * len(r["rois"])

    image = visualize.display_instances(
        image,
        r["rois"],
        r["masks"],
        r["class_ids"],
        class_names,
        r["scores"],
        ax=get_ax(),
        title="Predictions",
    )

    return image[:, :, ::-1]

title="Global Wheat Detection with Mask-RCNN Model"
description="<strong>Model</strong>: Mask-RCNN. <strong>Backbone</strong>: ResNet-101. Trained on: <a href='https://www.kaggle.com/competitions/global-wheat-detection/overview'>Global Wheat Detection Dataset (Kaggle)</a>. </br>The code is written in <code>Keras (TensorFlow 1.14)</code>. One can run the full code on Kaggle: <a href='https://www.kaggle.com/code/ipythonx/keras-global-wheat-detection-with-mask-rcnn'>[Keras]:Global Wheat Detection with Mask-RCNN</a>"
article = "<p>The model received <strong>0.6449</strong> and <strong>0.5675</strong> mAP (0.5:0.75:0.05) on the public and private test dataset respectively. The above examples are from test dataset without ground truth bounding box. Details: <a href='https://www.kaggle.com/competitions/global-wheat-detection/data'>Global Wheat Dataset</a></p>"

iface = gr.Interface(
    fn=predict_fn,
    inputs=gr.Image(label="Input Image"),
    outputs=gr.Image(label="Prediction"),
    title=title,
    description=description,
    article=article,
    examples=[
        ["examples/2fd875eaa.jpg"],
        ["examples/51b3e36ab.jpg"],
        ["examples/51f1be19e.jpg"],
        ["examples/53f253011.jpg"],
        ["examples/348a992bb.jpg"],
        ["examples/796707dd7.jpg"],
        ["examples/aac893a91.jpg"],
        ["examples/cb8d261a3.jpg"],
        ["examples/cc3532ff6.jpg"],
        ["examples/f5a1f0358.jpg"],
    ],
)
iface.launch(share=True)