from transformers import ViTFeatureExtractor, ViTForImageClassification import gradio as gr import torch # Initialize the ViT feature extractor and model feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') # Define a function to make predictions def classify_image(image): # Preprocess the image inputs = feature_extractor(images=image, return_tensors="pt") # Make predictions outputs = model(**inputs) logits = outputs.logits # Get the predicted class index and label predicted_class_idx = logits.argmax(-1).item() predicted_class_label = model.config.id2label[predicted_class_idx] return predicted_class_label # Define Gradio interface for image upload input_image = gr.inputs.Image(type="pil", label="Upload an image") output_text = gr.outputs.Textbox(label="Predicted Class") # Create Gradio interface gr.Interface(fn=classify_image, inputs=input_image, outputs=output_text, title="Vision Transformer Image Classifier", description="Upload an image to classify its content.").launch()