zhzluke96 commited on
Commit
8f52106
1 Parent(s): 0129fb6
Files changed (1) hide show
  1. modules/models.py +35 -8
modules/models.py CHANGED
@@ -1,18 +1,23 @@
 
1
  import torch
2
  from modules.ChatTTS import ChatTTS
3
  from modules import config
4
  from modules.devices import devices
5
 
6
  import logging
 
7
 
8
  logger = logging.getLogger(__name__)
 
9
  chat_tts = None
 
10
 
11
 
12
- def load_chat_tts():
13
  global chat_tts
14
  if chat_tts:
15
- return chat_tts
 
16
 
17
  chat_tts = ChatTTS.Chat()
18
  chat_tts.load_models(
@@ -28,18 +33,40 @@ def load_chat_tts():
28
  )
29
 
30
  devices.torch_gc()
 
 
 
 
 
 
31
 
 
 
 
 
 
32
  return chat_tts
33
 
34
 
35
- def reload_chat_tts():
36
- logging.info("Reloading ChatTTS models")
37
  global chat_tts
 
38
  if chat_tts:
 
 
 
 
39
  if torch.cuda.is_available():
40
- for model_name, model in chat_tts.pretrain_models.items():
41
- if isinstance(model, torch.nn.Module):
42
- model.cpu()
43
  torch.cuda.empty_cache()
 
44
  chat_tts = None
45
- return load_chat_tts()
 
 
 
 
 
 
 
 
 
1
+ import threading
2
  import torch
3
  from modules.ChatTTS import ChatTTS
4
  from modules import config
5
  from modules.devices import devices
6
 
7
  import logging
8
+ import gc
9
 
10
  logger = logging.getLogger(__name__)
11
+
12
  chat_tts = None
13
+ load_event = threading.Event()
14
 
15
 
16
+ def load_chat_tts_in_thread():
17
  global chat_tts
18
  if chat_tts:
19
+ load_event.set() # 如果已经加载过,直接设置事件
20
+ return
21
 
22
  chat_tts = ChatTTS.Chat()
23
  chat_tts.load_models(
 
33
  )
34
 
35
  devices.torch_gc()
36
+ load_event.set() # 设置事件,表示加载完成
37
+
38
+
39
+ def initialize_chat_tts():
40
+ model_thread = threading.Thread(target=load_chat_tts_in_thread)
41
+ model_thread.start()
42
 
43
+
44
+ def load_chat_tts():
45
+ if chat_tts is None:
46
+ initialize_chat_tts()
47
+ load_event.wait()
48
  return chat_tts
49
 
50
 
51
+ def unload_chat_tts():
52
+ logging.info("Unloading ChatTTS models")
53
  global chat_tts
54
+
55
  if chat_tts:
56
+ for model_name, model in chat_tts.pretrain_models.items():
57
+ if isinstance(model, torch.nn.Module):
58
+ model.cpu()
59
+ del model
60
  if torch.cuda.is_available():
 
 
 
61
  torch.cuda.empty_cache()
62
+ gc.collect()
63
  chat_tts = None
64
+ logger.info("ChatTTS models unloaded")
65
+
66
+
67
+ def reload_chat_tts():
68
+ logging.info("Reloading ChatTTS models")
69
+ unload_chat_tts()
70
+ instance = load_chat_tts()
71
+ logger.info("ChatTTS models reloaded")
72
+ return instance