import asyncio import json import os import aiohttp import gradio as gr import numpy as np import spaces from huggingface_hub import InferenceClient import random import torch from huggingface_hub import AsyncInferenceClient from transformers import LlamaTokenizer, LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM async def query_llm(payload, model_name): headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"} async with aiohttp.ClientSession() as session: async with session.post(f"https://api-inference.huggingface.co/models/{model_name}", headers=headers, json=payload) as response: return await response.json() async def generate_mistral_7bvo1(system_input, user_input): client = AsyncInferenceClient( "mistralai/Mistral-7B-Instruct-v0.1", token=os.getenv('HF_TOKEN'), ) async for message in await client.chat_completion( messages=[ {"role": "system", "content": system_input}, {"role": "user", "content": user_input}, ], max_tokens=256, stream=True, ): yield message.choices[0].delta.content async def generate_t5(system_input, user_input): output = await query_llm({ "inputs": (inputs := f"{system_input}\n{user_input}"), }, "google/flan-t5-xxl") try: yield output[0]["generated_text"] except (IndexError, KeyError): yield str(output) async def generate_gpt2(system_input, user_input): output = await query_llm({ "inputs": (inputs := f"{system_input}\n{user_input}"), }, "openai-community/gpt2") yield output[0]["generated_text"][:532] async def generate_llama2(system_input, user_input): client = AsyncInferenceClient( "meta-llama/Llama-2-7b-chat-hf", token=os.getenv('HF_TOKEN') ) async for message in await client.chat_completion( messages=[ {"role": "system", "content": system_input}, {"role": "user", "content": user_input}, ], max_tokens=256, stream=True, ): yield message.choices[0].delta.content async def generate_llama3(system_input, user_input): client = AsyncInferenceClient( "meta-llama/Meta-Llama-3.1-8B-Instruct", token=os.getenv('HF_TOKEN') ) try: async for message in await client.chat_completion( messages=[ {"role": "system", "content": system_input}, {"role": "user", "content": user_input}, ], max_tokens=256, stream=True, ): yield message.choices[0].delta.content except json.JSONDecodeError: pass async def generate_mixtral(system_input, user_input): client = AsyncInferenceClient( "mistralai/Mixtral-8x7B-Instruct-v0.1", token=os.getenv('HF_TOKEN') ) try: async for message in await client.chat_completion( messages=[ {"role": "system", "content": system_input}, {"role": "user", "content": user_input}, ], max_tokens=256, stream=True, ): yield message.choices[0].delta.content except json.JSONDecodeError: pass