|
from transformers import ( |
|
EncoderDecoderModel, |
|
AutoTokenizer |
|
) |
|
import torch |
|
|
|
PRETRAINED = "raynardj/wenyanwen-chinese-translate-to-ancient" |
|
|
|
def inference(text): |
|
tk_kwargs = dict( |
|
truncation=True, |
|
max_length=128, |
|
padding="max_length", |
|
return_tensors='pt') |
|
|
|
inputs = tokenizer([text,],**tk_kwargs) |
|
with torch.no_grad(): |
|
return tokenizer.batch_decode( |
|
model.generate( |
|
inputs.input_ids, |
|
attention_mask=inputs.attention_mask, |
|
num_beams=3, |
|
bos_token_id=101, |
|
eos_token_id=tokenizer.sep_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
), skip_special_tokens=True) |
|
|
|
import streamlit as st |
|
|
|
st.title("Wenyanwen Translator") |
|
st.markdown(""" |
|
# Translate from Chinese to Ancient Chinese / 还你古朴清雅壮丽的文言文, 这[github](https://github.com/raynardj/yuan) |
|
""") |
|
|
|
@st.cache |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED) |
|
model = EncoderDecoderModel.from_pretrained(PRETRAINED) |
|
return tokenizer, model |
|
|
|
tokenizer, model = load_model() |
|
|
|
text = st.text_area("轻轻地我走了,正如我轻轻地来。我挥一挥衣袖,不带走一片云彩。") |
|
|
|
st.write(inference(text)[0]) |