from transformers import ( EncoderDecoderModel, AutoTokenizer ) import torch import streamlit as st PRETRAINED = "raynardj/wenyanwen-chinese-translate-to-ancient" def inference(text): tk_kwargs = dict( truncation=True, max_length=128, padding="max_length", return_tensors='pt') inputs = tokenizer([text,],**tk_kwargs) with torch.no_grad(): return tokenizer.batch_decode( model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, num_beams=3, bos_token_id=101, eos_token_id=tokenizer.sep_token_id, pad_token_id=tokenizer.pad_token_id, ), skip_special_tokens=True)[0].replace(" ","") st.title("🪕古朴 ❄️清雅 🌊壮丽") st.markdown(""" > Translate from Chinese to Ancient Chinese / 还你古朴清雅壮丽的文言文, 这[github](https://github.com/raynardj/yuan) > 最多100个中文字符 """) @st.cache(allow_output_mutation=True) def load_model(): tokenizer = AutoTokenizer.from_pretrained(PRETRAINED) model = EncoderDecoderModel.from_pretrained(PRETRAINED) return tokenizer, model tokenizer, model = load_model() text = st.text_area(value="轻轻地我走了,正如我轻轻地来。我挥一挥衣袖,不带走一片云彩。", label="输入文本") if st.button("曰"): if len(text) > 100: st.error("无过百字,若过则当答此言。") else: st.write(inference(text))