Plat commited on
Commit
f8fa551
1 Parent(s): 8c0ecc1

chore: onnx model

Browse files
Files changed (3) hide show
  1. app.py +113 -16
  2. conversion.py +7 -0
  3. requirements.txt +2 -0
app.py CHANGED
@@ -1,31 +1,128 @@
1
  import gradio as gr
 
2
 
3
  # モデルのロード
4
- model = gr.load("models/Miwa-Keita/zenz-v1-checkpoints")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # 入力を調整する関数
7
- def preprocess_input(user_input):
8
  prefix = "\uEE00" # 前に付与する文字列
9
  suffix = "\uEE01" # 後ろに付与する文字列
10
- processed_input = prefix + user_input + suffix
11
- return model(processed_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # 出力を調整する関数
14
- def postprocess_output(model_output):
15
  suffix = "\uEE01"
16
  # \uEE01の後の部分を抽出
17
- if suffix in model_output:
18
- return model_output.split(suffix)[1]
19
- return model_output
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # インターフェースを定義
22
- iface = gr.Interface(
23
- fn=lambda x: postprocess_output(preprocess_input(x)),
24
- inputs=gr.Textbox(label="変換する文字列(カタカナ)"),
25
- outputs=gr.Textbox(label="変換結果"),
26
- title="ニューラルかな漢字変換モデルzenz-v1のデモ",
27
- description="変換したい文字列をカタカナを入力してください"
28
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # ローンチ
31
- iface.launch()
 
 
1
  import gradio as gr
2
+ from optimum.pipelines import pipeline
3
 
4
  # モデルのロード
5
+ MODEL_NAME = "p1atdev/zenz-v1-onnx"
6
+ pipe = pipeline("text-generation", MODEL_NAME)
7
+
8
+
9
+ # ひらがなをカタカナに変換する関数
10
+ def hiragana_to_katakana(hiragana: str):
11
+ katakana = ""
12
+ for char in hiragana:
13
+ # ひらがなの文字コードの範囲はU+3041からU+309F
14
+ if 0x3041 <= ord(char) <= 0x309F:
15
+ katakana += chr(ord(char) + 0x60)
16
+ else:
17
+ katakana += char
18
+ return katakana
19
+
20
 
21
  # 入力を調整する関数
22
+ def preprocess_input(user_input: str):
23
  prefix = "\uEE00" # 前に付与する文字列
24
  suffix = "\uEE01" # 後ろに付与する文字列
25
+ processed_input = prefix + hiragana_to_katakana(user_input) + suffix
26
+
27
+ return processed_input
28
+
29
+
30
+ # 出力を生成する関数
31
+ def generate_output(input_text: str, num_beams: int = 4):
32
+ generated_outputs = pipe(
33
+ input_text,
34
+ max_new_tokens=len(input_text) * 2,
35
+ num_beams=num_beams,
36
+ num_return_sequences=num_beams,
37
+ )
38
+ generated_texts = [output["generated_text"] for output in generated_outputs] # type: ignore
39
+ return generated_texts
40
+
41
 
42
  # 出力を調整する関数
43
+ def postprocess_output(model_outputs: list[str]):
44
  suffix = "\uEE01"
45
  # \uEE01の後の部分を抽出
46
+ for i, model_output in enumerate(model_outputs):
47
+ if suffix in model_output:
48
+ model_outputs[i] = model_output.split(suffix)[1]
49
+ return "\n".join(
50
+ [f"{i+1}: {model_output}" for i, model_output in enumerate(model_outputs)]
51
+ )
52
+
53
+
54
+ # 変換処理をまとめる関数
55
+ def process_conversion(user_input: str, num_beams: int = 4):
56
+ processed_input = preprocess_input(user_input)
57
+ generated_outputs = generate_output(processed_input, num_beams)
58
+ postprocessed_output = postprocess_output(generated_outputs)
59
+ return postprocessed_output
60
+
61
 
62
  # インターフェースを定義
63
+ def interface():
64
+ with gr.Blocks() as ui:
65
+ gr.Markdown(
66
+ """## ニューラルかな漢字変換モデルzenz-v1のデモ (ONNX版)
67
+ 変換したい文字列をひらがな・カタカナを入力してください"""
68
+ )
69
+
70
+ with gr.Row():
71
+ with gr.Column():
72
+ input_text = gr.TextArea(
73
+ label="変換する文字列(ひらがな・カタカナ)",
74
+ info="変換したいテキストをひらがなかカタカナで入力します。入力すると右に反映されます。",
75
+ )
76
+ num_beams = gr.Slider(
77
+ label="候補数",
78
+ info="多くするとより変換に時間がかかります",
79
+ minimum=1,
80
+ maximum=20,
81
+ step=1,
82
+ value=4,
83
+ )
84
+
85
+ with gr.Column():
86
+ output_text = gr.TextArea(
87
+ label="変換結果 (リアルタイム反映)",
88
+ info="変換候補が出力されます。上の候補ほど確信度が高いです。",
89
+ )
90
+
91
+ gr.Examples(
92
+ examples=[
93
+ ["きめつのえいがをみました"],
94
+ ["はがいたいのでしかいにみてもらった"],
95
+ ["くつろぐにふといでかんたといいます"],
96
+ ["けんかをかった"],
97
+ ["けんかにかった"],
98
+ ["こうえんをおねがいする"],
99
+ ["こうえんでおねがいする"],
100
+ ["つきむらてまり"],
101
+ ],
102
+ inputs=[input_text],
103
+ )
104
+
105
+ gr.Markdown(
106
+ """\
107
+ - 使用しているモデル (ONNX): [p1atdev/zenz-v1-onnx](https://huggingface.co/p1atdev/zenz-v1-onnx)
108
+ - オリジナル(変換元)のモデル: [Miwa-Keita/zenz-v1-checkpoints](https://huggingface.co/Miwa-Keita/zenz-v1-checkpoints)
109
+ """
110
+ )
111
+
112
+ input_text.change(
113
+ fn=process_conversion,
114
+ inputs=[input_text, num_beams],
115
+ outputs=output_text,
116
+ )
117
+ num_beams.change(
118
+ fn=process_conversion,
119
+ inputs=[input_text, num_beams],
120
+ outputs=output_text,
121
+ )
122
+
123
+ ui.launch()
124
+
125
 
126
  # ローンチ
127
+ if __name__ == "__main__":
128
+ interface()
conversion.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from optimum.onnxruntime import ORTModelForCausalLM
2
+
3
+ MODEL_NAME = "Miwa-Keita/zenz-v1-checkpoints"
4
+
5
+ ort_model = ORTModelForCausalLM.from_pretrained(MODEL_NAME, export=True)
6
+
7
+ ort_model.save_pretrained(save_directory="./onnx")
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ optimum[onnxruntime]