Pclanglais commited on
Commit
b6cc9e1
1 Parent(s): 1fca231

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -178
app.py CHANGED
@@ -1,201 +1,66 @@
1
  import spaces
2
  import transformers
3
  import re
4
- from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline
5
- from vllm import LLM, SamplingParams
6
  import torch
7
  import gradio as gr
8
- import json
9
  import os
10
- import shutil
11
- import requests
12
- import pandas as pd
13
- import difflib
14
  from concurrent.futures import ThreadPoolExecutor
15
 
16
  # Define the device
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
- # OCR Correction Model
20
- ocr_model_name = "PleIAs/OCRonos-Vintage"
 
 
21
 
22
- import torch
23
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
24
-
25
- # Load pre-trained model and tokenizer
26
- model_name = "PleIAs/OCRonos-Vintage"
27
- model = GPT2LMHeadModel.from_pretrained(model_name)
28
- tokenizer = GPT2Tokenizer.from_pretrained(model_name)
29
-
30
- # Set the device to GPU if available, otherwise use CPU
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
- model.to(device)
33
-
34
- # CSS for formatting
35
  css = """
36
  <style>
37
- .generation {
38
- margin-left: 2em;
39
- margin-right: 2em;
40
- font-size: 1.2em;
41
- }
42
- :target {
43
- background-color: #CCF3DF;
44
- }
45
- .source {
46
- float: left;
47
- max-width: 17%;
48
- margin-left: 2%;
49
- }
50
- .tooltip {
51
- position: relative;
52
- cursor: pointer;
53
- font-variant-position: super;
54
- color: #97999b;
55
- }
56
- .tooltip:hover::after {
57
- content: attr(data-text);
58
- position: absolute;
59
- left: 0;
60
- top: 120%;
61
- white-space: pre-wrap;
62
- width: 500px;
63
- max-width: 500px;
64
- z-index: 1;
65
- background-color: #f9f9f9;
66
- color: #000;
67
- border: 1px solid #ddd;
68
- border-radius: 5px;
69
- padding: 5px;
70
- display: block;
71
- box-shadow: 0 4px 8px rgba(0,0,0,0.1);
72
- }
73
- .deleted {
74
- background-color: #ffcccb;
75
- text-decoration: line-through;
76
- }
77
- .inserted {
78
- background-color: #90EE90;
79
- }
80
- .manuscript {
81
- display: flex;
82
- margin-bottom: 10px;
83
- align-items: baseline;
84
- }
85
- .annotation {
86
- width: 15%;
87
- padding-right: 20px;
88
- color: grey !important;
89
- font-style: italic;
90
- text-align: right;
91
- }
92
- .content {
93
- width: 80%;
94
- }
95
- h2 {
96
- margin: 0;
97
- font-size: 1.5em;
98
- }
99
- .title-content h2 {
100
- font-weight: bold;
101
- }
102
- .bibliography-content {
103
- color: darkgreen !important;
104
- margin-top: -5px;
105
- }
106
- .paratext-content {
107
- color: #a4a4a4 !important;
108
- margin-top: -5px;
109
- }
110
  </style>
111
  """
112
 
113
  # Helper functions
114
  def generate_html_diff(old_text, new_text):
115
- d = difflib.Differ()
116
- diff = list(d.compare(old_text.split(), new_text.split()))
117
- html_diff = []
118
- for word in diff:
119
- if word.startswith(' '):
120
- html_diff.append(word[2:])
121
- elif word.startswith('+ '):
122
- html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>')
123
- return ' '.join(html_diff)
124
 
125
  def preprocess_text(text):
126
- text = re.sub(r'<[^>]+>', '', text)
127
- text = re.sub(r'\n', ' ', text)
128
- text = re.sub(r'\s+', ' ', text)
129
- return text.strip()
130
-
131
- def split_text(text, max_tokens=500):
132
- parts = text.split("\n")
133
- chunks = []
134
- current_chunk = ""
135
-
136
- for part in parts:
137
- if current_chunk:
138
- temp_chunk = current_chunk + "\n" + part
139
- else:
140
- temp_chunk = part
141
-
142
- num_tokens = len(tokenizer.tokenize(temp_chunk))
143
-
144
- if num_tokens <= max_tokens:
145
- current_chunk = temp_chunk
146
- else:
147
- if current_chunk:
148
- chunks.append(current_chunk)
149
- current_chunk = part
150
-
151
- if current_chunk:
152
- chunks.append(current_chunk)
153
-
154
- if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens:
155
- long_text = chunks[0]
156
- chunks = []
157
- while len(tokenizer.tokenize(long_text)) > max_tokens:
158
- split_point = len(long_text) // 2
159
- while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]):
160
- split_point += 1
161
- if split_point >= len(long_text):
162
- split_point = len(long_text) - 1
163
- chunks.append(long_text[:split_point].strip())
164
- long_text = long_text[split_point:].strip()
165
- if long_text:
166
- chunks.append(long_text)
167
-
168
- return chunks
169
-
170
-
171
- # Function to generate text
172
- def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
173
- prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
174
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
175
-
176
- # Set the number of threads for PyTorch
177
- torch.set_num_threads(num_threads)
178
-
179
- # Generate text
180
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
181
- future = executor.submit(
182
- model.generate,
183
- input_ids,
184
- max_new_tokens=max_new_tokens,
185
- pad_token_id=tokenizer.eos_token_id,
186
- top_k=50,
187
- num_return_sequences=1,
188
- do_sample=True,
189
- temperature=0.7
190
- )
191
- output = future.result()
192
-
193
- # Decode and return the generated text
194
- result = tokenizer.decode(output[0], skip_special_tokens=True)
195
- print(result)
196
-
197
- result = result.split("### Correction ###")[1]
198
- return result
199
 
200
  # OCR Correction Class
201
  class OCRCorrector:
@@ -214,7 +79,7 @@ class TextProcessor:
214
 
215
  @spaces.GPU(duration=120)
216
  def process(self, user_message):
217
- #OCR Correction
218
  corrected_text, html_diff = self.ocr_corrector.correct(user_message)
219
 
220
  # Combine results
 
1
  import spaces
2
  import transformers
3
  import re
 
 
4
  import torch
5
  import gradio as gr
 
6
  import os
7
+ import ctranslate2
 
 
 
8
  from concurrent.futures import ThreadPoolExecutor
9
 
10
  # Define the device
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Load CTranslate2 model and tokenizer
14
+ model_path = "PleIAs/OCRonos-Vintage-CT2"
15
+ generator = ctranslate2.Generator(model_path, device=device)
16
+ tokenizer = transformers.AutoTokenizer.from_pretrained("PleIAs/OCRonos-Vintage")
17
 
18
+ # CSS for formatting (unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
19
  css = """
20
  <style>
21
+ ... (your existing CSS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  </style>
23
  """
24
 
25
  # Helper functions
26
  def generate_html_diff(old_text, new_text):
27
+ # (unchanged)
28
+ ...
 
 
 
 
 
 
 
29
 
30
  def preprocess_text(text):
31
+ # (unchanged)
32
+ ...
33
+
34
+ def split_text(text, max_tokens=400):
35
+ encoded = tokenizer.encode(text)
36
+ splits = []
37
+ for i in range(0, len(encoded), max_tokens):
38
+ split = encoded[i:i+max_tokens]
39
+ splits.append(tokenizer.decode(split))
40
+ return splits
41
+
42
+ # Function to generate text using CTranslate2
43
+ def ocr_correction(prompt, max_new_tokens=600):
44
+ splits = split_text(prompt, max_tokens=400)
45
+ corrected_splits = []
46
+
47
+ for split in splits:
48
+ full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n"
49
+ encoded = tokenizer.encode(full_prompt)
50
+ prompt_tokens = tokenizer.convert_ids_to_tokens(encoded)
51
+
52
+ result = generator.generate_batch(
53
+ [prompt_tokens],
54
+ max_length=max_new_tokens,
55
+ sampling_temperature=0.7,
56
+ sampling_topk=20,
57
+ include_prompt_in_result=False
58
+ )[0]
59
+
60
+ corrected_text = tokenizer.decode(result.sequences_ids[0])
61
+ corrected_splits.append(corrected_text)
62
+
63
+ return " ".join(corrected_splits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # OCR Correction Class
66
  class OCRCorrector:
 
79
 
80
  @spaces.GPU(duration=120)
81
  def process(self, user_message):
82
+ # OCR Correction
83
  corrected_text, html_diff = self.ocr_corrector.correct(user_message)
84
 
85
  # Combine results