File size: 16,492 Bytes
4d6d2dc
9fa8328
 
 
4d6d2dc
9b5c8c6
3dbd18e
9b5c8c6
4d6d2dc
 
293df90
2533b7f
9fa8328
e4c230b
3dbd18e
077e2b3
79df09c
 
4d6d2dc
98858d4
cee7c56
2a69d25
4d6d2dc
af967c9
 
 
 
 
e3b129c
 
 
 
9fa8328
 
 
 
e4c230b
e3b129c
f8fba1a
9fa8328
5ba44ad
fa45463
9fa8328
f5ff0e3
 
d2266c9
5ba44ad
f5ff0e3
 
 
 
 
 
9fa8328
4d6d2dc
1a614f9
 
 
 
4009e7f
9fa8328
 
 
 
 
fa45463
9fa8328
 
b233c7d
f8fba1a
9fa8328
 
 
8f43d2f
9fa8328
b233c7d
e4c230b
 
 
b9bab55
b233c7d
b9bab55
 
9fa8328
 
b5a6906
4009e7f
fac9749
 
6b69a3c
f269195
4009e7f
3e684af
 
4009e7f
9fa8328
 
4d6d2dc
5daf90b
11b86b4
f8fba1a
03b5112
f8fba1a
11b86b4
f8fba1a
ee7058f
af967c9
 
e4cb9e0
f8fba1a
 
ee7058f
 
 
 
b30a06e
de099ae
cee7c56
c23388b
d8c5a8d
 
20c0832
4009e7f
1e4e3c2
de099ae
d75586b
 
 
1e4e3c2
 
 
 
f8fba1a
11b86b4
e3b129c
ce07d7a
4d6d2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fa8328
d75586b
4d6d2dc
 
d75586b
b5a6906
2bb573c
d75586b
d2266c9
e4c230b
0a22698
e36b100
f8f26a8
 
0a22698
a5aded9
7889ca8
 
f724621
 
e524716
7889ca8
0a22698
9ab090f
e4c230b
d2266c9
8e5b8b3
de099ae
d2266c9
b9e0369
542759c
8e5b8b3
11f2e9c
fac9749
cee7c56
4d6d2dc
 
 
 
8f35fab
5ba44ad
cf4e80d
049eed9
d028e6b
ee7058f
d028e6b
b30e55a
26b3274
4009e7f
 
4d6d2dc
 
0d6b098
5ae57a2
 
 
9f98ca2
b29377d
 
 
4e62c15
9f98ca2
 
 
 
868605b
fa45463
 
d2266c9
fa45463
9f98ca2
6b4003c
868605b
 
9fa8328
4d6d2dc
8670d11
62bd403
2a69d25
d2266c9
01e48f0
 
 
 
 
 
 
 
 
5ba44ad
01e48f0
 
5ba44ad
01e48f0
f724621
 
4e62c15
 
5ba44ad
4e62c15
5ba44ad
273c292
 
4d6d2dc
01e48f0
 
 
 
 
 
 
 
 
 
ae20803
 
 
 
 
21ccb1f
 
 
3e51fa7
ae20803
01e48f0
0a22698
c23388b
9136f03
5e8b4c1
4009e7f
e4068da
1e4e3c2
 
e4068da
9136f03
765296c
4009e7f
c23388b
5ba44ad
c7e88d8
5ba44ad
8670d11
 
c7e88d8
4d6d2dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import os
import gc
from typing import Optional
from dataclasses import dataclass
from copy import deepcopy
from functools import partial
import numpy as np
import spaces
import gradio as gr
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
from interpret import InterpretationPrompt
from configs import model_info, dataset_info


MAX_PROMPT_TOKENS = 60
MAX_NUM_LAYERS = 50
welcome_message = '**You are now running {model_name}!!** πŸ₯³πŸ₯³πŸ₯³'

# Used by the layer and token importance heuristic in this file. 
# These layers are usually not important. We will ignore them when looking for important layers 
avoid_first, avoid_last = 3, 2


@dataclass
class LocalState:
    hidden_states: Optional[torch.Tensor] = None

@dataclass
class GlobalState:
    tokenizer : Optional[PreTrainedTokenizer] = None
    model : Optional[PreTrainedModel] = None
    sentence_transformer: Optional[PreTrainedModel] = None
    local_state : LocalState = LocalState()
    wait_with_hidden_state : bool = False
    interpretation_prompt_template : str = '{prompt}'
    original_prompt_template : str = 'User: [X]\n\nAssistant: {prompt}'
    layers_format : str = 'model.layers.{k}'


suggested_interpretation_prompts = [
                                    "Sure, I'll summarize your message:", 
                                    "The meaning of [X] is",
                                    "Sure, here's a bullet list of the key words in your message:",
                                    "Sure, here are the words in your message:",
                                    "Before responding, let me repeat the message you wrote:", 
                                    "Let me repeat the message:"
                                   ]

    
## functions
@spaces.GPU
def initialize_gpu():
    pass

def reset_model(model_name, load_on_gpu, *extra_components, reset_sentence_transformer=False, with_extra_components=True): 
    # extract model info
    model_args = deepcopy(model_info[model_name])
    model_path = model_args.pop('model_path')
    global_state.original_prompt_template = model_args.pop('original_prompt_template')
    global_state.interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
    global_state.layers_format = model_args.pop('layers_format')
    tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
    use_ctransformers = model_args.pop('ctransformers', False)
    dont_cuda = model_args.pop('dont_cuda', False)
    global_state.wait_with_hidden_states = model_args.pop('wait_with_hidden_states', False)
    AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
    
    # get model
    global_state.model, global_state.tokenizer, global_state.local_state.hidden_states = None, None, None
    gc.collect()
    global_state.model = AutoModelClass.from_pretrained(model_path, **model_args)
    if reset_sentence_transformer:
        global_state.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
        gc.collect()
    if not dont_cuda:
        global_state.model.to('cuda')
    if load_on_gpu:
        global_state.model.to('cpu')
    global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
    gc.collect()
    if with_extra_components:
        return ([welcome_message.format(model_name=model_name)] 
                + [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))] 
                + [gr.Button('', visible=False) for _ in range(len(tokens_container))]
                + [*extra_components])
    else:
        return None


def get_hidden_states(raw_original_prompt, force_hidden_states=False):
    model, tokenizer = global_state.model, global_state.tokenizer
    original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
    model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
    tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
    if global_state.wait_with_hidden_states and not force_hidden_states:
        global_state.local_state.hidden_states = None
        important_tokens = [] # cannot find important tokens without the hidden states
    else:
        outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
        hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
        # TODO: document this!
        hidden_scores = F.normalize(hidden_states[avoid_first-1:len(hidden_states)-avoid_last], dim=-1).diff(dim=0).norm(dim=-1).cpu() # num_layers x num_tokens
        important_tokens = np.unravel_index(hidden_scores.flatten().topk(k=5).indices.numpy(), hidden_scores.shape)[1]
        print(f'{important_tokens=}\t\t{hidden_states.shape=}')
        global_state.local_state.hidden_states = hidden_states.cpu().detach()
        
    token_btns = ([gr.Button(token, visible=True, 
                             elem_classes=['token_btn'] + (['important_token'] if i in important_tokens else [])
                            ) 
                   for i, token in enumerate(tokens)] 
                  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
    progress_dummy_output = ''
    invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
    return [progress_dummy_output, *token_btns, *invisible_bubbles]


@spaces.GPU
def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample, 
                       temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i, 
                       num_beams=1):
    model = global_state.model
    tokenizer = global_state.tokenizer
    print(f'run {model}')
    if use_gpu:
        model = model.cuda()
    else:
        model = model.cpu()
    if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
        get_hidden_states(raw_original_prompt, force_hidden_states=True)
    interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
    hidden_means = torch.tensor(global_state.local_state.hidden_states.mean(dim=1)).to(model.device).to(model.dtype)
    length_penalty = -length_penalty   # unintuitively, length_penalty > 0 will make sequences longer, so we negate it

    # generation parameters
    generation_kwargs = {
        'max_new_tokens': int(max_new_tokens),
        'do_sample': do_sample,
        'temperature': temperature,
        'top_k': int(top_k),
        'top_p': top_p,
        'repetition_penalty': repetition_penalty,
        'length_penalty': length_penalty,
        'num_beams': int(num_beams)
    }
    
    # create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
    interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
    interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)

    # generate the interpretations
    generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, 
                                               layers_format=global_state.layers_format, k=3, 
                                               **generation_kwargs)
    generation_texts = tokenizer.batch_decode(generated)

    # try identifying important layers
    vectors_to_compare = interpreted_vectors # torch.tensor(global_state.sentence_transformer.encode(generation_texts))
    diff_score1 = F.normalize(vectors_to_compare, dim=-1).diff(dim=0).norm(dim=-1).cpu()
    tokenized_generations = [tokenizer.tokenize(text) for text in generation_texts]
    bags_of_words = [set(tokens) | set([(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)]) for tokens in tokenized_generations]
    diff_score2 = torch.tensor([
                                -len(bags_of_words[i+1] & bags_of_words[i]) / np.sqrt(len(bags_of_words[i+1]) * len(bags_of_words[i]))
                                for i in range(len(bags_of_words)-1)
                              ])
    diff_score = ((diff_score1 - diff_score1.min()) / (diff_score1.max() - diff_score1.min()) 
                  + (diff_score2 - diff_score2.min()) / (diff_score2.max() - diff_score2.min()))
    
    assert avoid_first >= 1 # due to .diff() we will not be able to compute a score for the first layer
    diff_score = diff_score[avoid_first-1:len(diff_score)-avoid_last]
    important_idxs = avoid_first + diff_score.topk(k=int(np.ceil(0.3 * len(diff_score)))).indices.cpu().numpy() # 
    
    # create GUI output
    print(f'{important_idxs=}')
    progress_dummy_output = ''
    elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] + 
                    ([] if i in important_idxs else ['faded_bubble']) for i in range(len(generation_texts))]
    bubble_outputs = [gr.Textbox(text.replace('\n', ' '), show_label=True, visible=True, 
                                 container=True, label=f'Layer {i}', elem_classes=elem_classes[i])
                      for i, text in enumerate(generation_texts)]
    bubble_outputs += [gr.Textbox('', visible=False) for _ in range(MAX_NUM_LAYERS - len(bubble_outputs))]
    return [progress_dummy_output, *bubble_outputs]


## main
torch.set_grad_enabled(False)
model_name = 'LLAMA2-13B'
raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
tokens_container = []

for i in range(MAX_PROMPT_TOKENS):
    btn = gr.Button('', visible=False)
    tokens_container.append(btn)

with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
    global_state = GlobalState()
    reset_model(model_name, load_on_gpu=True, with_extra_components=False, reset_sentence_transformer=True)
    with gr.Row():
        with gr.Column(scale=5):
            gr.Markdown('# 😎 Self-Interpreting Models')

            gr.Markdown('<b style="color: #8B0000;">Model outputs are not filtered and might include undesired language!</b>')
            
            gr.Markdown(
            '''
                **πŸ‘Ύ This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!πŸ‘Ύ**
                This idea was investigated in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was further explored in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)). 
                Honorary mention: **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) - my own work πŸ₯³). It was less mature but had the same idea in mind. I think it can be a great introduction to the subject!
                We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!! 
            ''', line_breaks=True)
            
            gr.Markdown(
            '''
                **πŸ‘Ύ The idea is really simple: models are able to understand their own hidden states by nature! πŸ‘Ύ**
                In line with the residual stream view ([nostalgebraist, 2020](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)), internal representations from different layers are transferable between layers. 
                So we can inject an representation from (roughly) any layer into any layer! If we give a model a prompt of the form ``User: [X] Assistant: Sure, I'll repeat your message`` and replace the internal representation of ``[X]`` *during computation* with the hidden state we want to understand, 
                we expect to get back a summary of the information that exists inside the hidden state, despite being from a different layer and a different run!! How cool is that! 😯😯😯
            ''', line_breaks=True)

        # with gr.Column(scale=1):    
        #     gr.Markdown('<span style="font-size:180px;">πŸ€”</span>')

    with gr.Group():
        # model_chooser = gr.Radio(label='Choose Your Model', choices=list(model_info.keys()), value=model_name)
        load_on_gpu = gr.Checkbox(label='Load on GPU', visible=False, value=True)
        welcome_model = gr.Markdown(welcome_message.format(model_name=model_name))
    with gr.Blocks() as demo_main:
        gr.Markdown('## The Prompt to Analyze')        
        for info in dataset_info:
            with gr.Tab(info['name']):
                num_examples = 10
                dataset = load_dataset(info['hf_repo'], split='train', streaming=True)
                if 'filter' in info:
                    dataset = dataset.filter(info['filter'])
                dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
                dataset = [[row[info['text_col']]] for row in dataset]
                gr.Examples(dataset, [raw_original_prompt], cache_examples=False)
                
        with gr.Group():
            raw_original_prompt.render()
            original_prompt_btn = gr.Button('Output Token List', variant='primary')
            gr.Markdown('**Tokens will appear in the "Tokens" section**')
            
        gr.Markdown('## Choose Your Interpretation Prompt')
        with gr.Group('Interpretation'):
            raw_interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
            interpretation_prompt_examples = gr.Examples([[p] for p in suggested_interpretation_prompts], 
                                                         [raw_interpretation_prompt], cache_examples=False)
       
        with gr.Accordion(open=False, label='Generation Settings'):
            with gr.Row():
                num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
                repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
                length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
                # num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
            do_sample = gr.Checkbox(label='With sampling')
            with gr.Accordion(label='Sampling Parameters'):
                with gr.Row():
                    temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
                    top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
                    top_p = gr.Slider(0., 1., value=0.95, label='top p')
     
        gr.Markdown('''
                ## Tokens
                ### Here go the tokens of the prompt (click on the one to explore)
                ''')
        with gr.Row():
            for btn in tokens_container:
                btn.render()
        use_gpu = gr.Checkbox(label='Use GPU', visible=False, value=True)
        
        progress_dummy = gr.Markdown('', elem_id='progress_dummy')
        interpretation_bubbles = [gr.Textbox('', container=False, visible=False) for i in range(MAX_NUM_LAYERS)]
        
    # event listeners    
    for i, btn in enumerate(tokens_container):
        btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt, 
                                                     num_tokens, do_sample, temperature, 
                                                     top_k, top_p, repetition_penalty, length_penalty,
                                                     use_gpu
                                                    ], [progress_dummy, *interpretation_bubbles])

    original_prompt_btn.click(get_hidden_states, 
                              [raw_original_prompt], 
                              [progress_dummy, *tokens_container, *interpretation_bubbles])
    raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)

    extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
    # model_chooser.change(reset_model, [model_chooser, load_on_gpu, *extra_components], 
    #                      [welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
    
    demo.launch()