import gradio as gr from Models import VisionModel import huggingface_hub from PIL import Image import torch.amp.autocast_mode from pathlib import Path MODEL_REPO = "fancyfeast/joytag" @torch.no_grad() def predict(image: Image.Image): with torch.amp.autocast_mode.autocast('cuda', enabled=True): preds = model(image) tag_preds = preds['tags'].sigmoid().cpu() return {top_tags[i]: tag_preds[i] for i in range(len(top_tags))} print("Downloading model...") path = huggingface_hub.snapshot_download(MODEL_REPO) print("Loading model...") model = VisionModel.load_model(path) model.eval() with open(Path(path) / 'top_tags.txt', 'r') as f: top_tags = [line.strip() for line in f.readlines() if line.strip()] print("Starting server...") gradio_app = gr.Interface( predict, inputs=gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), outputs=[gr.Label(label="Result", num_top_classes=5)], title="JoyTag", ) if __name__ == '__main__': gradio_app.launch()