vicuna / modules /extensions.py
albertsoma's picture
Upload 123 files
16c358c
raw
history blame contribute delete
No virus
4.27 kB
import traceback
from functools import partial
import gradio as gr
import extensions
import modules.shared as shared
state = {}
available_extensions = []
setup_called = set()
def load_extensions():
global state, setup_called
for i, name in enumerate(shared.args.extensions):
if name in available_extensions:
if name != 'api':
print(f'Loading the extension "{name}"... ', end='')
try:
exec(f"import extensions.{name}.script")
extension = getattr(extensions, name).script
if extension not in setup_called and hasattr(extension, "setup"):
setup_called.add(extension)
extension.setup()
state[name] = [True, i]
if name != 'api':
print('Ok.')
except:
if name != 'api':
print('Fail.')
traceback.print_exc()
# This iterator returns the extensions in the order specified in the command-line
def iterator():
for name in sorted(state, key=lambda x: state[x][1]):
if state[name][0]:
yield getattr(extensions, name).script, name
# Extension functions that map string -> string
def _apply_string_extensions(function_name, text):
for extension, _ in iterator():
if hasattr(extension, function_name):
text = getattr(extension, function_name)(text)
return text
# Input hijack of extensions
def _apply_input_hijack(text, visible_text):
for extension, _ in iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
extension.input_hijack['state'] = False
if callable(extension.input_hijack['value']):
text, visible_text = extension.input_hijack['value'](text, visible_text)
else:
text, visible_text = extension.input_hijack['value']
return text, visible_text
# custom_generate_chat_prompt handling
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
custom_generate_chat_prompt = None
for extension, _ in iterator():
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
if custom_generate_chat_prompt is not None:
return custom_generate_chat_prompt(text, state, **kwargs)
return None
# Extension functions that override the default tokenizer output
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
for extension, _ in iterator():
if hasattr(extension, function_name):
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return prompt, input_ids, input_embeds
EXTENSION_MAP = {
"input": partial(_apply_string_extensions, "input_modifier"),
"output": partial(_apply_string_extensions, "output_modifier"),
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
"input_hijack": _apply_input_hijack,
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
}
def apply_extensions(typ, *args, **kwargs):
if typ not in EXTENSION_MAP:
raise ValueError(f"Invalid extension type {typ}")
return EXTENSION_MAP[typ](*args, **kwargs)
def create_extensions_block():
global setup_called
# Updating the default values
for extension, name in iterator():
if hasattr(extension, 'params'):
for param in extension.params:
_id = f"{name}-{param}"
if _id in shared.settings:
extension.params[param] = shared.settings[_id]
should_display_ui = False
for extension, name in iterator():
if hasattr(extension, "ui"):
should_display_ui = True
# Creating the extension ui elements
if should_display_ui:
with gr.Column(elem_id="extensions"):
for extension, name in iterator():
gr.Markdown(f"\n### {name}")
if hasattr(extension, "ui"):
extension.ui()