|
import asyncio |
|
import json |
|
import logging |
|
|
|
import click |
|
import requests |
|
from dotenv import load_dotenv |
|
|
|
from hugginggpt import generate_response, infer, plan_tasks |
|
from hugginggpt.history import ConversationHistory |
|
from hugginggpt.llm_factory import LLMs, create_llms |
|
from hugginggpt.log import setup_logging |
|
from hugginggpt.model_inference import TaskSummary |
|
from hugginggpt.model_selection import select_hf_models |
|
from hugginggpt.response_generation import format_response |
|
|
|
load_dotenv() |
|
setup_logging() |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@click.command() |
|
@click.option("-p", "--prompt", type=str, help="Prompt for huggingGPT") |
|
def main(prompt): |
|
_print_banner() |
|
llms = create_llms() |
|
if prompt: |
|
standalone_mode(user_input=prompt, llms=llms) |
|
|
|
else: |
|
interactive_mode(llms=llms) |
|
|
|
|
|
def standalone_mode(user_input: str, llms: LLMs) -> str: |
|
try: |
|
response, task_summaries = compute( |
|
user_input=user_input, |
|
history=ConversationHistory(), |
|
llms=llms, |
|
) |
|
pretty_response = format_response(response) |
|
print(pretty_response) |
|
return pretty_response |
|
except Exception as e: |
|
logger.exception("") |
|
print( |
|
f"Sorry, encountered error: {e}. Please try again. Check logs if problem persists." |
|
) |
|
|
|
|
|
def interactive_mode(llms: LLMs): |
|
print("Please enter your request. End the conversation with 'exit'") |
|
history = ConversationHistory() |
|
while True: |
|
try: |
|
user_input = click.prompt("User") |
|
if user_input.lower() == "exit": |
|
break |
|
|
|
logger.info(f"User input: {user_input}") |
|
response, task_summaries = compute( |
|
user_input=user_input, |
|
history=history, |
|
llms=llms, |
|
) |
|
pretty_response = format_response(response) |
|
print(f"Assistant:{pretty_response}") |
|
|
|
history.add(role="user", content=user_input) |
|
history.add(role="assistant", content=response) |
|
except Exception as e: |
|
logger.exception("") |
|
print( |
|
f"Sorry, encountered error: {e}. Please try again. Check logs if problem persists." |
|
) |
|
|
|
|
|
def compute( |
|
user_input: str, |
|
history: ConversationHistory, |
|
llms: LLMs, |
|
) -> (str, list[TaskSummary]): |
|
tasks = plan_tasks( |
|
user_input=user_input, history=history, llm=llms.task_planning_llm |
|
) |
|
|
|
sorted(tasks, key=lambda t: max(t.dep)) |
|
logger.info(f"Sorted tasks: {tasks}") |
|
|
|
hf_models = asyncio.run( |
|
select_hf_models( |
|
user_input=user_input, |
|
tasks=tasks, |
|
model_selection_llm=llms.model_selection_llm, |
|
output_fixing_llm=llms.output_fixing_llm, |
|
) |
|
) |
|
|
|
task_summaries = [] |
|
with requests.Session() as session: |
|
for task in tasks: |
|
logger.info(f"Starting task: {task}") |
|
if task.depends_on_generated_resources(): |
|
task = task.replace_generated_resources(task_summaries=task_summaries) |
|
model = hf_models[task.id] |
|
inference_result = infer( |
|
task=task, |
|
model_id=model.id, |
|
llm=llms.model_inference_llm, |
|
session=session, |
|
) |
|
task_summaries.append( |
|
TaskSummary( |
|
task=task, |
|
model=model, |
|
inference_result=json.dumps(inference_result), |
|
) |
|
) |
|
logger.info(f"Finished task: {task}") |
|
logger.info("Finished all tasks") |
|
logger.debug(f"Task summaries: {task_summaries}") |
|
|
|
response = generate_response( |
|
user_input=user_input, |
|
task_summaries=task_summaries, |
|
llm=llms.response_generation_llm, |
|
) |
|
return response, task_summaries |
|
|
|
|
|
def _print_banner(): |
|
with open("resources/banner.txt", "r") as f: |
|
banner = f.read() |
|
logger.info("\n" + banner) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|