Oysiyl's picture
Update app.py
e32677f verified
raw
history blame contribute delete
No virus
2.39 kB
import os
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#workaround for unnecessary flash_attn requirement
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
import numpy as np
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
model = AutoModelForCausalLM.from_pretrained("Oysiyl/Florence-2-FT-OCR-Cauldron-IAM", attn_implementation="sdpa", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("Oysiyl/Florence-2-FT-OCR-Cauldron-IAM", trust_remote_code=True)
prompt = "OCR"
def predict(im):
composite_image = Image.fromarray(im['composite'].astype(np.uint8)).convert("RGBA")
background_image = Image.new("RGBA", composite_image.size, (255, 255, 255, 255))
image = Image.alpha_composite(background_image, composite_image).convert("RGB")
inputs = processor(text=prompt, images=image, return_tensors="pt")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
return parsed_answer[prompt]
sketchpad = gr.ImageEditor(label="Draw something or upload an image")
interface = gr.Interface(
predict,
inputs=sketchpad,
outputs='text',
theme='gradio/monochrome',
title="Handwritten Recognition using Florence 2 model finetuned on IAM subset from HuggingFace Cauldron dataset",
description="<p style='text-align: center'>Draw a text or upload an image with handwritten notes and let's model try to guess the text!</p>",
article = "<p style='text-align: center'>Handwritten Text Recognition | Demo Model</p>")
interface.launch(debug=True)