File size: 4,750 Bytes
181162f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# This was made by following this tutorial
# https://www.youtube.com/watch?v=i40ulpcacFM

!pip install -U -q segmentation-models
# # Open the file in write mode
# with open('/usr/local/lib/python3.9/dist-packages/efficientnet/keras.py', 'r') as f:
#     # Read the contents of the file
#     contents = f.read()

# # Replace the string
# new_contents = contents.replace('init_keras_custom_objects', 'init_tfkeras_custom_objects')

# # Open the file in write mode again and write the modified contents
# with open('/usr/local/lib/python3.9/dist-packages/efficientnet/keras.py', 'w') as f:
#     f.write(new_contents)

!pip install patchify
!pip install gradio

import os
from os.path import join as pjoin
import cv2
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
from PIL import Image
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from patchify import patchify, unpatchify

from keras import backend as K
from keras.models import load_model 
     
import segmentation_models as sm


import gradio as gr
     
def jaccard_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)

weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
dice_loss = sm.losses.DiceLoss(class_weights=weights) 
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)  

model_path = 'models/satellite_segmentation_100-epochs.h5'
saved_model = load_model(model_path,
                         custom_objects=({'dice_loss_plus_1focal_loss': total_loss, 
                                          'jaccard_coef': jaccard_coef}))


def process_input_image(test_image):
  test_dataset = []
  image_patch_size = 256
  scaler = MinMaxScaler()

  # crop images so that they are divisible by image_patch_size
  test_image = np.array(test_image)
  size_x = (test_image.shape[1]//image_patch_size)*image_patch_size
  size_y = (test_image.shape[0]//image_patch_size)*image_patch_size

  test_image = Image.fromarray(test_image)
  test_image = test_image.crop((0, 0, size_x, size_y))
            
  # patchify image so that each patch is size (image_patch_size,image_patch_size)
  test_image = np.array(test_image)
  image_patches = patchify(test_image, (image_patch_size,image_patch_size, 3), step = image_patch_size) # 3 should probably be a variable since we have have  many more channels than RGB

  # scale values so that they are between 0 to 1
  # here, we use MinMaxScaler from sklearn

  for i in range(image_patches.shape[0]):
    for j in range(image_patches.shape[1]):
      image_patch = image_patches[i,j,:,:]

      image_patch = scaler.fit_transform(image_patch.reshape(-1, image_patch.shape[-1])).reshape(image_patch.shape)
      
      image_patch = image_patch[0] # drop extra unessesary dimantion that patchify adds
      test_dataset.append(image_patch)

  test_dataset = [np.expand_dims(np.array(x), 0) for x in test_dataset]

  test_prediction = []
      
  for image in tqdm(test_dataset):
    prediction = saved_model.predict(image,verbose=0)
    predicted_image = np.argmax(prediction, axis=3)
    predicted_image = predicted_image[0,:,:]
    test_prediction.append(predicted_image)


  reconstructed_image = np.reshape(np.array(test_prediction),(image_patches.shape[0],image_patches.shape[1],image_patch_size,image_patch_size))
  reconstructed_image  =  unpatchify(reconstructed_image , (size_y,size_x))

  lookup = {'rgb': [np.array([ 60,  16, 152]),
    np.array([132,  41, 246]),
    np.array([110, 193, 228]),
    np.array([254, 221,  58]),
    np.array([226, 169,  41]),
    np.array([155, 155, 155])],
  'int': [0, 1, 2, 3, 4, 5]}

  rgb_image = np.zeros((reconstructed_image.shape[0],reconstructed_image.shape[1],3), dtype=np.uint8)

  for i,l in enumerate(lookup['int']):
    rgb_image[np.where(reconstructed_image==l)] = lookup['rgb'][i]
  return 'Predicted Masked Image', rgb_image


my_app = gr.Blocks()
with my_app:
  gr.Markdown("Statellite Image Segmentation Application UI with Gradio")
  with gr.Tabs():
    with gr.TabItem("Select your image"):
      with gr.Row():
        with gr.Column():
            img_source = gr.Image(label="Please select source Image")
            source_image_loader = gr.Button("Load above Image")
        with gr.Column():
            output_label = gr.Label(label="Image Info")
            img_output = gr.Image(label="Image Output")
    source_image_loader.click(
        process_input_image,
        [
            img_source
        ],
        [
            output_label,
            img_output
        ]
    )

my_app.launch(debug=True)