import argparse import time from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig import torch from threading import Thread MODEL_PATH = 'THUDM/glm-4-9b-chat' def stress_test(token_len, n, num_gpu): device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu") tokenizer = AutoTokenizer.from_pretrained( MODEL_PATH, trust_remote_code=True, padding_side="left" ) model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16 ).to(device).eval() # Use INT4 weight infer # model = AutoModelForCausalLM.from_pretrained( # MODEL_PATH, # trust_remote_code=True, # quantization_config=BitsAndBytesConfig(load_in_4bit=True), # low_cpu_mem_usage=True, # ).eval() times = [] decode_times = [] print("Warming up...") vocab_size = tokenizer.vocab_size warmup_token_len = 20 random_token_ids = torch.randint(3, vocab_size - 200, (warmup_token_len - 5,), dtype=torch.long) start_tokens = [151331, 151333, 151336, 198] end_tokens = [151337] input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(0).to( device) attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device) position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device) warmup_inputs = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids } with torch.no_grad(): _ = model.generate( input_ids=warmup_inputs['input_ids'], attention_mask=warmup_inputs['attention_mask'], max_new_tokens=2048, do_sample=False, repetition_penalty=1.0, eos_token_id=[151329, 151336, 151338] ) print("Warming up complete. Starting stress test...") for i in range(n): random_token_ids = torch.randint(3, vocab_size - 200, (token_len - 5,), dtype=torch.long) input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze( 0).to(device) attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device) position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device) test_inputs = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids } streamer = TextIteratorStreamer( tokenizer=tokenizer, timeout=36000, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = { "input_ids": test_inputs['input_ids'], "attention_mask": test_inputs['attention_mask'], "max_new_tokens": 512, "do_sample": False, "repetition_penalty": 1.0, "eos_token_id": [151329, 151336, 151338], "streamer": streamer } start_time = time.time() t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() first_token_time = None all_token_times = [] for token in streamer: current_time = time.time() if first_token_time is None: first_token_time = current_time times.append(first_token_time - start_time) all_token_times.append(current_time) t.join() end_time = time.time() avg_decode_time_per_token = len(all_token_times) / (end_time - first_token_time) if all_token_times else 0 decode_times.append(avg_decode_time_per_token) print( f"Iteration {i + 1}/{n} - Prefilling Time: {times[-1]:.4f} seconds - Average Decode Time: {avg_decode_time_per_token:.4f} tokens/second") torch.cuda.empty_cache() avg_first_token_time = sum(times) / n avg_decode_time = sum(decode_times) / n print(f"\nAverage First Token Time over {n} iterations: {avg_first_token_time:.4f} seconds") print(f"Average Decode Time per Token over {n} iterations: {avg_decode_time:.4f} tokens/second") return times, avg_first_token_time, decode_times, avg_decode_time def main(): parser = argparse.ArgumentParser(description="Stress test for model inference") parser.add_argument('--token_len', type=int, default=1000, help='Number of tokens for each test') parser.add_argument('--n', type=int, default=3, help='Number of iterations for the stress test') parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPUs to use for inference') args = parser.parse_args() token_len = args.token_len n = args.n num_gpu = args.num_gpu stress_test(token_len, n, num_gpu) if __name__ == "__main__": main()