zhzluke96 commited on
Commit
f34bda5
1 Parent(s): b44532e
modules/api/impl/openai_api.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import HTTPException, Body
2
  from fastapi.responses import StreamingResponse
3
 
4
  import io
@@ -14,7 +14,7 @@ from modules.normalization import text_normalize
14
  from modules import generate_audio as generate
15
 
16
 
17
- from typing import Literal
18
  import pyrubberband as pyrb
19
 
20
  from modules.api import utils as api_utils
@@ -106,8 +106,29 @@ async def openai_speech_api(
106
  raise HTTPException(status_code=500, detail=str(e))
107
 
108
 
109
- def setup(api_manager: APIManager):
110
- api_manager.post(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  "/v1/audio/speech",
112
  response_class=FileResponse,
113
  description="""
@@ -122,3 +143,28 @@ openai api document:
122
  > model 可填任意值
123
  """,
124
  )(openai_speech_api)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import File, Form, HTTPException, Body, UploadFile
2
  from fastapi.responses import StreamingResponse
3
 
4
  import io
 
14
  from modules import generate_audio as generate
15
 
16
 
17
+ from typing import List, Literal, Optional, Union
18
  import pyrubberband as pyrb
19
 
20
  from modules.api import utils as api_utils
 
106
  raise HTTPException(status_code=500, detail=str(e))
107
 
108
 
109
+ class TranscribeSegment(BaseModel):
110
+ id: int
111
+ seek: float
112
+ start: float
113
+ end: float
114
+ text: str
115
+ tokens: List[int]
116
+ temperature: float
117
+ avg_logprob: float
118
+ compression_ratio: float
119
+ no_speech_prob: float
120
+
121
+
122
+ class TranscriptionsVerboseResponse(BaseModel):
123
+ task: str
124
+ language: str
125
+ duration: float
126
+ text: str
127
+ segments: List[TranscribeSegment]
128
+
129
+
130
+ def setup(app: APIManager):
131
+ app.post(
132
  "/v1/audio/speech",
133
  response_class=FileResponse,
134
  description="""
 
143
  > model 可填任意值
144
  """,
145
  )(openai_speech_api)
146
+
147
+ @app.post(
148
+ "/v1/audio/transcriptions",
149
+ response_class=TranscriptionsVerboseResponse,
150
+ description="WIP",
151
+ )
152
+ async def transcribe(
153
+ file: UploadFile = File(...),
154
+ model: str = Form(...),
155
+ language: Optional[str] = Form(None),
156
+ prompt: Optional[str] = Form(None),
157
+ response_format: str = Form("json"),
158
+ temperature: float = Form(0),
159
+ timestamp_granularities: List[str] = Form(["segment"]),
160
+ ):
161
+ # TODO: Implement transcribe
162
+ return {
163
+ "file": file.filename,
164
+ "model": model,
165
+ "language": language,
166
+ "prompt": prompt,
167
+ "response_format": response_format,
168
+ "temperature": temperature,
169
+ "timestamp_granularities": timestamp_granularities,
170
+ }
modules/normalization.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from modules.utils.zh_normalization.text_normlization import *
2
  import emojiswitch
3
  from modules.utils.markdown import markdown_to_text
@@ -5,12 +6,28 @@ from modules import models
5
  import re
6
 
7
 
 
8
  def is_chinese(text):
9
  # 中文字符的 Unicode 范围是 \u4e00-\u9fff
10
  chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
11
  return bool(chinese_pattern.search(text))
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  post_normalize_pipeline = []
15
  pre_normalize_pipeline = []
16
 
@@ -123,7 +140,7 @@ def apply_character_map(text):
123
 
124
  @post_normalize()
125
  def apply_emoji_map(text):
126
- lang = "zh" if is_chinese(text) else "en"
127
  return emojiswitch.demojize(text, delimiters=("", ""), lang=lang)
128
 
129
 
@@ -144,6 +161,8 @@ def replace_unk_tokens(text):
144
  """
145
  chat_tts = models.load_chat_tts()
146
  if "tokenizer" not in chat_tts.pretrain_models:
 
 
147
  return text
148
  tokenizer = chat_tts.pretrain_models["tokenizer"]
149
  vocab = tokenizer.get_vocab()
@@ -223,7 +242,7 @@ def sentence_normalize(sentence_text: str):
223
  pattern = re.compile(r"(\[.+?\])|([^[]+)")
224
 
225
  def normalize_part(part):
226
- sentences = tx.normalize(part) if is_chinese(part) else [part]
227
  dest_text = ""
228
  for sentence in sentences:
229
  sentence = apply_post_normalize(sentence)
 
1
+ from functools import lru_cache
2
  from modules.utils.zh_normalization.text_normlization import *
3
  import emojiswitch
4
  from modules.utils.markdown import markdown_to_text
 
6
  import re
7
 
8
 
9
+ @lru_cache(maxsize=64)
10
  def is_chinese(text):
11
  # 中文字符的 Unicode 范围是 \u4e00-\u9fff
12
  chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
13
  return bool(chinese_pattern.search(text))
14
 
15
 
16
+ @lru_cache(maxsize=64)
17
+ def is_eng(text):
18
+ eng_pattern = re.compile(r"[a-zA-Z]")
19
+ return bool(eng_pattern.search(text))
20
+
21
+
22
+ @lru_cache(maxsize=64)
23
+ def guess_lang(text):
24
+ if is_chinese(text):
25
+ return "zh"
26
+ if is_eng(text):
27
+ return "en"
28
+ return "zh"
29
+
30
+
31
  post_normalize_pipeline = []
32
  pre_normalize_pipeline = []
33
 
 
140
 
141
  @post_normalize()
142
  def apply_emoji_map(text):
143
+ lang = guess_lang(text)
144
  return emojiswitch.demojize(text, delimiters=("", ""), lang=lang)
145
 
146
 
 
161
  """
162
  chat_tts = models.load_chat_tts()
163
  if "tokenizer" not in chat_tts.pretrain_models:
164
+ # 这个地方只有在 huggingface spaces 中才会触发
165
+ # 因为 hugggingface 自动处理模型卸载加载,所以如果拿不到就算了...
166
  return text
167
  tokenizer = chat_tts.pretrain_models["tokenizer"]
168
  vocab = tokenizer.get_vocab()
 
242
  pattern = re.compile(r"(\[.+?\])|([^[]+)")
243
 
244
  def normalize_part(part):
245
+ sentences = tx.normalize(part) if guess_lang(part) == "zh" else [part]
246
  dest_text = ""
247
  for sentence in sentences:
248
  sentence = apply_post_normalize(sentence)
modules/utils/zh_normalization/num.py CHANGED
@@ -144,13 +144,22 @@ def replace_number(match) -> str:
144
  sign = match.group(1)
145
  number = match.group(2)
146
  pure_decimal = match.group(5)
147
- if pure_decimal:
148
- result = num2str(pure_decimal)
149
- else:
150
- sign: str = "负" if sign else ""
151
- number: str = num2str(number)
152
- result = f"{sign}{number}"
 
 
 
 
 
 
 
153
  return result
 
 
154
 
155
 
156
  # 范围表达式
 
144
  sign = match.group(1)
145
  number = match.group(2)
146
  pure_decimal = match.group(5)
147
+
148
+ # TODO 也许可以把 num2str 完全替换成 cn2an
149
+ import cn2an
150
+ text = pure_decimal if pure_decimal else f"{sign}{number}"
151
+ try:
152
+ result = cn2an.an2cn(text, "low")
153
+ except ValueError:
154
+ if pure_decimal:
155
+ result = num2str(pure_decimal)
156
+ else:
157
+ sign: str = "负" if sign else ""
158
+ number: str = num2str(number)
159
+ result = f"{sign}{number}"
160
  return result
161
+
162
+
163
 
164
 
165
  # 范围表达式
webui.py CHANGED
@@ -45,6 +45,9 @@ from modules import refiner, config
45
  from modules.utils import env, audio
46
  from modules.SentenceSplitter import SentenceSplitter
47
 
 
 
 
48
  torch._dynamo.config.cache_size_limit = 64
49
  torch._dynamo.config.suppress_errors = True
50
  torch.set_float32_matmul_precision("high")
 
45
  from modules.utils import env, audio
46
  from modules.SentenceSplitter import SentenceSplitter
47
 
48
+ # fix: If the system proxy is enabled in the Windows system, you need to skip these
49
+ os.environ["NO_PROXY"] = "localhost,127.0.0.1,0.0.0.0"
50
+
51
  torch._dynamo.config.cache_size_limit = 64
52
  torch._dynamo.config.suppress_errors = True
53
  torch.set_float32_matmul_precision("high")