liujch1998 commited on
Commit
9bebd75
1 Parent(s): 7eae3e8

Add accelerate

Browse files
Files changed (2) hide show
  1. app.py +0 -9
  2. requirements.txt +3 -1
app.py CHANGED
@@ -3,15 +3,6 @@ import os
3
  import torch
4
  import transformers
5
 
6
- def reduce_sum(value, mask, axis=None):
7
- if axis is None:
8
- return torch.sum(value * mask)
9
- return torch.sum(value * mask, axis)
10
- def reduce_mean(value, mask, axis=None):
11
- if axis is None:
12
- return torch.sum(value * mask) / torch.sum(mask)
13
- return reduce_sum(value, mask, axis) / torch.sum(mask, axis)
14
-
15
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
16
 
17
  HF_TOKEN_DOWNLOAD = os.environ.get('HF_TOKEN_DOWNLOAD')
 
3
  import torch
4
  import transformers
5
 
 
 
 
 
 
 
 
 
 
6
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
7
 
8
  HF_TOKEN_DOWNLOAD = os.environ.get('HF_TOKEN_DOWNLOAD')
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  torch
2
  transformers
3
  tokenizers
4
- sentencepiece
 
 
 
1
  torch
2
  transformers
3
  tokenizers
4
+ sentencepiece
5
+ huggingface_hub
6
+ accelerate