vteam27 commited on
Commit
40e9659
1 Parent(s): 4e9395b

added text Batches

Browse files
Files changed (2) hide show
  1. app.py +24 -4
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  from doctr.io import DocumentFile
3
  from doctr.models import ocr_predictor
4
  import gradio as gr
@@ -99,7 +101,19 @@ demo_ocr = gr.Interface(
99
 
100
  # demo_ocr.launch(debug=True)
101
 
102
-
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def run_t2tt(file_uploader , input_text: str, source_language: str, target_language: str) -> (str, bytes):
105
  if file_uploader is not None:
@@ -107,9 +121,15 @@ def run_t2tt(file_uploader , input_text: str, source_language: str, target_langu
107
  input_text=file.read()
108
  source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
109
  target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
110
- text_inputs = processor(text = input_text, src_lang=source_language_code , return_tensors="pt")
111
- output_tokens = model.generate(**text_inputs, tgt_lang=target_language_code)
112
- output = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
 
 
 
 
 
 
113
  _output_name = "result.txt"
114
  open(_output_name, 'w').write(output)
115
  return str(output), _output_name
 
1
  import os
2
+ import nltk
3
+ nltk.download('punkt')
4
  from doctr.io import DocumentFile
5
  from doctr.models import ocr_predictor
6
  import gradio as gr
 
101
 
102
  # demo_ocr.launch(debug=True)
103
 
104
+ def split_text_into_batches(text, max_tokens_per_batch):
105
+ sentences = nltk.sent_tokenize(text) # Tokenize text into sentences
106
+ batches = []
107
+ current_batch = ""
108
+ for sentence in sentences:
109
+ if len(current_batch) + len(sentence) + 1 <= max_tokens_per_batch: # Add 1 for space
110
+ current_batch += sentence + " " # Add sentence to current batch
111
+ else:
112
+ batches.append(current_batch.strip()) # Add current batch to batches list
113
+ current_batch = sentence + " " # Start a new batch with the current sentence
114
+ if current_batch:
115
+ batches.append(current_batch.strip()) # Add the last batch
116
+ return batches
117
 
118
  def run_t2tt(file_uploader , input_text: str, source_language: str, target_language: str) -> (str, bytes):
119
  if file_uploader is not None:
 
121
  input_text=file.read()
122
  source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
123
  target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
124
+ max_tokens_per_batch= 256
125
+ batches = split_text_into_batches(input_text, max_tokens_per_batch)
126
+ translated_text = ""
127
+ for batch in batches:
128
+ text_inputs = processor(text=batch, src_lang=source_language_code, return_tensors="pt")
129
+ output_tokens = model.generate(**text_inputs, tgt_lang=target_language_code)
130
+ translated_batch = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
131
+ translated_text += translated_batch + " "
132
+ output=translated_text.strip()
133
  _output_name = "result.txt"
134
  open(_output_name, 'w').write(output)
135
  return str(output), _output_name
requirements.txt CHANGED
@@ -8,4 +8,5 @@ transformers
8
  fairseq2==0.1
9
  pydub
10
  yt-dlp
11
- sentencepiece
 
 
8
  fairseq2==0.1
9
  pydub
10
  yt-dlp
11
+ sentencepiece
12
+ nltk