File size: 1,901 Bytes
1e7dec2
d9983c0
1e7dec2
 
 
 
 
 
 
 
d9983c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e7dec2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import streamlit as st
import os
from huggingface_hub import InferenceClient

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

def format_prompt(agent_id, message):
    return f"<s><Agent {agent_id}>[INST] {message} [/INST]"

def generate_response(agent_id, message):
    api_url = "https://api-inference.huggingface.co/models/mistralai/mistral-tiny"  # Replace with your model path
    headers = {
        'Authorization': f'Bearer {os.getenv("HUGGINGFACE_API_KEY")}'
    }
    data = {
        "inputs": {
            "past_user_inputs": [],
            "generated_responses": [],
            "text": message
        }
    }

    response = requests.post(api_url, headers=headers, json=data)
    if response.status_code == 200:
        response_data = response.json()  # Parse the JSON response into a dictionary
        # Ensure that 'generated_text' is accessed from a dictionary
        if isinstance(response_data, dict) and 'generated_text' in response_data:
            return response_data['generated_text']
        else:
            # Handle unexpected response format
            return "Received an unexpected response format from the API."
    else:
        return f"Error: {response.status_code}"


def generate_group_therapy_responses(message):
    responses = []
    for agent_id in range(1, 5):  # Four agents
        agent_response = generate_response(agent_id, message)
        responses.append(f"Agent {agent_id}: {agent_response}")
    return responses

st.title("Group Therapy Simulation")
st.write("Ask a question and receive perspectives from four different agents.")

user_question = st.text_input("Your Question:")

if st.button("Get Responses"):
    if user_question:
        responses = generate_group_therapy_responses(user_question)
        for response in responses:
            st.write(response)
    else:
        st.write("Please enter a question.")