File size: 2,549 Bytes
bb3407a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b6e243
bb3407a
 
 
 
 
 
 
 
 
 
 
5b6e243
bb3407a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e250f84
bb3407a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from markdown import Markdown
from io import StringIO
import re
from embedding import num_tokens_from_str, EMBEDDING_CHAR_LIMIT

HTMLR = re.compile("<.*?>")
WS = re.compile("\s+")
LIGHTGALLERY = re.compile("\[lightgallery.*\]")


def unmark_element(element, stream=None):
    if stream is None:
        stream = StringIO()
    if element.text:
        stream.write(element.text)
    for sub in element:
        unmark_element(sub, stream)
    if element.tail:
        stream.write(element.tail)
    return stream.getvalue()


# patching Markdown
Markdown.output_formats["plain"] = unmark_element
__md = Markdown(output_format="plain", extensions=["tables"])
__md.stripTopLevelTags = False


def unmark(text):
    return __md.convert(text)


def clean_md(text: str):
    cleantext = re.sub(HTMLR, "", text)
    cleantext = re.sub(LIGHTGALLERY, "", cleantext)
    para = cleantext.split("\n#")
    para = [unmark(p) for p in para]
    para = [re.sub(WS, " ", p.lower()) for p in para]
    return para


start_seq_length = num_tokens_from_str("passage: ")


def truncate_to_sequences(text: str, max_char=EMBEDDING_CHAR_LIMIT):
    sequence_length = num_tokens_from_str(text) // (max_char - start_seq_length) + 1
    length = len(text)
    separator = length // sequence_length

    sequences = []
    base = 0
    while base < length:
        count = len(sequences) + 1
        end = min(separator * count, length)
        found = False

        if end == length:
            found = True

        if found is False:
            section = text[base:end]
            section_rev = section[::-1]
            for i in range(len(section_rev)):
                if section_rev[i : i + 2] == " .":
                    found = True
                    end -= 1
                    break
                end -= 1

            if found is False:
                end = separator * count
                for i in range(len(section_rev)):
                    if section_rev[i] == " ":
                        found = True
                        break
                    end -= 1

        if num_tokens_from_str(text[base:end]) > max_char:
            sub_sequences = truncate_to_sequences(text[base:end])
            sequences += sub_sequences
        else:
            sequences.append(text[base:end])

        base = base + end
    return sequences


def md_to_passages(md: str):
    initial_passages = clean_md(md)
    passages = []
    for p in initial_passages:
        sequences = truncate_to_sequences(p)
        passages += sequences

    return passages