LLMhistory / generators.py
freQuensy23's picture
Update generators.py
c1e500d verified
raw
history blame
No virus
3.32 kB
import asyncio
import json
import os
import aiohttp
import gradio as gr
import numpy as np
import spaces
from huggingface_hub import InferenceClient
import random
import torch
from huggingface_hub import AsyncInferenceClient
from transformers import LlamaTokenizer, LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM
async def query_llm(payload, model_name):
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
async with aiohttp.ClientSession() as session:
async with session.post(f"https://api-inference.huggingface.co/models/{model_name}", headers=headers,
json=payload) as response:
return await response.json()
async def generate_mistral_7bvo1(system_input, user_input):
client = AsyncInferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1",
token=os.getenv('HF_TOKEN'),
)
async for message in await client.chat_completion(
messages=[
{"role": "system", "content": system_input},
{"role": "user", "content": user_input}, ],
max_tokens=256,
stream=True,
):
yield message.choices[0].delta.content
async def generate_t5(system_input, user_input):
output = await query_llm({
"inputs": (inputs := f"{system_input}\n{user_input}"),
}, "google/flan-t5-xxl")
try:
yield output[0]["generated_text"]
except (IndexError, KeyError):
yield str(output)
async def generate_gpt2(system_input, user_input):
output = await query_llm({
"inputs": (inputs := f"{system_input}\n{user_input}"),
}, "openai-community/gpt2")
yield output[0]["generated_text"][:532]
async def generate_llama2(system_input, user_input):
client = AsyncInferenceClient(
"meta-llama/Llama-2-7b-chat-hf",
token=os.getenv('HF_TOKEN')
)
async for message in await client.chat_completion(
messages=[
{"role": "system", "content": system_input},
{"role": "user", "content": user_input}, ],
max_tokens=256,
stream=True,
):
yield message.choices[0].delta.content
async def generate_llama3(system_input, user_input):
client = AsyncInferenceClient(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
token=os.getenv('HF_TOKEN')
)
try:
async for message in await client.chat_completion(
messages=[
{"role": "system", "content": system_input},
{"role": "user", "content": user_input}, ],
max_tokens=256,
stream=True,
):
yield message.choices[0].delta.content
except json.JSONDecodeError:
pass
async def generate_mixtral(system_input, user_input):
client = AsyncInferenceClient(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
token=os.getenv('HF_TOKEN')
)
try:
async for message in await client.chat_completion(
messages=[
{"role": "system", "content": system_input},
{"role": "user", "content": user_input}, ],
max_tokens=256,
stream=True,
):
yield message.choices[0].delta.content
except json.JSONDecodeError:
pass