raynardj
🍶 baseline
735386f
raw
history blame
1.26 kB
from transformers import (
EncoderDecoderModel,
AutoTokenizer
)
import torch
PRETRAINED = "raynardj/wenyanwen-chinese-translate-to-ancient"
def inference(text):
tk_kwargs = dict(
truncation=True,
max_length=128,
padding="max_length",
return_tensors='pt')
inputs = tokenizer([text,],**tk_kwargs)
with torch.no_grad():
return tokenizer.batch_decode(
model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
num_beams=3,
bos_token_id=101,
eos_token_id=tokenizer.sep_token_id,
pad_token_id=tokenizer.pad_token_id,
), skip_special_tokens=True)
import streamlit as st
st.title("Wenyanwen Translator")
st.markdown("""
# Translate from Chinese to Ancient Chinese / 还你古朴清雅壮丽的文言文, 这[github](https://github.com/raynardj/yuan)
""")
@st.cache
def load_model():
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)
model = EncoderDecoderModel.from_pretrained(PRETRAINED)
return tokenizer, model
tokenizer, model = load_model()
text = st.text_area("轻轻地我走了,正如我轻轻地来。我挥一挥衣袖,不带走一片云彩。")
st.write(inference(text)[0])