KingNish commited on
Commit
f3cfba4
1 Parent(s): faa634d

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +13 -12
chatbot.py CHANGED
@@ -315,7 +315,7 @@ def model_inference(
315
  temperature,
316
  max_new_tokens,
317
  repetition_penalty,
318
- top_p,
319
  web_search,
320
  ):
321
  # Define generation_args at the beginning of the function
@@ -332,6 +332,7 @@ def model_inference(
332
  generate_kwargs = dict(
333
  max_new_tokens=4000,
334
  do_sample=True,
 
335
  )
336
  # Format the prompt for the language model
337
  formatted_prompt = format_prompt(
@@ -351,6 +352,7 @@ def model_inference(
351
  generate_kwargs = dict(
352
  max_new_tokens=5000,
353
  do_sample=True,
 
354
  )
355
  # Format the prompt for the language model
356
  formatted_prompt = format_prompt(
@@ -389,16 +391,15 @@ def model_inference(
389
  }
390
  assert decoding_strategy in [
391
  "Greedy",
392
- "Top P Sampling",
393
  ]
394
 
395
  if decoding_strategy == "Greedy":
396
  generation_args["do_sample"] = False
397
- elif decoding_strategy == "Top P Sampling":
398
  generation_args["temperature"] = temperature
399
  generation_args["do_sample"] = True
400
- generation_args["top_p"] = top_p
401
- # Creating model inputs
402
  (
403
  resulting_text,
404
  resulting_images,
@@ -440,7 +441,7 @@ FEATURES = datasets.Features(
440
  "temperature": datasets.Value("float32"),
441
  "max_new_tokens": datasets.Value("int32"),
442
  "repetition_penalty": datasets.Value("float32"),
443
- "top_p": datasets.Value("int32"),
444
  }
445
  )
446
 
@@ -465,9 +466,9 @@ repetition_penalty = gr.Slider(
465
  decoding_strategy = gr.Radio(
466
  [
467
  "Greedy",
468
- "Top P Sampling",
469
  ],
470
- value="Top P Sampling",
471
  label="Decoding strategy",
472
  interactive=True,
473
  info="Higher values are equivalent to sampling more low-probability tokens.",
@@ -482,14 +483,14 @@ temperature = gr.Slider(
482
  label="Sampling temperature",
483
  info="Higher values will produce more diverse outputs.",
484
  )
485
- top_p = gr.Slider(
486
  minimum=0.01,
487
- maximum=0.99,
488
- value=0.9,
489
  step=0.01,
490
  visible=True,
491
  interactive=True,
492
- label="Top P",
493
  info="Higher values are equivalent to sampling more low-probability tokens.",
494
  )
495
 
 
315
  temperature,
316
  max_new_tokens,
317
  repetition_penalty,
318
+ min_p,
319
  web_search,
320
  ):
321
  # Define generation_args at the beginning of the function
 
332
  generate_kwargs = dict(
333
  max_new_tokens=4000,
334
  do_sample=True,
335
+ min_p=0.08,
336
  )
337
  # Format the prompt for the language model
338
  formatted_prompt = format_prompt(
 
352
  generate_kwargs = dict(
353
  max_new_tokens=5000,
354
  do_sample=True,
355
+ min_p=0.08,
356
  )
357
  # Format the prompt for the language model
358
  formatted_prompt = format_prompt(
 
391
  }
392
  assert decoding_strategy in [
393
  "Greedy",
394
+ "Min P Sampling",
395
  ]
396
 
397
  if decoding_strategy == "Greedy":
398
  generation_args["do_sample"] = False
399
+ elif decoding_strategy == "Min P Sampling":
400
  generation_args["temperature"] = temperature
401
  generation_args["do_sample"] = True
402
+ generation_args["min_p"] = min_p
 
403
  (
404
  resulting_text,
405
  resulting_images,
 
441
  "temperature": datasets.Value("float32"),
442
  "max_new_tokens": datasets.Value("int32"),
443
  "repetition_penalty": datasets.Value("float32"),
444
+ "min_p": datasets.Value("int32"),
445
  }
446
  )
447
 
 
466
  decoding_strategy = gr.Radio(
467
  [
468
  "Greedy",
469
+ "Min P Sampling",
470
  ],
471
+ value="Min P Sampling",
472
  label="Decoding strategy",
473
  interactive=True,
474
  info="Higher values are equivalent to sampling more low-probability tokens.",
 
483
  label="Sampling temperature",
484
  info="Higher values will produce more diverse outputs.",
485
  )
486
+ min_p = gr.Slider(
487
  minimum=0.01,
488
+ maximum=0.49,
489
+ value=0.08,
490
  step=0.01,
491
  visible=True,
492
  interactive=True,
493
+ label="Min P",
494
  info="Higher values are equivalent to sampling more low-probability tokens.",
495
  )
496