File size: 674 Bytes
c8dde73 a2a59cc c8dde73 a2a59cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
---
language: zh
---
# Based on bert-base-chinese
基于bert-base-chinese在`message80W`数据集(垃圾邮件二分类)上做了5个epoch的fine-tune
```python
# evaluate
with torch.no_grad():
model.eval()
eval_steps = 0
pred_list = []
label_list = []
for i, batch in enumerate(tqdm(test_loader)):
input_ids, attention_mask, label = batch
logits = model(input_ids, attention_mask)
pred_list += (torch.argmax(logits, dim=-1))
label_list += label
eval_steps += 1
```
80W数据,shuffled,8:3分train eval
下面是eval结果
![image-20220512153415505](image-20220512153415505.png)
|