File size: 2,293 Bytes
5aa8a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM 

#workaround for unnecessary flash_attn requirement
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
import numpy as np

def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    if not str(filename).endswith("modeling_florence2.py"):
        return get_imports(filename)
    imports = get_imports(filename)
    imports.remove("flash_attn")
    return imports

with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
    model = AutoModelForCausalLM.from_pretrained("Oysiyl/Florence-2-FT-OCR-Cauldron-IAM", attn_implementation="sdpa", trust_remote_code=True)

processor = AutoProcessor.from_pretrained("Oysiyl/Florence-2-FT-OCR-Cauldron-IAM", trust_remote_code=True)

prompt = "OCR"


def predict(im):
    composite_image = Image.fromarray(im['composite'].astype(np.uint8)).convert("RGBA")
    background_image = Image.new("RGBA", composite_image.size, (255, 255, 255, 255))
    image = Image.alpha_composite(background_image, composite_image).convert("RGB")
    inputs = processor(text=prompt, images=image, return_tensors="pt")

    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        do_sample=False,
        num_beams=3
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]

    parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))


    return parsed_answer[prompt]


sketchpad = gr.ImageEditor(label="Draw something or upload an image")

interface = gr.Interface(
    predict,
    inputs=sketchpad, 
    outputs='text',
    theme='gradio/monochrome',
    title="Handwritten Recognition using Florence 2 model finetuned on IAM subset from HuggingFace Cauldron dataset", 
    description="<p style='text-align: center'>Draw a text or upload an image with handwritten notes and let's model try to guess the text!</p>", 
    article = "<p style='text-align: center'>Handwritten Text Recognition | Demo Model</p>")
interface.launch(debug=True)