from markdown import Markdown from io import StringIO import re from embedding import num_tokens_from_str, EMBEDDING_CHAR_LIMIT HTMLR = re.compile("<.*?>") WS = re.compile("\s+") LIGHTGALLERY = re.compile("\[lightgallery.*\]") def unmark_element(element, stream=None): if stream is None: stream = StringIO() if element.text: stream.write(element.text) for sub in element: unmark_element(sub, stream) if element.tail: stream.write(element.tail) return stream.getvalue() # patching Markdown Markdown.output_formats["plain"] = unmark_element __md = Markdown(output_format="plain", extensions=["tables"]) __md.stripTopLevelTags = False def unmark(text): return __md.convert(text) def clean_md(text: str): cleantext = re.sub(HTMLR, "", text) cleantext = re.sub(LIGHTGALLERY, "", cleantext) para = cleantext.split("\n#") para = [unmark(p) for p in para] para = [re.sub(WS, " ", p.lower()) for p in para] return para start_seq_length = num_tokens_from_str("passage: ") def truncate_to_sequences(text: str, max_char=EMBEDDING_CHAR_LIMIT): sequence_length = num_tokens_from_str(text) // (max_char - start_seq_length) + 1 length = len(text) separator = length // sequence_length sequences = [] base = 0 while base < length: count = len(sequences) + 1 end = min(separator * count, length) found = False if end == length: found = True if found is False: section = text[base:end] section_rev = section[::-1] for i in range(len(section_rev)): if section_rev[i : i + 2] == " .": found = True end -= 1 break end -= 1 if found is False: end = separator * count for i in range(len(section_rev)): if section_rev[i] == " ": found = True break end -= 1 if num_tokens_from_str(text[base:end]) > max_char: sub_sequences = truncate_to_sequences(text[base:end]) sequences += sub_sequences else: sequences.append(text[base:end]) base = base + end return sequences def md_to_passages(md: str): initial_passages = clean_md(md) passages = [] for p in initial_passages: sequences = truncate_to_sequences(p) passages += sequences return passages