Tumor-Classification / tumorseg.py
itachi-ai's picture
tumor seg and chest xray added
539be38 verified
raw
history blame contribute delete
No virus
2.14 kB
import os
import tensorflow as tf
from tensorflow.keras import backend as K
import tf_keras
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')
def dice_coefficients(y_true, y_pred, smooth=100):
y_true_flatten = K.flatten(y_true)
y_pred_flatten = K.flatten(y_pred)
intersection = K.sum(y_true_flatten * y_pred_flatten)
union = K.sum(y_true_flatten) + K.sum(y_pred_flatten)
return (2 * intersection + smooth) / (union + smooth)
def dice_coefficients_loss(y_true, y_pred, smooth=100):
return 1.0 - dice_coefficients(y_true, y_pred, smooth)
def iou(y_true, y_pred, smooth=100):
intersection = K.sum(y_true * y_pred)
sum = K.sum(y_true + y_pred)
iou = (intersection + smooth) / (sum - intersection + smooth)
return iou
def jaccard_distance(y_true, y_pred):
y_true_flatten = K.flatten(y_true)
y_pred_flatten = K.flatten(y_pred)
return -iou(y_true_flatten, y_pred_flatten)
segmodel = tf_keras.models.load_model("segment_model/V2", custom_objects={'dice_coefficients_loss': dice_coefficients_loss, 'iou': iou, 'dice_coefficients': dice_coefficients } )
def load_image_for_pred(image_path):
img = tf.keras.utils.load_img(
image_path,
color_mode='rgb',
target_size=(256, 256),
interpolation='nearest',
keep_aspect_ratio=False
)
img = tf.keras.utils.img_to_array(img) / 255
return np.array([img])
def make_segmentation(image_path):
img = load_image_for_pred(image_path)
predicted_img = segmodel.predict(img)
plt.figure(figsize=(5, 3))
plt.subplot(1, 3, 1)
plt.imshow(np.squeeze(img))
plt.title('Original Image')
plt.axis(False)
plt.subplot(1, 3, 2)
plt.imshow(np.squeeze(predicted_img) > 0.5)
plt.title('Prediction')
plt.axis(False)
plt.subplot(1, 4, 4)
plt.imshow(np.squeeze(img))
plt.imshow(np.squeeze(predicted_img) > 0.5, cmap='gray', alpha=0.5)
plt.title('Image w/h Mask')
plt.axis(False)
save_file_name = os.path.splitext(image_path)[0] + '_segmented.png'
plt.savefig(save_file_name)
return save_file_name