import json from typing import Any, Dict, List import tensorflow as tf from tensorflow.keras.models import load_model import base64 import io import os import numpy as np from PIL import Image # most of this code has been obtained from Datature's prediction script # https://github.com/datature/resources/blob/main/scripts/bounding_box/prediction.py # def load_model(): # return tf.saved_model.load('./saved_model') # model = load_model() class PreTrainedPipeline(): def __init__(self, path: str): # load the model self.model = tf.saved_model.load(os.path.join(path, "saved_model")) def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]: # # convert img to numpy array, resize and normalize to make the prediction # img = np.array(inputs) # im = tf.image.resize(img, (128, 128)) # im = tf.cast(im, tf.float32) / 255.0 # pred_mask = self.model.predict(im[tf.newaxis, ...]) # # take the best performing class for each pixel # # the output of argmax looks like this [[1, 2, 0], ...] # pred_mask_arg = tf.argmax(pred_mask, axis=-1) # labels = [] # # convert the prediction mask into binary masks for each class # binary_masks = {} # mask_codes = {} # # when we take tf.argmax() over pred_mask, it becomes a tensor object # # the shape becomes TensorShape object, looking like this TensorShape([128]) # # we need to take get shape, convert to list and take the best one # rows = pred_mask_arg[0][1].get_shape().as_list()[0] # cols = pred_mask_arg[0][2].get_shape().as_list()[0] # for cls in range(pred_mask.shape[-1]): # binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class # for row in range(rows): # for col in range(cols): # if pred_mask_arg[0][row][col] == cls: # binary_masks[f"mask_{cls}"][row][col] = 1 # else: # binary_masks[f"mask_{cls}"][row][col] = 0 # mask = binary_masks[f"mask_{cls}"] # mask *= 255 # img = Image.fromarray(mask.astype(np.int8), mode="L") # # we need to make it readable for the widget # with io.BytesIO() as out: # img.save(out, format="PNG") # png_string = out.getvalue() # mask = base64.b64encode(png_string).decode("utf-8") # mask_codes[f"mask_{cls}"] = mask # # widget needs the below format, for each class we return label and mask string # labels.append({ # "label": f"LABEL_{cls}", # "mask": mask_codes[f"mask_{cls}"], # "score": 1.0, # }) # labels = [{"score":0.9509243965148926,"label":"car","box":{"xmin":142,"ymin":106,"xmax":376,"ymax":229}}, # {"score":0.9981777667999268,"label":"car","box":{"xmin":405,"ymin":146,"xmax":640,"ymax":297}}, # {"score":0.9963648915290833,"label":"car","box":{"xmin":0,"ymin":115,"xmax":61,"ymax":167}}, # {"score":0.974663257598877,"label":"car","box":{"xmin":155,"ymin":104,"xmax":290,"ymax":141}}, # {"score":0.9986898303031921,"label":"car","box":{"xmin":39,"ymin":117,"xmax":169,"ymax":188}}, # {"score":0.9998276233673096,"label":"person","box":{"xmin":172,"ymin":60,"xmax":482,"ymax":396}}, # {"score":0.9996274709701538,"label":"skateboard","box":{"xmin":265,"ymin":348,"xmax":440,"ymax":413}}] labels = [] return labels