quesbook_search / processing.py
Stefan
fix(spaces): remove types
5b6e243
raw
history blame
2.55 kB
from markdown import Markdown
from io import StringIO
import re
from embedding import num_tokens_from_str, EMBEDDING_CHAR_LIMIT
HTMLR = re.compile("<.*?>")
WS = re.compile("\s+")
LIGHTGALLERY = re.compile("\[lightgallery.*\]")
def unmark_element(element, stream=None):
if stream is None:
stream = StringIO()
if element.text:
stream.write(element.text)
for sub in element:
unmark_element(sub, stream)
if element.tail:
stream.write(element.tail)
return stream.getvalue()
# patching Markdown
Markdown.output_formats["plain"] = unmark_element
__md = Markdown(output_format="plain", extensions=["tables"])
__md.stripTopLevelTags = False
def unmark(text):
return __md.convert(text)
def clean_md(text: str):
cleantext = re.sub(HTMLR, "", text)
cleantext = re.sub(LIGHTGALLERY, "", cleantext)
para = cleantext.split("\n#")
para = [unmark(p) for p in para]
para = [re.sub(WS, " ", p.lower()) for p in para]
return para
start_seq_length = num_tokens_from_str("passage: ")
def truncate_to_sequences(text: str, max_char=EMBEDDING_CHAR_LIMIT):
sequence_length = num_tokens_from_str(text) // (max_char - start_seq_length) + 1
length = len(text)
separator = length // sequence_length
sequences = []
base = 0
while base < length:
count = len(sequences) + 1
end = min(separator * count, length)
found = False
if end == length:
found = True
if found is False:
section = text[base:end]
section_rev = section[::-1]
for i in range(len(section_rev)):
if section_rev[i : i + 2] == " .":
found = True
end -= 1
break
end -= 1
if found is False:
end = separator * count
for i in range(len(section_rev)):
if section_rev[i] == " ":
found = True
break
end -= 1
if num_tokens_from_str(text[base:end]) > max_char:
sub_sequences = truncate_to_sequences(text[base:end])
sequences += sub_sequences
else:
sequences.append(text[base:end])
base = base + end
return sequences
def md_to_passages(md: str):
initial_passages = clean_md(md)
passages = []
for p in initial_passages:
sequences = truncate_to_sequences(p)
passages += sequences
return passages