lengyue233 commited on
Commit
e90b8b5
1 Parent(s): 75e9ff1

fix prev token

Browse files
Files changed (1) hide show
  1. tools/llama/generate.py +7 -7
tools/llama/generate.py CHANGED
@@ -47,13 +47,13 @@ def logits_to_probs(
47
  top_p: Optional[int] = None,
48
  repetition_penalty: float = 1.0,
49
  ):
50
- # if previous_tokens is not None and repetition_penalty != 1.0:
51
- previous_tokens = previous_tokens.long()
52
- score = torch.gather(logits, dim=0, index=previous_tokens)
53
- score = torch.where(
54
- score < 0, score * repetition_penalty, score / repetition_penalty
55
- )
56
- logits.scatter_(dim=0, index=previous_tokens, src=score)
57
 
58
  # if top_p is not None and top_p < 1.0:
59
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
 
47
  top_p: Optional[int] = None,
48
  repetition_penalty: float = 1.0,
49
  ):
50
+ if previous_tokens is not None:
51
+ previous_tokens = previous_tokens.long()
52
+ score = torch.gather(logits, dim=0, index=previous_tokens)
53
+ score = torch.where(
54
+ score < 0, score * repetition_penalty, score / repetition_penalty
55
+ )
56
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
57
 
58
  # if top_p is not None and top_p < 1.0:
59
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)