monsoon-nlp's picture
hide sgpt for now
3d872a7
raw
history blame contribute delete
No virus
2.48 kB
import os
import cohere
import gradio as gr
import numpy as np
import pinecone
import torch
from transformers import AutoModel, AutoTokenizer
co = cohere.Client(os.environ.get('COHERE_API', ''))
pinecone.init(
api_key=os.environ.get('PINECONE_API', ''),
environment=os.environ.get('PINECONE_ENV', '')
)
# model = AutoModel.from_pretrained('monsoon-nlp/gpt-nyc')
# tokenizer = AutoTokenizer.from_pretrained('monsoon-nlp/gpt-nyc')
# zos = np.zeros(4096-1024).tolist()
def list_me(matches):
result = ''
for match in matches:
result += '<li><a target="_blank" href="https://reddit.com/r/AskNYC/comments/' + match['id'] + '">'
result += match['metadata']['question']
result += '</a>'
if 'body' in match['metadata']:
result += '<br/>' + match['metadata']['body']
result += '</li>'
return result.replace('/mini', '/')
def query(question):
# Cohere search
response = co.embed(
model='large',
texts=[question],
)
index = pinecone.Index("gptnyc")
closest = index.query(
top_k=2,
include_metadata=True,
vector=response.embeddings[0],
)
# SGPT search
# batch_tokens = tokenizer(
# [question],
# padding=True,
# truncation=True,
# return_tensors="pt"
# )
# with torch.no_grad():
# last_hidden_state = model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state
# weights = (
# torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
# .unsqueeze(0)
# .unsqueeze(-1)
# .expand(last_hidden_state.size())
# .float().to(last_hidden_state.device)
# )
# input_mask_expanded = (
# batch_tokens["attention_mask"]
# .unsqueeze(-1)
# .expand(last_hidden_state.size())
# .float()
# )
# sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
# sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
# embeddings = sum_embeddings / sum_mask
# closest_sgpt = index.query(
# top_k=2,
# include_metadata=True,
# namespace="mini",
# vector=embeddings[0].tolist() + zos,
# )
return '<h3>Cohere</h3><ul>' + list_me(closest['matches']) + '</ul>'
#'<h3>SGPT</h3><ul>' + list_me(closest_sgpt['matches']) + '</ul>'
iface = gr.Interface(
fn=query,
inputs="text",
outputs="html"
)
iface.launch()