File size: 3,812 Bytes
c0be431
 
 
 
 
 
 
 
 
 
 
 
 
f0c7657
c0be431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c86acba
 
 
 
 
6bfe382
 
c0be431
 
c86acba
c0be431
49f477f
c0be431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4bfc4a
cc03544
c0be431
 
 
 
 
dfa8941
c0be431
 
 
 
 
 
e4bfc4a
cc03544
c0be431
 
f0c7657
c0be431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import asyncio
import json
import os

import aiohttp
import gradio as gr
import numpy as np
import spaces
from huggingface_hub import InferenceClient

import random
import torch
from huggingface_hub import AsyncInferenceClient
from transformers import LlamaTokenizer, LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM


async def query_llm(payload, model_name):
    headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
    async with aiohttp.ClientSession() as session:
        async with session.post(f"https://api-inference.huggingface.co/models/{model_name}", headers=headers,
                                json=payload) as response:
            return await response.json()


async def generate_mistral_7bvo1(system_input, user_input):
    client = AsyncInferenceClient(
        "mistralai/Mistral-7B-Instruct-v0.1",
        token=os.getenv('HF_TOKEN'),
    )

    async for message in await client.chat_completion(
            messages=[
                {"role": "system", "content": system_input},
                {"role": "user", "content": user_input}, ],
            max_tokens=256,
            stream=True,
    ):
        yield message.choices[0].delta.content


async def generate_t5(system_input, user_input):
    output = await query_llm({
        "inputs": (inputs := f"{system_input}\n{user_input}"),
    }, "google/flan-t5-large")
    yield output[0]["generated_text"]


async def generate_gpt2(system_input, user_input):
    output = await query_llm({
        "inputs": (inputs := f"{system_input}\n{user_input}"),
    }, "openai-community/gpt2")
    yield output[0]["generated_text"][:532]


async def generate_llama2(system_input, user_input):
    client = AsyncInferenceClient(
        "meta-llama/Llama-2-7b-chat-hf",
        token=os.getenv('HF_TOKEN')
    )
    async for message in await client.chat_completion(
            messages=[
                {"role": "system", "content": system_input},
                {"role": "user", "content": user_input}, ],
            max_tokens=256,
            stream=True,
    ):
        yield message.choices[0].delta.content


@spaces.GPU(duration=120)
def generate_openllama(system_input, user_input):
    model_path = 'openlm-research/open_llama_3b_v2'
    tokenizer = LlamaTokenizer.from_pretrained(model_path)
    model = LlamaForCausalLM.from_pretrained(
        model_path, torch_dtype=torch.float16, device_map='cuda',
    )
    print('model openllama loaded')
    input_text = f"{system_input}\n{user_input}"
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
    output = model.generate(input_ids, max_length=128)
    return tokenizer.decode(output[0], skip_special_tokens=True)


@spaces.GPU(duration=120)
def generate_bloom(system_input, user_input):
    model_path = 'bigscience/bloom-7b1'
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path, torch_dtype=torch.float16, device_map='cuda',
    )
    input_text = f"{system_input}\n{user_input}"
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
    output = model.generate(input_ids, max_length=128)
    return tokenizer.decode(output[0], skip_special_tokens=True)


async def generate_llama3(system_input, user_input):
    client = AsyncInferenceClient(
        "meta-llama/Meta-Llama-3.1-8B-Instruct",
        token=os.getenv('HF_TOKEN')
    )
    try:
        async for message in await client.chat_completion(
                messages=[
                    {"role": "system", "content": system_input},
                    {"role": "user", "content": user_input}, ],
                max_tokens=256,
                stream=True,
        ):
            yield message.choices[0].delta.content
    except json.JSONDecodeError:
        pass