File size: 5,546 Bytes
7e5b31a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/vasim/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import gradio as gr\n",
    "import torch\n",
    "from transformers import AutoImageProcessor, ViTForImageClassification, AutoModelForImageClassification\n",
    "from peft import PeftConfig, PeftModel\n",
    "import os \n",
    "from PIL import Image\n",
    "import requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:\n",
      "- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([37]) in the model instantiated\n",
      "- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([37, 768]) in the model instantiated\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.\n"
     ]
    }
   ],
   "source": [
    "classes = ['chihuahua', 'newfoundland', 'english setter', 'Persian', 'yorkshire terrier', 'Maine Coon', 'boxer', 'leonberger', 'Birman', 'staffordshire bull terrier', 'Egyptian Mau', 'shiba inu', 'wheaten terrier', 'miniature pinscher', 'american pit bull terrier', 'Bombay', 'British Shorthair', 'german shorthaired', 'american bulldog', 'Abyssinian', 'great pyrenees', 'Siamese', 'Sphynx', 'english cocker spaniel', 'japanese chin', 'havanese', 'Russian Blue', 'saint bernard', 'samoyed', 'scottish terrier', 'keeshond', 'Bengal', 'Ragdoll', 'pomeranian', 'beagle', 'basset hound', 'pug']\n",
    "\n",
    "label2id = {c:idx for idx,c in enumerate(classes)}\n",
    "id2label = {idx:c for idx,c in enumerate(classes)}\n",
    "\n",
    "model_name = \"vit-base-patch16-224\"\n",
    "model_checkpoint = f\"google/{model_name}\"\n",
    "\n",
    "repo_name = f\"md-vasim/{model_name}-finetuned-lora-oxfordPets\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(repo_name)\n",
    "model = AutoModelForImageClassification.from_pretrained(\n",
    "    config.base_model_name_or_path,\n",
    "    label2id=label2id,\n",
    "    id2label=id2label,\n",
    "    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint\n",
    ")\n",
    "# Load the Lora model\n",
    "inference_model = PeftModel.from_pretrained(model, repo_name)\n",
    "image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)\n",
    "\n",
    "def inference(image):\n",
    "    encoding = image_processor(image.convert(\"RGB\"), return_tensors=\"pt\")\n",
    "    with torch.no_grad():\n",
    "        outputs = inference_model(**encoding)\n",
    "        logits = outputs.logits\n",
    "    predicted_class_idx = logits.argmax(-1).item()\n",
    "    result = inference_model.config.id2label[predicted_class_idx]\n",
    "    return result "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7861\n",
      "\n",
      "To create a public link, set `share=True` in `launch()`.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"http://127.0.0.1:7861/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "output_box = gr.Textbox(\n",
    "            label=\"Output\"\n",
    "        )\n",
    "\n",
    "examples=[\n",
    "        [\"../static/shiba_inu_174.jpg\"],\n",
    "        [\"../static/pomeranian_89.jpg\"],\n",
    "    ]\n",
    "\n",
    "demo = gr.Interface(inference, gr.Image(type=\"pil\"), output_box, examples=examples)\n",
    "if __name__ == \"__main__\":\n",
    "    demo.launch()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}