mixtherapy / app.py
sifujohn's picture
Update app.py
b041f65 verified
raw
history blame contribute delete
No virus
1.9 kB
import streamlit as st
import os
from huggingface_hub import InferenceClient
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
def format_prompt(agent_id, message):
return f"<s><Agent {agent_id}>[INST] {message} [/INST]"
def generate_response(agent_id, message):
api_url = "https://api-inference.huggingface.co/models/mistralai/mistral-tiny" # Replace with your model path
headers = {
'Authorization': f'Bearer {os.getenv("HUGGINGFACE_API_KEY")}'
}
data = {
"inputs": {
"past_user_inputs": [],
"generated_responses": [],
"text": message
}
}
response = requests.post(api_url, headers=headers, json=data)
if response.status_code == 200:
response_data = response.json() # Parse the JSON response into a dictionary
# Ensure that 'generated_text' is accessed from a dictionary
if isinstance(response_data, dict) and 'generated_text' in response_data:
return response_data['generated_text']
else:
# Handle unexpected response format
return "Received an unexpected response format from the API."
else:
return f"Error: {response.status_code}"
def generate_group_therapy_responses(message):
responses = []
for agent_id in range(1, 5): # Four agents
agent_response = generate_response(agent_id, message)
responses.append(f"Agent {agent_id}: {agent_response}")
return responses
st.title("Group Therapy Simulation")
st.write("Ask a question and receive perspectives from four different agents.")
user_question = st.text_input("Your Question:")
if st.button("Get Responses"):
if user_question:
responses = generate_group_therapy_responses(user_question)
for response in responses:
st.write(response)
else:
st.write("Please enter a question.")