Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,8 @@ import torch
|
|
4 |
import gradio as gr
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
-
from transformers import GPT2LMHeadModel,
|
|
|
8 |
|
9 |
def top_k_top_p_filtering( logits, top_k=0, top_p=0.0, filter_value=-float('Inf') ):
|
10 |
assert logits.dim() == 1
|
@@ -52,13 +53,15 @@ def generate(title, context, max_len):
|
|
52 |
return result
|
53 |
|
54 |
if __name__ == '__main__':
|
55 |
-
tokenizer =
|
56 |
-
eod_id = tokenizer.convert_tokens_to_ids("<eod>")
|
57 |
-
sep_id = tokenizer.sep_token_id
|
58 |
-
unk_id = tokenizer.unk_token_id
|
59 |
|
|
|
60 |
model = GPT2LMHeadModel.from_pretrained("supermy/poetry")
|
61 |
model.eval()
|
|
|
62 |
gr.Interface(
|
63 |
fn=generate,
|
64 |
inputs=[
|
|
|
4 |
import gradio as gr
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
+
from transformers import GPT2LMHeadModel, BertTokenizer
|
8 |
+
|
9 |
|
10 |
def top_k_top_p_filtering( logits, top_k=0, top_p=0.0, filter_value=-float('Inf') ):
|
11 |
assert logits.dim() == 1
|
|
|
53 |
return result
|
54 |
|
55 |
if __name__ == '__main__':
|
56 |
+
# tokenizer = BertTokenizer(vocab_file="chinese_vocab.model")
|
57 |
+
# eod_id = tokenizer.convert_tokens_to_ids("<eod>")
|
58 |
+
# sep_id = tokenizer.sep_token_id
|
59 |
+
# unk_id = tokenizer.unk_token_id
|
60 |
|
61 |
+
tokenizer = BertTokenizer.from_pretrained("supermy/poetry")
|
62 |
model = GPT2LMHeadModel.from_pretrained("supermy/poetry")
|
63 |
model.eval()
|
64 |
+
|
65 |
gr.Interface(
|
66 |
fn=generate,
|
67 |
inputs=[
|