import torch from transformers import AutoImageProcessor, Dinov2ForImageClassification, Dinov2Config, Dinov2Model from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download import json import torch.nn as nn import numpy as np # DEFINE MODEL NAME model_name = "DinoVdeau_Aina-large-2024_06_12-batch-size32_epochs150_freeze" checkpoint_name = "lombardata/" + model_name # Load the model configuration and create the model config_path = hf_hub_download(repo_id=checkpoint_name, filename="config.json") with open(config_path, 'r') as config_file: config = json.load(config_file) id2label = config["id2label"] label2id = config["label2id"] image_size = config["image_size"] num_labels = len(id2label) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # IMPORT CLASSIFICATION MODEL def create_head(num_features , number_classes ,dropout_prob=0.5 ,activation_func =nn.ReLU): features_lst = [num_features , num_features//2 , num_features//4] layers = [] for in_f ,out_f in zip(features_lst[:-1] , features_lst[1:]): layers.append(nn.Linear(in_f , out_f)) layers.append(activation_func()) layers.append(nn.BatchNorm1d(out_f)) if dropout_prob !=0 : layers.append(nn.Dropout(dropout_prob)) layers.append(nn.Linear(features_lst[-1] , number_classes)) return nn.Sequential(*layers) class NewheadDinov2ForImageClassification(Dinov2ForImageClassification): def __init__(self, config: Dinov2Config) -> None: super().__init__(config) # Classifier head self.classifier = create_head(config.hidden_size * 2, config.num_labels) model = NewheadDinov2ForImageClassification.from_pretrained(checkpoint_name) model.to(device) def sigmoid(_outputs): return 1.0 / (1.0 + np.exp(-_outputs)) def predict(image, slider_threshold=0.5, fixed_thresholds=None): # Preprocess the image processor = AutoImageProcessor.from_pretrained(checkpoint_name) inputs = processor(images=image, return_tensors="pt").to(device) # Get model predictions with torch.no_grad(): model_outputs = model(**inputs) logits = model_outputs.logits[0] probabilities = torch.sigmoid(logits).cpu().numpy() # Convert to probabilities # Create a dictionary of label scores based on the slider threshold slider_results = {id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities) if prob > slider_threshold} # If fixed thresholds are provided, format the labels into a string fixed_threshold_labels_str = None if fixed_thresholds is not None: fixed_threshold_labels = [id2label[str(i)] for i, prob in enumerate(probabilities) if prob > fixed_thresholds[id2label[str(i)]]] fixed_threshold_labels_str = ", ".join(fixed_threshold_labels) return slider_results, fixed_threshold_labels_str def predict_wrapper(image, slider_threshold=0.5): # Get predictions from the predict function using both the slider and fixed thresholds slider_results, fixed_threshold_results = predict(image, slider_threshold) # Return both sets of predictions for Gradio outputs return slider_results, fixed_threshold_results # Define style title = "Aina image classification" model_link = "https://huggingface.co/" + checkpoint_name description = f"This application showcases the capability of artificial intelligence-based systems to identify objects within underwater images. To utilize it, you can either upload your own image or select one of the provided examples for analysis.\nFor predictions, we use this [open-source model]({model_link})" description = ("This application showcases the capability of artificial intelligence-based systems " "to identify objects within underwater images. To utilize it, you can either upload " "your own image or select one of the provided examples for analysis. " "\nFor predictions, we use this [open-source model](model_link)") # New subtitle message to be added subtitle = "Note: the model runs on CPU, so it may take a while to run the prediction." # Combine description and subtitle full_description = f"{description}\n\n{subtitle}" iface = gr.Interface( fn=predict_wrapper, inputs=[gr.components.Image(type="pil"), gr.components.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Threshold")], outputs=[ gr.components.Textbox(label="Fixed Threshold Predictions") ], title=title, description=full_description).launch()