lengyue233 commited on
Commit
9bfe4ad
1 Parent(s): e90b8b5

Optimize graph

Browse files
Files changed (2) hide show
  1. app.py +5 -10
  2. tools/llama/generate.py +37 -26
app.py CHANGED
@@ -41,6 +41,9 @@ Related code are released under BSD-3-Clause License, and weights are released u
41
 
42
  We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
43
  我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
 
 
 
44
  """
45
 
46
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
@@ -76,7 +79,6 @@ def inference(
76
  reference_text,
77
  max_new_tokens,
78
  chunk_length,
79
- top_k,
80
  top_p,
81
  repetition_penalty,
82
  temperature,
@@ -112,7 +114,6 @@ def inference(
112
  device=vqgan_model.device,
113
  max_new_tokens=max_new_tokens,
114
  text=text,
115
- top_k=int(top_k) if top_k > 0 else None,
116
  top_p=top_p,
117
  repetition_penalty=repetition_penalty,
118
  temperature=temperature,
@@ -194,10 +195,6 @@ def build_app():
194
  step=8,
195
  )
196
 
197
- top_k = gr.Slider(
198
- label="Top-K", minimum=0, maximum=5, value=0, step=1
199
- )
200
-
201
  top_p = gr.Slider(
202
  label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
203
  )
@@ -264,7 +261,6 @@ def build_app():
264
  reference_text,
265
  max_new_tokens,
266
  chunk_length,
267
- top_k,
268
  top_p,
269
  repetition_penalty,
270
  temperature,
@@ -310,8 +306,8 @@ if __name__ == "__main__":
310
  args.compile = True
311
  args.max_gradio_length = 1024
312
  args.tokenizer = "./checkpoints/fish-speech-1"
313
- args.llama_checkpoint_path = "./checkpoints/fish-speech-1/text2semantic-sft-large-v1-4k.pth"
314
- args.llama_config_name = "dual_ar_2_codebook_large"
315
  args.vqgan_checkpoint_path = "./checkpoints/fish-speech-1/vq-gan-group-fsq-2x1024.pth"
316
  args.vqgan_config_name = "vqgan_pretrain"
317
 
@@ -343,7 +339,6 @@ if __name__ == "__main__":
343
  reference_text="",
344
  max_new_tokens=0,
345
  chunk_length=0,
346
- top_k=0, # 0 means no limit
347
  top_p=0.7,
348
  repetition_penalty=1.5,
349
  temperature=0.7,
 
41
 
42
  We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
43
  我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
44
+
45
+ The model running in this WebUI is Fish Speech V1 Medium SFT 4K.
46
+ 在此 WebUI 中运行的模型是 Fish Speech V1 Medium SFT 4K.
47
  """
48
 
49
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
 
79
  reference_text,
80
  max_new_tokens,
81
  chunk_length,
 
82
  top_p,
83
  repetition_penalty,
84
  temperature,
 
114
  device=vqgan_model.device,
115
  max_new_tokens=max_new_tokens,
116
  text=text,
 
117
  top_p=top_p,
118
  repetition_penalty=repetition_penalty,
119
  temperature=temperature,
 
195
  step=8,
196
  )
197
 
 
 
 
 
198
  top_p = gr.Slider(
199
  label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
200
  )
 
261
  reference_text,
262
  max_new_tokens,
263
  chunk_length,
 
264
  top_p,
265
  repetition_penalty,
266
  temperature,
 
306
  args.compile = True
307
  args.max_gradio_length = 1024
308
  args.tokenizer = "./checkpoints/fish-speech-1"
309
+ args.llama_checkpoint_path = "./checkpoints/fish-speech-1/text2semantic-sft-medium-v1-4k.pth"
310
+ args.llama_config_name = "dual_ar_2_codebook_medium"
311
  args.vqgan_checkpoint_path = "./checkpoints/fish-speech-1/vq-gan-group-fsq-2x1024.pth"
312
  args.vqgan_config_name = "vqgan_pretrain"
313
 
 
339
  reference_text="",
340
  max_new_tokens=0,
341
  chunk_length=0,
 
342
  top_p=0.7,
343
  repetition_penalty=1.5,
344
  temperature=0.7,
tools/llama/generate.py CHANGED
@@ -42,11 +42,11 @@ def multinomial_sample_one_no_sync(
42
  def logits_to_probs(
43
  logits,
44
  previous_tokens: Optional[torch.Tensor] = None,
45
- temperature: float = 1.0,
46
- top_k: Optional[int] = None,
47
- top_p: Optional[int] = None,
48
- repetition_penalty: float = 1.0,
49
- ):
50
  if previous_tokens is not None:
51
  previous_tokens = previous_tokens.long()
52
  score = torch.gather(logits, dim=0, index=previous_tokens)
@@ -55,11 +55,9 @@ def logits_to_probs(
55
  )
56
  logits.scatter_(dim=0, index=previous_tokens, src=score)
57
 
58
- # if top_p is not None and top_p < 1.0:
59
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
60
- cum_probs = torch.cumsum(
61
- torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
62
- )
63
  sorted_indices_to_remove = cum_probs > top_p
64
  sorted_indices_to_remove[0] = False # keep at least one option
65
  indices_to_remove = sorted_indices_to_remove.scatter(
@@ -69,11 +67,6 @@ def logits_to_probs(
69
 
70
  logits = logits / max(temperature, 1e-5)
71
 
72
- # if top_k is not None:
73
- # v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
74
- # pivot = v.select(-1, -1).unsqueeze(-1)
75
- # logits = torch.where(logits < pivot, -float("Inf"), logits)
76
-
77
  probs = torch.nn.functional.softmax(logits, dim=-1)
78
  return probs
79
 
@@ -449,7 +442,6 @@ def generate_long(
449
  text: str,
450
  num_samples: int = 1,
451
  max_new_tokens: int = 0,
452
- top_k: int = None,
453
  top_p: int = 0.7,
454
  repetition_penalty: float = 1.5,
455
  temperature: float = 0.7,
@@ -462,6 +454,10 @@ def generate_long(
462
  prompt_tokens: Optional[torch.Tensor] = None,
463
  is_streaming: bool = False,
464
  ):
 
 
 
 
465
  model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
466
  im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
467
 
@@ -493,8 +489,18 @@ def generate_long(
493
  )
494
  logger.info(f"Encoded text: {text}")
495
 
 
 
 
 
 
 
 
 
496
  for sample_idx in range(num_samples):
497
- torch.cuda.synchronize()
 
 
498
  global_encoded = []
499
  all_codes = []
500
  seg_idx = 0
@@ -540,7 +546,6 @@ def generate_long(
540
  im_end_id=im_end_id,
541
  decode_one_token=decode_one_token,
542
  temperature=temperature,
543
- top_k=top_k,
544
  top_p=top_p,
545
  repetition_penalty=repetition_penalty,
546
  )
@@ -548,7 +553,9 @@ def generate_long(
548
  if sample_idx == 0 and seg_idx == 0 and compile:
549
  logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
550
 
551
- torch.cuda.synchronize()
 
 
552
  t = time.perf_counter() - t0
553
 
554
  tokens_generated = y.size(1) - prompt_length
@@ -559,9 +566,11 @@ def generate_long(
559
  logger.info(
560
  f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
561
  )
562
- logger.info(
563
- f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
564
- )
 
 
565
 
566
  # Put the generated tokens
567
  # since there is <im_end> and <eos> tokens, we remove last 2 tokens
@@ -654,7 +663,6 @@ def launch_thread_safe_queue(
654
  )
655
  @click.option("--num-samples", type=int, default=1)
656
  @click.option("--max-new-tokens", type=int, default=0)
657
- @click.option("--top-k", type=int, default=None)
658
  @click.option("--top-p", type=float, default=0.7)
659
  @click.option("--repetition-penalty", type=float, default=1.5)
660
  @click.option("--temperature", type=float, default=0.7)
@@ -678,7 +686,6 @@ def main(
678
  prompt_tokens: Optional[Path],
679
  num_samples: int,
680
  max_new_tokens: int,
681
- top_k: int,
682
  top_p: int,
683
  repetition_penalty: float,
684
  temperature: float,
@@ -702,7 +709,10 @@ def main(
702
  model, decode_one_token = load_model(
703
  config_name, checkpoint_path, device, precision, max_length, compile=compile
704
  )
705
- torch.cuda.synchronize()
 
 
 
706
  logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
707
 
708
  prompt_tokens = (
@@ -713,7 +723,9 @@ def main(
713
 
714
  tokenizer = AutoTokenizer.from_pretrained(tokenizer)
715
  torch.manual_seed(seed)
716
- torch.cuda.manual_seed(seed)
 
 
717
 
718
  generator = generate_long(
719
  model=model,
@@ -722,7 +734,6 @@ def main(
722
  text=text,
723
  num_samples=num_samples,
724
  max_new_tokens=max_new_tokens,
725
- top_k=top_k,
726
  top_p=top_p,
727
  repetition_penalty=repetition_penalty,
728
  temperature=temperature,
 
42
  def logits_to_probs(
43
  logits,
44
  previous_tokens: Optional[torch.Tensor] = None,
45
+ temperature: torch.Tensor = 1.0,
46
+ top_p: torch.Tensor = 1.0,
47
+ repetition_penalty: torch.Tensor = 1.0,
48
+ ) -> torch.Tensor:
49
+ # Apply repetition penalty
50
  if previous_tokens is not None:
51
  previous_tokens = previous_tokens.long()
52
  score = torch.gather(logits, dim=0, index=previous_tokens)
 
55
  )
56
  logits.scatter_(dim=0, index=previous_tokens, src=score)
57
 
58
+ # Apply top-p sampling
59
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
60
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
 
 
61
  sorted_indices_to_remove = cum_probs > top_p
62
  sorted_indices_to_remove[0] = False # keep at least one option
63
  indices_to_remove = sorted_indices_to_remove.scatter(
 
67
 
68
  logits = logits / max(temperature, 1e-5)
69
 
 
 
 
 
 
70
  probs = torch.nn.functional.softmax(logits, dim=-1)
71
  return probs
72
 
 
442
  text: str,
443
  num_samples: int = 1,
444
  max_new_tokens: int = 0,
 
445
  top_p: int = 0.7,
446
  repetition_penalty: float = 1.5,
447
  temperature: float = 0.7,
 
454
  prompt_tokens: Optional[torch.Tensor] = None,
455
  is_streaming: bool = False,
456
  ):
457
+ assert 0 < top_p <= 1, "top_p must be in (0, 1]"
458
+ assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
459
+ assert 0 < temperature < 2, "temperature must be in (0, 2)"
460
+
461
  model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
462
  im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
463
 
 
489
  )
490
  logger.info(f"Encoded text: {text}")
491
 
492
+ # Move temperature, top_p, repetition_penalty to device
493
+ # This is important so that changing params doesn't trigger recompile
494
+ temperature = torch.tensor(temperature, device=device, dtype=torch.float)
495
+ top_p = torch.tensor(top_p, device=device, dtype=torch.float)
496
+ repetition_penalty = torch.tensor(
497
+ repetition_penalty, device=device, dtype=torch.float
498
+ )
499
+
500
  for sample_idx in range(num_samples):
501
+ if torch.cuda.is_available():
502
+ torch.cuda.synchronize()
503
+
504
  global_encoded = []
505
  all_codes = []
506
  seg_idx = 0
 
546
  im_end_id=im_end_id,
547
  decode_one_token=decode_one_token,
548
  temperature=temperature,
 
549
  top_p=top_p,
550
  repetition_penalty=repetition_penalty,
551
  )
 
553
  if sample_idx == 0 and seg_idx == 0 and compile:
554
  logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
555
 
556
+ if torch.cuda.is_available():
557
+ torch.cuda.synchronize()
558
+
559
  t = time.perf_counter() - t0
560
 
561
  tokens_generated = y.size(1) - prompt_length
 
566
  logger.info(
567
  f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
568
  )
569
+
570
+ if torch.cuda.is_available():
571
+ logger.info(
572
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
573
+ )
574
 
575
  # Put the generated tokens
576
  # since there is <im_end> and <eos> tokens, we remove last 2 tokens
 
663
  )
664
  @click.option("--num-samples", type=int, default=1)
665
  @click.option("--max-new-tokens", type=int, default=0)
 
666
  @click.option("--top-p", type=float, default=0.7)
667
  @click.option("--repetition-penalty", type=float, default=1.5)
668
  @click.option("--temperature", type=float, default=0.7)
 
686
  prompt_tokens: Optional[Path],
687
  num_samples: int,
688
  max_new_tokens: int,
 
689
  top_p: int,
690
  repetition_penalty: float,
691
  temperature: float,
 
709
  model, decode_one_token = load_model(
710
  config_name, checkpoint_path, device, precision, max_length, compile=compile
711
  )
712
+
713
+ if torch.cuda.is_available():
714
+ torch.cuda.synchronize()
715
+
716
  logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
717
 
718
  prompt_tokens = (
 
723
 
724
  tokenizer = AutoTokenizer.from_pretrained(tokenizer)
725
  torch.manual_seed(seed)
726
+
727
+ if torch.cuda.is_available():
728
+ torch.cuda.manual_seed(seed)
729
 
730
  generator = generate_long(
731
  model=model,
 
734
  text=text,
735
  num_samples=num_samples,
736
  max_new_tokens=max_new_tokens,
 
737
  top_p=top_p,
738
  repetition_penalty=repetition_penalty,
739
  temperature=temperature,