import spaces import transformers import re from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline from vllm import LLM, SamplingParams import torch import gradio as gr import json import os import shutil import requests import pandas as pd import difflib # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" # OCR Correction Model ocr_model_name = "PleIAs/OCRonos" ocr_llm = LLM(ocr_model_name, max_model_len=8128) # Editorial Segmentation Model editorial_model = "PleIAs/Segmentext" token_classifier = pipeline( "token-classification", model=editorial_model, aggregation_strategy="simple", device=device ) tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512) # CSS for formatting css = """ """ # Helper functions def generate_html_diff(old_text, new_text): d = difflib.Differ() diff = list(d.compare(old_text.split(), new_text.split())) html_diff = [] for word in diff: if word.startswith(' '): html_diff.append(word[2:]) elif word.startswith('+ '): html_diff.append(f'{word[2:]}') return ' '.join(html_diff) def preprocess_text(text): text = re.sub(r'<[^>]+>', '', text) text = re.sub(r'\n', ' ', text) text = re.sub(r'\s+', ' ', text) return text.strip() def split_text(text, max_tokens=500): parts = text.split("\n") chunks = [] current_chunk = "" for part in parts: if current_chunk: temp_chunk = current_chunk + "\n" + part else: temp_chunk = part num_tokens = len(tokenizer.tokenize(temp_chunk)) if num_tokens <= max_tokens: current_chunk = temp_chunk else: if current_chunk: chunks.append(current_chunk) current_chunk = part if current_chunk: chunks.append(current_chunk) if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: long_text = chunks[0] chunks = [] while len(tokenizer.tokenize(long_text)) > max_tokens: split_point = len(long_text) // 2 while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): split_point += 1 if split_point >= len(long_text): split_point = len(long_text) - 1 chunks.append(long_text[:split_point].strip()) long_text = long_text[split_point:].strip() if long_text: chunks.append(long_text) return chunks def transform_chunks(marianne_segmentation): marianne_segmentation = pd.DataFrame(marianne_segmentation) marianne_segmentation = marianne_segmentation[marianne_segmentation['entity_group'] != 'separator'] marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).str.replace('¶', '\n', regex=False) marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).apply(preprocess_text) marianne_segmentation = marianne_segmentation[marianne_segmentation['word'].notna() & (marianne_segmentation['word'] != '') & (marianne_segmentation['word'] != ' ')] html_output = [] for _, row in marianne_segmentation.iterrows(): entity_group = row['entity_group'] result_entity = "[" + entity_group.capitalize() + "]" word = row['word'] if entity_group == 'title': html_output.append(f'
{result_entity}

{word}

') elif entity_group == 'bibliography': html_output.append(f'
{result_entity}
{word}
') elif entity_group == 'paratext': html_output.append(f'
{result_entity}
{word}
') else: html_output.append(f'
{result_entity}
{word}
') final_html = '\n'.join(html_output) return final_html # OCR Correction Class class OCRCorrector: def __init__(self, system_prompt="Le dialogue suivant est une conversation"): self.system_prompt = system_prompt def correct(self, user_message): sampling_params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=4000, presence_penalty=0, stop=["#END#"]) detailed_prompt = f"### TEXT ###\n{user_message}\n\n### CORRECTION ###\n" prompts = [detailed_prompt] outputs = ocr_llm.generate(prompts, sampling_params, use_tqdm=False) generated_text = outputs[0].outputs[0].text html_diff = generate_html_diff(user_message, generated_text) return generated_text, html_diff # Editorial Segmentation Class class EditorialSegmenter: def segment(self, text): editorial_text = re.sub("\n", " ¶ ", text) num_tokens = len(tokenizer.tokenize(editorial_text)) if num_tokens > 500: batch_prompts = split_text(editorial_text, max_tokens=500) else: batch_prompts = [editorial_text] out = token_classifier(batch_prompts) classified_list = [] for classification in out: df = pd.DataFrame(classification) classified_list.append(df) classified_list = pd.concat(classified_list) out = transform_chunks(classified_list) return out # Combined Processing Class class TextProcessor: def __init__(self): self.ocr_corrector = OCRCorrector() self.editorial_segmenter = EditorialSegmenter() @spaces.GPU(duration=120) def process(self, user_message): # Step 1: OCR Correction corrected_text, html_diff = self.ocr_corrector.correct(user_message) # Step 2: Editorial Segmentation segmented_text = self.editorial_segmenter.segment(corrected_text) # Combine results ocr_result = f'

OCR Correction

\n
{html_diff}
' editorial_result = f'

Editorial Segmentation

\n
{segmented_text}
' final_output = f"{css}{ocr_result}

{editorial_result}" return final_output # Create the TextProcessor instance text_processor = TextProcessor() # Define the Gradio interface with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: gr.HTML("""

PleIAs Editor

""") text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5) process_button = gr.Button("Process Text") text_output = gr.HTML(label="Processed text") process_button.click(text_processor.process, inputs=text_input, outputs=[text_output]) if __name__ == "__main__": demo.queue().launch()