quotes-recsys / utils.py
batalovme's picture
Add demo
50bcd75
raw
history blame contribute delete
No virus
2.61 kB
import random
import torch
import numpy as np
import pandas as pd
from stqdm import stqdm
from torch import nn
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModel
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class DSSM(nn.Module):
def __init__(self, base_model_name, base_model=AutoModel):
super().__init__()
self.diary_emb = base_model.from_pretrained(base_model_name, add_pooling_layer=False)
self.quote_emb = base_model.from_pretrained(base_model_name, add_pooling_layer=False)
def forward(self, diary, quote):
return self.diary_emb(**diary), self.quote_emb(**quote)
def get_models_and_tokenizer(base_model_name, base_model=AutoModel, ckpt=None):
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = DSSM(base_model_name, base_model=base_model)
if ckpt:
print("use ckpt")
model.load_state_dict(torch.load(ckpt, map_location=device))
model.to(device)
return model.diary_emb, model.quote_emb, tokenizer
def model_inference(model, tokenizer, text):
tokenized_text = tokenizer(text, return_tensors="pt", truncation=True)
tokenized_text = tokenized_text.to(device)
output = model(**tokenized_text)
return output[0][:, 0, :]
class Recommender:
SIMILARITY_THRESHOLD = 0.8
def __init__(self, quotes_df, base_model_name, base_model=AutoModel, ckpt=None):
(self.diary_embedder,
self.quote_embedder,
self.tokenizer) = get_models_and_tokenizer(base_model_name, base_model, ckpt)
self.quotes = quotes_df['Quote'].to_list()
self.authors = quotes_df['Author'].to_list()
self.quote_embeddings = torch.tensor(np.array(
[model_inference(self.quote_embedder, self.tokenizer, q).cpu().detach().numpy() for q in stqdm(self.quotes[:50])]
)).squeeze(1)
def recommend(self, d):
d_emb = model_inference(self.diary_embedder, self.tokenizer, d).squeeze().cpu()
similarities = F.cosine_similarity(d_emb, self.quote_embeddings, dim=0)
above_threshold_indices = (similarities > self.SIMILARITY_THRESHOLD).nonzero().flatten().tolist()
if above_threshold_indices:
index = random.choice(above_threshold_indices)
else:
index = torch.argmax(similarities).item()
return self.quotes[index], self.authors[index]
def get_quote_embeddings(model, tokenizer):
quotes = pd.read_csv('quotes-recsys/data/quotes.csv')['Quote'].to_list()
return torch.tensor([model_inference(model, tokenizer, q) for q in quotes]).squeeze(1)