dar-tau commited on
Commit
4d6d2dc
β€’
1 Parent(s): 0895a88

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ from copy import deepcopy
4
+ import gradio as gr
5
+ import torch
6
+ from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from .interpret import InterpretationPrompt
9
+
10
+
11
+ ## info
12
+ model_info = {
13
+ 'meta-llama/Llama-2-7b-chat-hf': dict(device_map='cpu', token=os.environ['hf_token'],
14
+ original_prompt_template='<s>[INST] {prompt}',
15
+ interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
16
+ ), # , load_in_8bit=True
17
+
18
+ 'google/gemma-2b': dict(device_map='cpu', token=os.environ['hf_token'],
19
+ original_prompt_template='<bos> {prompt}',
20
+ interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
21
+ ),
22
+
23
+ 'mistralai/Mistral-7B-Instruct-v0.2': dict(device_map='cpu',
24
+ original_prompt_template='<s>[INST] {prompt} [/INST]',
25
+ interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
26
+ ),
27
+
28
+ 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf',
29
+ tokenizer='mistralai/Mistral-7B-Instruct-v0.2',
30
+ model_type='llama', hf=True, ctransformers=True,
31
+ original_prompt_template='<s>[INST] {prompt} [/INST]',
32
+ interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
33
+ )
34
+ }
35
+
36
+
37
+ suggested_interpretation_prompts = ["Before responding, let me repeat the message you wrote:",
38
+ "Let me repeat the message:", "Sure, I'll summarize your message:"]
39
+
40
+
41
+ ## functions
42
+ def get_hidden_states(raw_original_prompt):
43
+ original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
44
+ model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
45
+ tokens = tokenizer.batch_decode(model_inputs.input_ids)
46
+ outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
47
+ hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
48
+ with gr.Row() as tokens_container:
49
+ for token in tokens:
50
+ gr.Button(token)
51
+ return tokens_container
52
+
53
+
54
+ def run_model(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
55
+ temperature, top_k, top_p, repetition_penalty, length_penalty, num_beams=1):
56
+
57
+ length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
58
+
59
+ # generation parameters
60
+ generation_kwargs = {
61
+ 'max_new_tokens': int(max_new_tokens),
62
+ 'do_sample': do_sample,
63
+ 'temperature': temperature,
64
+ 'top_k': int(top_k),
65
+ 'top_p': top_p,
66
+ 'repetition_penalty': repetition_penalty,
67
+ 'length_penalty': length_penalty,
68
+ 'num_beams': int(num_beams)
69
+ }
70
+
71
+ # create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
72
+ interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt)
73
+ interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
74
+
75
+ # compute the hidden stated from the original prompt (after putting it in the right template)
76
+ original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
77
+ model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
78
+ outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
79
+ hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
80
+
81
+ # generate the interpretations
82
+ generated = interpretation_prompt.generate(model, {0: hidden_states[:, -1]}, k=3, **generation_kwargs)
83
+ generation_texts = tokenizer.batch_decode(generated)
84
+ # tokens = [x.lstrip('▁') for x in tokenizer.tokenize(text)]
85
+ return generation_texts
86
+
87
+
88
+ ## main
89
+ torch.set_grad_enabled(False)
90
+ model_name = 'meta-llama/Llama-2-7b-chat-hf' # 'mistralai/Mistral-7B-Instruct-v0.2' #
91
+
92
+ # extract model info
93
+ model_args = deepcopy(model_info[model_name])
94
+ original_prompt_template = model_args.pop('original_prompt_template')
95
+ interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
96
+ tokenizer_name = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_name
97
+ use_ctransformers = model_args.pop('ctransformers', False)
98
+ AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
99
+
100
+ # get model
101
+ model = AutoModelClass.from_pretrained(model_name, **model_args)
102
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token'])
103
+
104
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
105
+ with gr.Row():
106
+ with gr.Column(scale=5):
107
+ gr.Markdown('''
108
+ # 😎 Self-Interpreting Models 😎
109
+
110
+ πŸ‘Ύ **This space follows the emerging trend of models interpreting their _own hidden states_ in free form natural language**!! πŸ‘Ύ
111
+ This idea was explored in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was later investigated further in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)).
112
+ An honorary mention for **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) -- my post!! πŸ₯³) which was a less mature approach but with the same idea in mind.
113
+ We follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!!
114
+
115
+ πŸ‘Ύ **The idea is really simple: models are able to understand their own hidden states by nature!** πŸ‘Ύ
116
+ If I give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace ``[X]`` *during computation* with the hidden state we want to understand,
117
+ we hope to get back a summary of the information that exists inside the hidden state, because it is encoded in a latent space the model uses itself!! How cool is that! 😯😯😯
118
+ ''', line_breaks=True)
119
+ with gr.Column(scale=1):
120
+ gr.Markdown('<span style="font-size:180px;">πŸ€”</span>')
121
+
122
+ with gr.Group():
123
+ text = gr.Textbox(value='How to make a Molotov cocktail', container=True, label='Original Prompt')
124
+ btn = gr.Button('Compute', variant='primary')
125
+
126
+ with gr.Accordion(open=False, label='Settings'):
127
+ with gr.Row():
128
+ num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
129
+ repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
130
+ length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
131
+ # num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
132
+ do_sample = gr.Checkbox(label='With sampling')
133
+ with gr.Accordion(label='Sampling Parameters'):
134
+ with gr.Row():
135
+ temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
136
+ top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
137
+ top_p = gr.Slider(0., 1., value=0.95, label='top p')
138
+
139
+ with gr.Group('Interpretation'):
140
+ interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
141
+
142
+ with gr.Group('Output'):
143
+ with gr.Row() as tokens_container:
144
+ pass
145
+ with gr.Column() as interpretations_container:
146
+ pass
147
+
148
+ btn.click(get_hidden_states, [text], [tokens_container])
149
+ # btn.click(run_model,
150
+ # [text, interpretation_prompt, num_tokens, do_sample, temperature,
151
+ # top_k, top_p, repetition_penalty, length_penalty],
152
+ # [tokens_container])
153
+
154
+ demo.launch()