m_dialo_m / app.py
raushanrajcareer's picture
Update app.py
07a4da5 verified
raw
history blame contribute delete
No virus
2.91 kB
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datafile import QL, resume_data_dict
import streamlit as st
from tensorflow import keras
import tensorflow as tf
# Load your custom data
data = []
for i in range(len(QL['labels'])):
data.append({"question":QL["queries"][i], "answer":resume_data_dict[QL['labels'][i]]})
# Create a Dataset
dataset = Dataset.from_list(data)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def preprocess_function(examples):
inputs = [f"Question: {q}" for q in examples["question"]]
model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=128)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(examples["answer"], padding="max_length", truncation=True, max_length=128)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
# Apply preprocessing
tokenized_dataset = dataset.map(preprocess_function, batched=True)
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
model.resize_token_embeddings(len(tokenizer))
# Define training arguments
training_args = TrainingArguments(
output_dir="./resume_bot",
num_train_epochs=3,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
warmup_steps=10,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
save_steps=500,
evaluation_strategy="steps"
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
eval_dataset=tokenized_dataset,
tokenizer=tokenizer
)
# Train the model
trainer.train()
model_name = "./resume_bot" # Path to your fine-tuned model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
st.title("Resume Chatbot")
if 'history' not in st.session_state:
st.session_state.history = []
user_input = st.text_input("You: ", "")
if user_input:
# Encode the input
input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
try:
response_ids = model.generate(input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
bot_response = tokenizer.decode(response_ids[0], skip_special_tokens=True)
# Update the chat history
st.session_state.history.append(f"You: {user_input}")
st.session_state.history.append(f"Bot: {bot_response}")
# Display the bot response
st.write(f"Bot: {bot_response}")
except Exception as e:
st.error(f"Error generating response: {e}")
# Add a button to clear the conversation
if st.button("Reset Conversation"):
st.session_state.history = []