Pclanglais's picture
Update app.py
ba05a34 verified
raw
history blame
No virus
3.24 kB
import transformers
import re
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gradio as gr
import difflib
from concurrent.futures import ThreadPoolExecutor
import os
# OCR Correction Model
model_name = "PleIAs/OCRonos-Vintage"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load pre-trained model and tokenizer
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# CSS for formatting
css = """
<style>
.generation {
margin-left: 2em;
margin-right: 2em;
font-size: 1.2em;
}
.inserted {
background-color: #90EE90;
}
</style>
"""
def generate_html_diff(old_text, new_text):
d = difflib.Differ()
diff = list(d.compare(old_text.split(), new_text.split()))
html_diff = []
for word in diff:
if word.startswith(' '):
html_diff.append(word[2:])
elif word.startswith('+ '):
html_diff.append(f'<span class="inserted">{word[2:]}</span>')
return ' '.join(html_diff)
def split_text(text, max_tokens=400):
tokens = tokenizer.tokenize(text)
chunks = []
current_chunk = []
for token in tokens:
current_chunk.append(token)
if len(current_chunk) >= max_tokens:
chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
current_chunk = []
if current_chunk:
chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
return chunks
def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
torch.set_num_threads(num_threads)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future = executor.submit(
model.generate,
input_ids,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id,
top_k=50,
num_return_sequences=1,
do_sample=False
)
output = future.result()
result = tokenizer.decode(output[0], skip_special_tokens=True)
return result.split("### Correction ###")[1].strip()
def process_text(user_message):
chunks = split_text(user_message)
corrected_chunks = []
for chunk in chunks:
corrected_chunk = ocr_correction(chunk)
corrected_chunks.append(corrected_chunk)
corrected_text = ' '.join(corrected_chunks)
html_diff = generate_html_diff(user_message, corrected_text)
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
final_output = f"{css}{ocr_result}"
return final_output
# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector (CPU)</h1>""")
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
process_button = gr.Button("Process Text")
text_output = gr.HTML(label="Processed text")
process_button.click(process_text, inputs=text_input, outputs=[text_output])
if __name__ == "__main__":
demo.queue().launch()