mgoin commited on
Commit
185a29e
1 Parent(s): e6ebb47

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +300 -0
README.md ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - fp8
4
+ ---
5
+ Quantized using the script below:
6
+
7
+ Command:
8
+ ```bash
9
+ python quantize.py --model-id mistralai/Mixtral-8x7B-Instruct-v0.1 --save-dir Mixtral-8x7B-Instruct-v0.1-FP8 --num-samples 512
10
+ ```
11
+
12
+ Script:
13
+ ```python
14
+ import argparse
15
+ import gc
16
+ import re
17
+ from typing import Tuple
18
+
19
+ import torch
20
+ import torch.functional as F
21
+ import transformers
22
+ from datasets import load_dataset
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer
24
+
25
+
26
+ # HACK: override the dtype_byte_size function in transformers to support float8 types
27
+ def new_dtype_byte_size(dtype):
28
+ if dtype == torch.bool:
29
+ return 1 / 8
30
+ bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
31
+ if bit_search is None:
32
+ raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
33
+ bit_size = int(bit_search.groups()[0])
34
+ return bit_size // 8
35
+
36
+
37
+ transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
38
+
39
+
40
+ def cleanup_memory():
41
+ gc.collect()
42
+ torch.cuda.empty_cache()
43
+
44
+
45
+ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
46
+ """Quantize a tensor using per-tensor static scaling factor.
47
+
48
+ Args:
49
+ tensor: The input tensor.
50
+ """
51
+ finfo = torch.finfo(torch.float8_e4m3fn)
52
+ # Calculate the scale as dtype max divided by absmax.
53
+ # Since .abs() creates a new tensor, we use aminmax to get
54
+ # the min and max first and then calculate the absmax.
55
+ if tensor.numel() == 0:
56
+ # Deal with empty tensors (triggered by empty MoE experts)
57
+ min_val, max_val = (
58
+ torch.tensor(0.0, dtype=tensor.dtype),
59
+ torch.tensor(1.0, dtype=tensor.dtype),
60
+ )
61
+ else:
62
+ min_val, max_val = tensor.aminmax()
63
+ amax = min_val.abs().max(max_val.abs())
64
+ scale = finfo.max / amax.clamp(min=1e-12)
65
+ # scale and clamp the tensor to bring it to
66
+ # the representative range of float8 data type
67
+ # (as default cast is unsaturated)
68
+ qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
69
+ # Return both float8 data and the inverse scale (as float),
70
+ # as both required as inputs to torch._scaled_mm
71
+ qweight = qweight.to(torch.float8_e4m3fn)
72
+ scale = scale.float().reciprocal()
73
+ return qweight, scale
74
+
75
+
76
+ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
77
+ cuda_compute_capability = torch.cuda.get_device_capability()
78
+ if cuda_compute_capability >= (9, 0):
79
+ output, _ = torch._scaled_mm(
80
+ A,
81
+ B.t(),
82
+ out_dtype=out_dtype,
83
+ scale_a=A_scale,
84
+ scale_b=B_scale,
85
+ bias=bias,
86
+ )
87
+ else:
88
+ output = torch.nn.functional.linear(
89
+ A.to(out_dtype) * A_scale,
90
+ B.to(out_dtype) * B_scale.to(out_dtype),
91
+ bias=bias,
92
+ )
93
+ return output
94
+
95
+
96
+ class FP8StaticLinearQuantizer(torch.nn.Module):
97
+ def __init__(self, qweight, weight_scale):
98
+ super().__init__()
99
+ self.weight = torch.nn.Parameter(qweight, requires_grad=False)
100
+ self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
101
+ self.act_scale = None
102
+
103
+ def forward(self, x):
104
+ # Dynamically quantize
105
+ qinput, x_act_scale = per_tensor_quantize(x)
106
+
107
+ # Update scale if needed.
108
+ if self.act_scale is None:
109
+ self.act_scale = torch.nn.Parameter(x_act_scale)
110
+ elif x_act_scale > self.act_scale:
111
+ self.act_scale = torch.nn.Parameter(x_act_scale)
112
+
113
+ # Pass quantized to next layer so it has realistic data.
114
+ output = fp8_gemm(
115
+ A=qinput,
116
+ A_scale=self.act_scale,
117
+ B=self.weight,
118
+ B_scale=self.weight_scale,
119
+ bias=None,
120
+ out_dtype=x.dtype,
121
+ )
122
+ return output
123
+
124
+
125
+ class FP8StaticLinear(torch.nn.Module):
126
+ def __init__(self, qweight, weight_scale, act_scale=0.0):
127
+ super().__init__()
128
+ self.weight = torch.nn.Parameter(qweight, requires_grad=False)
129
+ self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
130
+ self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False)
131
+
132
+ def per_tensor_quantize(
133
+ self, tensor: torch.Tensor, inv_scale: float
134
+ ) -> torch.Tensor:
135
+ # Scale and clamp the tensor to bring it to
136
+ # the representative range of float8 data type
137
+ # (as default cast is unsaturated)
138
+ finfo = torch.finfo(torch.float8_e4m3fn)
139
+ qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
140
+ return qweight.to(torch.float8_e4m3fn)
141
+
142
+ def forward(self, x):
143
+ qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale)
144
+ output = fp8_gemm(
145
+ A=qinput,
146
+ A_scale=self.act_scale,
147
+ B=self.weight,
148
+ B_scale=self.weight_scale,
149
+ bias=None,
150
+ out_dtype=x.dtype,
151
+ )
152
+ return output
153
+
154
+
155
+ class FP8DynamicLinear(torch.nn.Module):
156
+ def __init__(self, qweight, scale):
157
+ super().__init__()
158
+ self.weight = torch.nn.Parameter(qweight, requires_grad=False)
159
+ self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
160
+
161
+ def forward(self, x):
162
+ qinput, x_scale = per_tensor_quantize(x)
163
+ output = fp8_gemm(
164
+ A=qinput,
165
+ A_scale=x_scale,
166
+ B=self.weight,
167
+ B_scale=self.weight_scale,
168
+ bias=None,
169
+ out_dtype=x.dtype,
170
+ )
171
+ return output
172
+
173
+
174
+ def replace_module(model, name, new_module):
175
+ if "." in name:
176
+ parent_name = name.rsplit(".", 1)[0]
177
+ child_name = name[len(parent_name) + 1 :]
178
+ parent = model.model.get_submodule(parent_name)
179
+ else:
180
+ parent_name = ""
181
+ parent = model.model
182
+ child_name = name
183
+ setattr(parent, child_name, new_module)
184
+
185
+
186
+ def quantize_weights(model):
187
+ for name, linear in model.model.named_modules():
188
+ if "gate" in name or not isinstance(linear, torch.nn.Linear):
189
+ continue
190
+ quant_weight, quant_scale = per_tensor_quantize(linear.weight)
191
+ quant_linear = FP8DynamicLinear(quant_weight, quant_scale)
192
+ replace_module(model, name, quant_linear)
193
+ del linear
194
+ cleanup_memory()
195
+
196
+
197
+ def quantize_activations(model, calibration_tokens):
198
+ # Replace layers with quantizer.
199
+ for name, dynamic_quant_linear in model.model.named_modules():
200
+ if "gate" in name or not isinstance(dynamic_quant_linear, FP8DynamicLinear):
201
+ continue
202
+ quantizer = FP8StaticLinearQuantizer(
203
+ dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale
204
+ )
205
+ replace_module(model, name, quantizer)
206
+ del dynamic_quant_linear
207
+ cleanup_memory()
208
+
209
+ # Calibration.
210
+ for row_idx in range(calibration_tokens.shape[0]):
211
+ _ = model(calibration_tokens[row_idx].reshape(1, -1))
212
+
213
+ # Replace quantizer with StaticLayer.
214
+ for name, quantizer in model.model.named_modules():
215
+ if "gate" in name or not isinstance(quantizer, FP8StaticLinearQuantizer):
216
+ continue
217
+ static_proj = FP8StaticLinear(
218
+ quantizer.weight, quantizer.weight_scale, quantizer.act_scale
219
+ )
220
+ replace_module(model, name, static_proj)
221
+ del quantizer
222
+ cleanup_memory()
223
+
224
+
225
+ def save_quantized_model(model, activation_scheme, save_dir):
226
+ print(f"Saving the model to {save_dir}")
227
+ static_q_dict = {
228
+ "quantization_config": {
229
+ "quant_method": "fp8",
230
+ "activation_scheme": activation_scheme,
231
+ }
232
+ }
233
+ model.config.update(static_q_dict)
234
+ model.save_pretrained(save_dir)
235
+ tokenizer.save_pretrained(save_dir)
236
+
237
+
238
+ if __name__ == "__main__":
239
+ parser = argparse.ArgumentParser()
240
+ parser.add_argument("--model-id", type=str)
241
+ parser.add_argument("--save-dir", type=str)
242
+ parser.add_argument(
243
+ "--activation-scheme", type=str, default="static", choices=["static", "dynamic"]
244
+ )
245
+ parser.add_argument("--num-samples", type=int, default=512)
246
+ parser.add_argument("--max-seq-len", type=int, default=512)
247
+ args = parser.parse_args()
248
+
249
+ tokenizer = AutoTokenizer.from_pretrained(args.model_id)
250
+ sample_input_tokens = tokenizer.apply_chat_template(
251
+ [{"role": "user", "content": "What is your name?"}],
252
+ add_generation_prompt=True,
253
+ return_tensors="pt",
254
+ ).to("cuda")
255
+
256
+ ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
257
+ ds = ds.shuffle(seed=42).select(range(args.num_samples))
258
+ ds = ds.map(
259
+ lambda batch: {
260
+ "text": tokenizer.apply_chat_template(batch["messages"], tokenize=False)
261
+ }
262
+ )
263
+ tokenizer.pad_token_id = tokenizer.eos_token_id
264
+ calibration_tokens = tokenizer(
265
+ ds["text"],
266
+ return_tensors="pt",
267
+ truncation=True,
268
+ padding="max_length",
269
+ max_length=args.max_seq_len,
270
+ add_special_tokens=False,
271
+ ).input_ids.to("cuda")
272
+ print("Calibration tokens:", calibration_tokens.shape)
273
+
274
+ # Load and test the model
275
+ model = AutoModelForCausalLM.from_pretrained(
276
+ args.model_id, torch_dtype="auto", device_map="auto"
277
+ )
278
+ print(model)
279
+ output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
280
+ print("ORIGINAL:\n", tokenizer.decode(output[0]), "\n\n")
281
+
282
+ # Quantize weights.
283
+ quantize_weights(model)
284
+ print(model)
285
+ output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
286
+ print("WEIGHT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
287
+
288
+ if args.activation_scheme in "dynamic":
289
+ print("Exporting model with static weights and dynamic activations")
290
+ save_quantized_model(model, args.activation_scheme, args.save_dir)
291
+ else:
292
+ assert args.activation_scheme in "static"
293
+ # Quantize activations.
294
+ quantize_activations(model, calibration_tokens=calibration_tokens)
295
+ output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
296
+ print("ACT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
297
+
298
+ print("Exporting model with static weights and static activations")
299
+ save_quantized_model(model, args.activation_scheme, args.save_dir)
300
+ ```