lombardata's picture
Update app.py
1acb5c4 verified
raw
history blame
No virus
3.36 kB
import numpy as np
import gradio as gr
import torch
from transformers import Dinov2Config, Dinov2Model, Dinov2ForImageClassification, AutoImageProcessor
import torch.nn as nn
import os
# Load PyTorch model
def create_head(num_features , number_classes ,dropout_prob=0.5 ,activation_func =nn.ReLU):
features_lst = [num_features , num_features//2 , num_features//4]
layers = []
for in_f ,out_f in zip(features_lst[:-1] , features_lst[1:]):
layers.append(nn.Linear(in_f , out_f))
layers.append(activation_func())
layers.append(nn.BatchNorm1d(out_f))
if dropout_prob !=0 : layers.append(nn.Dropout(dropout_prob))
layers.append(nn.Linear(features_lst[-1] , number_classes))
return nn.Sequential(*layers)
class NewheadDinov2ForImageClassification(Dinov2ForImageClassification):
def __init__(self, config: Dinov2Config) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.dinov2 = Dinov2Model(config)
# Classifier head
self.classifier = create_head(config.hidden_size * 2, config.num_labels)
# IMPORT CLASSIFICATION MODEL
checkpoint_name = "lombardata/dino-base-2023_11_27-with_custom_head"
# import labels
classes_names = ["Acropore_branched", "Acropore_digitised", "Acropore_tabular", "Algae_assembly",
"Algae_limestone", "Algae_sodding", "Dead_coral", "Fish", "Human_object",
"Living_coral", "Millepore", "No_acropore_encrusting", "No_acropore_massive",
"No_acropore_sub_massive", "Rock", "Sand",
"Scrap", "Sea_cucumber", "Syringodium_isoetifolium",
"Thalassodendron_ciliatum", "Useless"]
classes_nb = list(np.arange(len(classes_names)))
id2label = {int(classes_nb[i]): classes_names[i] for i in range(len(classes_nb))}
label2id = {v: k for k, v in id2label.items()}
model = NewheadDinov2ForImageClassification.from_pretrained(checkpoint_name)
def sigmoid(_outputs):
return 1.0 / (1.0 + np.exp(-_outputs))
def predict(input_image):
image_processor = AutoImageProcessor.from_pretrained(checkpoint_name)
# predict
inputs = image_processor(input_image, return_tensors="pt")
inputs = inputs
with torch.no_grad():
model_outputs = model(**inputs)
outputs = model_outputs["logits"][0]
scores = sigmoid(outputs)
result = {}
i = 0
for score in scores:
label = id2label[i]
result[label] = float(score)
i += 1
result = {key: result[key] for key in result if result[key] > 0.5}
return result
# Define style
title = "DinoVd'eau image classification"
description = f"This is a prototype application that demonstrates how artificial intelligence-based systems can recognize what object(s) is present in an underwater image. To use it, simply upload your image, or click one of the example images to load them. For predictions, we use the open-source model {checkpoint_name}"
gr.Interface(
fn=predict,
inputs=gr.Image(shape=(224, 224)),
outputs="label",
title=title,
description=description,
examples=["GOPR0106.JPG",
"session_2021_08_30_Mayotte_10_image_00066.jpg",
"session_2018_11_17_kite_Le_Morne_Manawa_G0065777.JPG",
"session_2023_06_28_caplahoussaye_plancha_body_v1B_00_GP1_3_1327.jpeg"]).launch()