|
import io |
|
import gradio as gr |
|
import requests, validators |
|
import torch |
|
import pathlib |
|
from PIL import Image |
|
import datasets |
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
import os |
|
import IPython |
|
|
|
|
|
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" |
|
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained("saved_model_files") |
|
model = AutoModelForImageClassification.from_pretrained("saved_model_files") |
|
|
|
labels = ['angular_leaf_spot', 'bean_rust', 'healthy'] |
|
|
|
def classify(im): |
|
'''FUnction for classifying plant health status''' |
|
|
|
features = feature_extractor(im, return_tensors='pt') |
|
with torch.no_grad(): |
|
logits = model(**features).logits |
|
probability = torch.nn.functional.softmax(logits, dim=-1) |
|
probs = probability[0].detach().numpy() |
|
confidences = {label: float(probs[i]) for i, label in enumerate(labels)} |
|
|
|
return confidences |
|
|
|
def get_original_image(url_input): |
|
'''Get image from URL''' |
|
if validators.url(url_input): |
|
|
|
image = Image.open(requests.get(url_input, stream=True).raw) |
|
|
|
return image |
|
|
|
def detect_plant_health(url_input,image_input,webcam_input): |
|
|
|
if validators.url(url_input): |
|
image = Image.open(requests.get(url_input, stream=True).raw) |
|
|
|
elif image_input: |
|
image = image_input |
|
|
|
elif webcam_input: |
|
image = webcam_input |
|
|
|
|
|
label_probs = classify(image) |
|
|
|
return label_probs |
|
|
|
def set_example_image(example: list) -> dict: |
|
return gr.Image.update(value=example[0]) |
|
|
|
def set_example_url(example: list) -> dict: |
|
return gr.Textbox.update(value=example[0]), gr.Image.update(value=get_original_image(example[0])) |
|
|
|
|
|
title = """<h1 id="title">Plant Health Classification with ViT</h1>""" |
|
|
|
gr.Image(pathlib.Path('images/Healthy.png'),label = 'Healthy Plant') |
|
|
|
gr.Image(pathlib.Path('images/sickie.png'),label = 'Infected Plant') |
|
|
|
from IPython.display import display, Image |
|
display(Image(filename=pathlib.Path('images/sickie.png'))) |
|
|
|
description = """ |
|
This Plant Health classifier app was built to detect the health of plants using images of leaves by fine-tuning a Vision Transformer (ViT) [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) on the [Beans](https://huggingface.co/datasets/beans) dataset. |
|
The finetuned model has an accuracy of 98.4% on the test (unseen) dataset and 100% on the validation dataset. |
|
|
|
How to use the app: |
|
- Upload an image via 3 options, uploading the image from local device, using a URL (image from the web) or a webcam |
|
- The app will take a few seconds to generate a prediction with the following labels: |
|
- *angular_leaf_spot* |
|
- *bean_rust* |
|
- *healthy* |
|
- Feel free to click the image examples as well. |
|
""" |
|
urls = ["https://www.healthbenefitstimes.com/green-beans/","https://huggingface.co/nateraw/vit-base-beans/resolve/main/angular_leaf_spot.jpeg", "https://huggingface.co/nateraw/vit-base-beans/resolve/main/bean_rust.jpeg"] |
|
images = [[path.as_posix()] for path in sorted(pathlib.Path('images').rglob('*.p*g'))] |
|
|
|
twitter_link = """ |
|
[![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi) |
|
""" |
|
|
|
css = ''' |
|
h1#title { |
|
text-align: center; |
|
} |
|
''' |
|
demo = gr.Blocks(css=css) |
|
|
|
with demo: |
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
gr.Markdown(twitter_link) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem('Image Upload'): |
|
with gr.Row(): |
|
with gr.Column(): |
|
img_input = gr.Image(type='pil',shape=(450,450)) |
|
label_from_upload= gr.Label(num_top_classes=3) |
|
|
|
with gr.Row(): |
|
example_images = gr.Examples(examples=images,inputs=[img_input]) |
|
|
|
|
|
img_but = gr.Button('Classify') |
|
|
|
with gr.TabItem('Image URL'): |
|
with gr.Row(): |
|
with gr.Column(): |
|
url_input = gr.Textbox(lines=2,label='Enter valid image URL here..') |
|
original_image = gr.Image(shape=(450,450)) |
|
url_input.change(get_original_image, url_input, original_image) |
|
with gr.Column(): |
|
label_from_url = gr.Label(num_top_classes=3) |
|
|
|
with gr.Row(): |
|
example_url = gr.Examples(examples=urls,inputs=[url_input]) |
|
|
|
|
|
url_but = gr.Button('Classify') |
|
|
|
with gr.TabItem('WebCam'): |
|
with gr.Row(): |
|
with gr.Column(): |
|
web_input = gr.Image(source='webcam',type='pil',shape=(750,750),streaming=True) |
|
with gr.Column(): |
|
label_from_webcam= gr.Label(num_top_classes=3) |
|
|
|
cam_but = gr.Button('Classify') |
|
|
|
url_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_url],queue=True) |
|
img_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_upload],queue=True) |
|
cam_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_webcam],queue=True) |
|
|
|
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-plant-health)") |
|
|
|
|
|
demo.launch(debug=True,enable_queue=True) |