Stefan commited on
Commit
bb3407a
1 Parent(s): 7ce98a0

feat(setup): initial commit

Browse files
Files changed (10) hide show
  1. .gitignore +2 -0
  2. .vscode/settings.json +3 -0
  3. Pipfile +33 -0
  4. Pipfile.lock +0 -0
  5. embedding.py +48 -0
  6. main.py +28 -0
  7. pg.py +41 -0
  8. processing.py +95 -0
  9. requirements.txt +97 -0
  10. vectors.py +38 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ data*/
2
+ .env
.vscode/settings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "editor.defaultFormatter": "ms-python.black-formatter"
3
+ }
Pipfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[source]]
2
+ url = "https://pypi.org/simple"
3
+ verify_ssl = true
4
+ name = "pypi"
5
+
6
+ [packages]
7
+ numpy = "*"
8
+ pandas = "*"
9
+ torch = "*"
10
+ transformers = "*"
11
+ accelerate = "*"
12
+ sentencepiece = "*"
13
+ protobuf = "==3.20.1"
14
+ aiohttp = "*"
15
+ aiodns = "*"
16
+ brotli = "*"
17
+ python-dotenv = "*"
18
+ openai = "*"
19
+ nest-asyncio = "*"
20
+ tqdm = "*"
21
+ tiktoken = "*"
22
+ instructorembedding = "*"
23
+ markdown = "*"
24
+ sentence-transformers = "*"
25
+ pinecone-client = "*"
26
+ psycopg2 = "*"
27
+ gradio = "*"
28
+
29
+ [dev-packages]
30
+ ipykernel = "*"
31
+
32
+ [requires]
33
+ python_version = "3.11"
Pipfile.lock ADDED
The diff for this file is too large to render. See raw diff
 
embedding.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+ import tiktoken
3
+ from transformers import AutoTokenizer, AutoModel
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-large-v2")
6
+ model = AutoModel.from_pretrained("intfloat/e5-large-v2")
7
+
8
+ EMBEDDING_CHAR_LIMIT = 512
9
+
10
+
11
+ def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
12
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
13
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
14
+
15
+
16
+ def strings_to_vectors(strings: list[str]):
17
+ passage_batch = tokenizer(
18
+ strings,
19
+ max_length=EMBEDDING_CHAR_LIMIT,
20
+ padding=True,
21
+ truncation=True,
22
+ return_tensors="pt",
23
+ )
24
+ passage_outputs = model(**passage_batch)
25
+ return average_pool(
26
+ passage_outputs.last_hidden_state, passage_batch["attention_mask"]
27
+ )
28
+
29
+
30
+ def num_tokens_from_str(string, model="gpt-3.5-turbo"):
31
+ """Returns the number of tokens used by a list of messages."""
32
+ try:
33
+ encoding = tiktoken.encoding_for_model(model)
34
+ except KeyError:
35
+ encoding = tiktoken.get_encoding("cl100k_base")
36
+ if model == "gpt-3.5-turbo": # note: future models may deviate from this
37
+ num_tokens = 0
38
+ num_tokens += (
39
+ 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
40
+ )
41
+ num_tokens += len(encoding.encode(string))
42
+ num_tokens += 2 # every reply is primed with <im_start>assistant
43
+ return num_tokens
44
+ else:
45
+ raise NotImplementedError(
46
+ f"""num_tokens_from_messages() is not presently implemented for model {model}.
47
+ See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
48
+ )
main.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from processing import md_to_passages
3
+ from pg import get_chapters
4
+ from vectors import match_query
5
+
6
+
7
+ def find_embedding(query: str):
8
+ top_res = match_query(query, 3)
9
+ # print(top_res)
10
+
11
+ chapters = get_chapters(list(map(lambda x: x["metadata"]["chapterId"], top_res)))
12
+
13
+ output = ""
14
+
15
+ for res, chapter in zip(top_res, chapters):
16
+ passages = md_to_passages(chapter["explanation"])
17
+ output += f"{res['id']}\t| score: {res['score']:.2f}%\n{passages[res['passage_idx']]}\n\n"
18
+
19
+ return output
20
+
21
+
22
+ with gr.Blocks() as quesbook_search:
23
+ question = gr.Text(label="question")
24
+ answer = gr.Text(label="answer")
25
+ submit = gr.Button("Submit")
26
+ submit.click(fn=find_embedding, inputs=question, outputs=answer)
27
+
28
+ quesbook_search.launch()
pg.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import psycopg2
2
+ import os
3
+
4
+ pg = psycopg2.connect(
5
+ dbname=os.getenv("POSTGRES_DB"),
6
+ user=os.getenv("POSTGRES_USER"),
7
+ password=os.getenv("POSTGRES_PASSWORD"),
8
+ port=os.getenv("POSTGRES_PORT"),
9
+ host=os.getenv("POSTGRES_HOST"),
10
+ )
11
+
12
+
13
+ def get_chapters(ids: list[int]):
14
+ cur = pg.cursor()
15
+ cur.execute(
16
+ """
17
+ SELECT
18
+ ch.id,
19
+ ch.explanation
20
+ FROM
21
+ chapters ch
22
+ WHERE
23
+ ch.id = ANY (%s);
24
+ """,
25
+ (ids,),
26
+ )
27
+ data = cur.fetchall()
28
+ cur.close()
29
+
30
+ chapters = list(map(lambda x: {"id": x[0], "explanation": x[1]}, data))
31
+
32
+ ordered_chapters = []
33
+ for id in ids:
34
+ chapter = next(
35
+ (ch for ch in chapters if ch["id"] == id),
36
+ None,
37
+ )
38
+ if chapter:
39
+ ordered_chapters.append(chapter)
40
+
41
+ return ordered_chapters
processing.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from markdown import Markdown
2
+ from io import StringIO
3
+ import re
4
+ from embedding import num_tokens_from_str, EMBEDDING_CHAR_LIMIT
5
+
6
+ HTMLR = re.compile("<.*?>")
7
+ WS = re.compile("\s+")
8
+ LIGHTGALLERY = re.compile("\[lightgallery.*\]")
9
+
10
+
11
+ def unmark_element(element, stream=None):
12
+ if stream is None:
13
+ stream = StringIO()
14
+ if element.text:
15
+ stream.write(element.text)
16
+ for sub in element:
17
+ unmark_element(sub, stream)
18
+ if element.tail:
19
+ stream.write(element.tail)
20
+ return stream.getvalue()
21
+
22
+
23
+ # patching Markdown
24
+ Markdown.output_formats["plain"] = unmark_element
25
+ __md = Markdown(output_format="plain", extensions=["tables"])
26
+ __md.stripTopLevelTags = False
27
+
28
+
29
+ def unmark(text):
30
+ return __md.convert(text)
31
+
32
+
33
+ def clean_md(text: str) -> list[str]:
34
+ cleantext = re.sub(HTMLR, "", text)
35
+ cleantext = re.sub(LIGHTGALLERY, "", cleantext)
36
+ para = cleantext.split("\n#")
37
+ para = [unmark(p) for p in para]
38
+ para = [re.sub(WS, " ", p.lower()) for p in para]
39
+ return para
40
+
41
+
42
+ start_seq_length = num_tokens_from_str("passage: ")
43
+
44
+
45
+ def truncate_to_sequences(text: str, max_char=EMBEDDING_CHAR_LIMIT) -> list[str]:
46
+ sequence_length = num_tokens_from_str(text) // (max_char - start_seq_length) + 1
47
+ length = len(text)
48
+ separator = length // sequence_length
49
+
50
+ sequences = []
51
+ base = 0
52
+ while base < length:
53
+ count = len(sequences) + 1
54
+ end = min(separator * count, length)
55
+ found = False
56
+
57
+ if end == length:
58
+ found = True
59
+
60
+ if found is False:
61
+ section = text[base:end]
62
+ section_rev = section[::-1]
63
+ for i in range(len(section_rev)):
64
+ if section_rev[i : i + 2] == " .":
65
+ found = True
66
+ end -= 1
67
+ break
68
+ end -= 1
69
+
70
+ if found is False:
71
+ end = separator * count
72
+ for i in range(len(section_rev)):
73
+ if section_rev[i] == " ":
74
+ found = True
75
+ break
76
+ end -= 1
77
+
78
+ if num_tokens_from_str(text[base:end]) > max_char:
79
+ sub_sequences = truncate_to_sequences(text[base:end])
80
+ sequences += sub_sequences
81
+ else:
82
+ sequences.append(text[base:end])
83
+
84
+ base = base + end
85
+ return sequences
86
+
87
+
88
+ def md_to_passages(md: str) -> list[str]:
89
+ initial_passages = clean_md(md)
90
+ passages = []
91
+ for p in initial_passages:
92
+ sequences = truncate_to_sequences(p)
93
+ passages += sequences
94
+
95
+ return passages
requirements.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -i https://pypi.org/simple
2
+ accelerate==0.19.0
3
+ aiodns==3.0.0
4
+ aiofiles==23.1.0 ; python_version >= '3.7' and python_version < '4.0'
5
+ aiohttp==3.8.4
6
+ aiosignal==1.3.1 ; python_version >= '3.7'
7
+ altair==5.0.0 ; python_version >= '3.7'
8
+ anyio==3.6.2 ; python_full_version >= '3.6.2'
9
+ async-timeout==4.0.2 ; python_version >= '3.6'
10
+ attrs==23.1.0 ; python_version >= '3.7'
11
+ brotli==1.0.9
12
+ certifi==2023.5.7 ; python_version >= '3.6'
13
+ cffi==1.15.1
14
+ charset-normalizer==3.1.0 ; python_full_version >= '3.7.0'
15
+ click==8.1.3 ; python_version >= '3.7'
16
+ contourpy==1.0.7 ; python_version >= '3.8'
17
+ cycler==0.11.0 ; python_version >= '3.6'
18
+ dnspython==2.3.0 ; python_version >= '3.7' and python_version < '4.0'
19
+ fastapi==0.95.2 ; python_version >= '3.7'
20
+ ffmpy==0.3.0
21
+ filelock==3.12.0 ; python_version >= '3.7'
22
+ fonttools==4.39.4 ; python_version >= '3.8'
23
+ frozenlist==1.3.3 ; python_version >= '3.7'
24
+ fsspec==2023.5.0 ; python_version >= '3.8'
25
+ gradio==3.32.0
26
+ gradio-client==0.2.5 ; python_version >= '3.7'
27
+ h11==0.14.0 ; python_version >= '3.7'
28
+ httpcore==0.17.2 ; python_version >= '3.7'
29
+ httpx==0.24.1 ; python_version >= '3.7'
30
+ huggingface-hub==0.14.1 ; python_full_version >= '3.7.0'
31
+ idna==3.4 ; python_version >= '3.5'
32
+ instructorembedding==1.0.0
33
+ jinja2==3.1.2 ; python_version >= '3.7'
34
+ joblib==1.2.0 ; python_version >= '3.7'
35
+ jsonschema==4.17.3 ; python_version >= '3.7'
36
+ kiwisolver==1.4.4 ; python_version >= '3.7'
37
+ linkify-it-py==2.0.2
38
+ loguru==0.7.0 ; python_version >= '3.5'
39
+ markdown==3.4.3
40
+ markdown-it-py[linkify]==2.2.0 ; python_version >= '3.7'
41
+ markupsafe==2.1.2 ; python_version >= '3.7'
42
+ matplotlib==3.7.1 ; python_version >= '3.8'
43
+ mdit-py-plugins==0.3.3 ; python_version >= '3.7'
44
+ mdurl==0.1.2 ; python_version >= '3.7'
45
+ mpmath==1.3.0
46
+ multidict==6.0.4 ; python_version >= '3.7'
47
+ nest-asyncio==1.5.6
48
+ networkx==3.1 ; python_version >= '3.8'
49
+ nltk==3.8.1 ; python_version >= '3.7'
50
+ numpy==1.24.3
51
+ openai==0.27.7
52
+ orjson==3.8.13 ; python_version >= '3.7'
53
+ packaging==23.1 ; python_version >= '3.7'
54
+ pandas==2.0.1
55
+ pillow==9.5.0 ; python_version >= '3.7'
56
+ pinecone-client==2.2.1
57
+ protobuf==3.20.1
58
+ psutil==5.9.5 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
59
+ psycopg2==2.9.6
60
+ pycares==4.3.0
61
+ pycparser==2.21
62
+ pydantic==1.10.8 ; python_version >= '3.7'
63
+ pydub==0.25.1
64
+ pygments==2.15.1 ; python_version >= '3.7'
65
+ pyparsing==3.0.9 ; python_full_version >= '3.6.8'
66
+ pyrsistent==0.19.3 ; python_version >= '3.7'
67
+ python-dateutil==2.8.2 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
68
+ python-dotenv==1.0.0
69
+ python-multipart==0.0.6 ; python_version >= '3.7'
70
+ pytz==2023.3
71
+ pyyaml==6.0 ; python_version >= '3.6'
72
+ regex==2023.5.5 ; python_version >= '3.6'
73
+ requests==2.31.0 ; python_version >= '3.7'
74
+ scikit-learn==1.2.2 ; python_version >= '3.8'
75
+ scipy==1.10.1 ; python_version < '3.12' and python_version >= '3.8'
76
+ semantic-version==2.10.0 ; python_version >= '2.7'
77
+ sentence-transformers==2.2.2
78
+ sentencepiece==0.1.99
79
+ six==1.16.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
80
+ sniffio==1.3.0 ; python_version >= '3.7'
81
+ starlette==0.27.0 ; python_version >= '3.7'
82
+ sympy==1.12 ; python_version >= '3.8'
83
+ threadpoolctl==3.1.0 ; python_version >= '3.6'
84
+ tiktoken==0.4.0
85
+ tokenizers==0.13.3
86
+ toolz==0.12.0 ; python_version >= '3.5'
87
+ torch==2.0.1
88
+ torchvision==0.15.2 ; python_version >= '3.8'
89
+ tqdm==4.65.0
90
+ transformers==4.29.2
91
+ typing-extensions==4.6.1 ; python_version >= '3.7'
92
+ tzdata==2023.3 ; python_version >= '2'
93
+ uc-micro-py==1.0.2 ; python_version >= '3.7'
94
+ urllib3==2.0.2 ; python_version >= '3.7'
95
+ uvicorn==0.22.0 ; python_version >= '3.7'
96
+ websockets==11.0.3 ; python_version >= '3.7'
97
+ yarl==1.9.2 ; python_version >= '3.7'
vectors.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from embedding import strings_to_vectors
2
+ import pinecone
3
+ import os
4
+
5
+ PINECONE_API = os.getenv("PINECONE_API")
6
+
7
+ pinecone.init(api_key=PINECONE_API, environment="us-west4-gcp-free")
8
+
9
+ vector_index = pinecone.Index("quesmed")
10
+
11
+
12
+ def scored_vector_todict(scored_vector):
13
+ x = {
14
+ "id": scored_vector["id"],
15
+ "metadata": {
16
+ "topicId": int(scored_vector["metadata"]["topicId"]),
17
+ "chapterId": int(scored_vector["metadata"]["chapterId"]),
18
+ "conceptId": int(scored_vector["metadata"]["conceptId"]),
19
+ },
20
+ "score": scored_vector["score"] * 100,
21
+ "values": scored_vector["values"],
22
+ }
23
+ for k, v in x["metadata"].items():
24
+ x[k] = int(v)
25
+ x["passage_idx"] = int(x["id"][-1])
26
+ return x
27
+
28
+
29
+ def match_query(query: str, n_res=3):
30
+ queries = [f"query: {query.replace('?','').lower()}"]
31
+ query_embeddings = strings_to_vectors(queries)
32
+ result = vector_index.query(
33
+ query_embeddings[0].tolist(),
34
+ top_k=n_res,
35
+ include_metadata=True,
36
+ namespace="quesbook",
37
+ )
38
+ return list(map(scored_vector_todict, result["matches"]))