File size: 5,479 Bytes
c41146d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
""" fine_tuning_app.py

Running a basic chatbot app that can compare base and fine-tuned models from Hugging face.

Note:
 - run using streamlit run fine_tuning_app.py
 - use free -h then sudo sysctl vm.drop_caches=2 to ensure I have cache space but this can mess up the venv
 - may need to run huggingface-cli login in terminal to enable access to model
 - Or: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/discussions/130 for above
 - Hugging face can use up a lot of disc space - cd ~/.cache/huggingface/hub then rm -rf <subdir>

"""

import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import time
import torch
from pynvml import * # needs restart of IDE to install, from nvidia-ml-py3

# ---------------------------------------------------------------------------------------
#                                     GENERAL SETUP:
# ---------------------------------------------------------------------------------------

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
hf_token = ""
# model_name = "thebigoed/PreFineLlama-3.1-8B" # this works badly as it does not know chat structure
# model_name = "unsloth/Meta-Llama-3.1-8B-bnb-4bit" # this is what we were fine tuning - also bad without chat instruct
# model_name = "Qwen/Qwen2.5-7B-Instruct" # working well now
# model_name = "meta-llama/Meta-Llama-3-8B-Instruct" # very effective. NB: if using fine grained access token, make sure it can access gated repos
st.title("Fine Tuning Testing")
col1, col2 = st.columns(2)
if 'conversation' not in st.session_state:
    st.session_state.conversation = []
user_input = st.text_input("You:", "") # user input

def print_gpu_utilization():
    # Used for basic resource monioring.
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.")

# ---------------------------------------------------------------------------------------
#                                     MODEL SETUP:
# ---------------------------------------------------------------------------------------

@st.cache_resource(show_spinner=False)
def load_model():
    """ Load model from Hugging face."""
    print_gpu_utilization()
    # see https://huggingface.co/mlabonne/FineLlama-3.1-8B for how to run
    # https://huggingface.co/docs/transformers/main/en/chat_templating look into this to decide on how we do templating
    success_placeholder = st.empty()
    with st.spinner("Loading model... please wait"):
        if str(DEVICE) == "cuda:0": # may not need this, need to test on CPU if device map is okay anyway
            tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
        else:
            tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype="auto")

        model = AutoModelForCausalLM.from_pretrained(model_name,
                                                     torch_dtype="auto",
                                                     device_map="auto"
                                                    )

        # Not using terminators at the moment
        #terminator = tokenizer.eos_token if tokenizer.eos_token else "<|endoftext|>"

    success_placeholder.success("Model loaded successfully!", icon="🔥")
    time.sleep(2)
    success_placeholder.empty()
    print_gpu_utilization()
    return model, tokenizer


def generate_response():
    """ Query the model. """

    success_placeholder = st.empty()
    with st.spinner("Thinking..."):

        # Tokenising the conversation
        if tokenizer.chat_template:
            text = tokenizer.apply_chat_template(st.session_state.conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
        else: # base models do not have chat templates
            print("Assuming base model.")
            model_input = ""
            for entry in st.session_state.conversation:
                model_input += f"{entry['role']}: {entry['content']}\n"
            text = tokenizer(model_input + "assistant: ", return_tensors="pt")["input_ids"].to(DEVICE)
        outputs = model.generate(text,
                                max_new_tokens=512,
                                )
        outputs = tokenizer.batch_decode(outputs[:,text.shape[1]:], skip_special_tokens=True)[0]
        print_gpu_utilization()
    
    success_placeholder.success("Response generated!", icon="✅")
    time.sleep(2)
    success_placeholder.empty()
    return outputs

# ---------------------------------------------------------------------------------------
#                                     RUNTIME EVENTS:
# ---------------------------------------------------------------------------------------

model, tokenizer = load_model()

# Submit button to send the query
with col1:
    if st.button("send"):
        if user_input:
            st.session_state.conversation.append({"role": "user", "content": user_input})
            st.session_state.conversation.append({"role": "assistant", "content": generate_response()})

# Clear button to reset
with col2:
    if st.button("clear chat"):
        if user_input:
            st.session_state.conversation = []

# Display conversation history
for chat in st.session_state.conversation:
    if chat['role'] == 'user':
        st.write(f"You: {chat['content']}")
    else:
        st.write(f"Assistant: {chat['content']}")