File size: 674 Bytes
c8dde73
a2a59cc
c8dde73
a2a59cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
---
language: zh
---









# Based on bert-base-chinese

基于bert-base-chinese在`message80W`数据集(垃圾邮件二分类)上做了5个epoch的fine-tune



```python
# evaluate
with torch.no_grad():
    model.eval()
    eval_steps = 0
    pred_list = []
    label_list = []
    for i, batch in enumerate(tqdm(test_loader)):
        input_ids, attention_mask, label = batch
        logits = model(input_ids, attention_mask)
        pred_list += (torch.argmax(logits, dim=-1))
        label_list += label
        eval_steps += 1
```



80W数据,shuffled,8:3分train  eval

下面是eval结果

![image-20220512153415505](image-20220512153415505.png)