.pre-commit-config.yaml CHANGED
@@ -1,7 +1,7 @@
1
  exclude: patch
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.4.0
5
  hooks:
6
  - id: check-executables-have-shebangs
7
  - id: check-json
@@ -9,43 +9,29 @@ repos:
9
  - id: check-shebang-scripts-are-executable
10
  - id: check-toml
11
  - id: check-yaml
 
12
  - id: end-of-file-fixer
13
  - id: mixed-line-ending
14
- args: ["--fix=lf"]
15
  - id: requirements-txt-fixer
16
  - id: trailing-whitespace
17
  - repo: https://github.com/myint/docformatter
18
- rev: v1.7.5
19
  hooks:
20
  - id: docformatter
21
- args: ["--in-place"]
22
  - repo: https://github.com/pycqa/isort
23
  rev: 5.12.0
24
  hooks:
25
  - id: isort
26
- args: ["--profile", "black"]
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v1.5.1
29
  hooks:
30
  - id: mypy
31
- args: ["--ignore-missing-imports"]
32
- additional_dependencies: ["types-python-slugify", "types-requests", "types-PyYAML"]
33
- - repo: https://github.com/psf/black
34
- rev: 23.9.0
35
  hooks:
36
- - id: black
37
- language_version: python3.10
38
- args: ["--line-length", "119"]
39
- - repo: https://github.com/kynan/nbstripout
40
- rev: 0.6.1
41
- hooks:
42
- - id: nbstripout
43
- args: ["--extra-keys", "metadata.interpreter metadata.kernelspec cell.metadata.pycharm"]
44
- - repo: https://github.com/nbQA-dev/nbQA
45
- rev: 1.7.0
46
- hooks:
47
- - id: nbqa-black
48
- - id: nbqa-pyupgrade
49
- args: ["--py37-plus"]
50
- - id: nbqa-isort
51
- args: ["--float-to-top"]
 
1
  exclude: patch
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
  hooks:
6
  - id: check-executables-have-shebangs
7
  - id: check-json
 
9
  - id: check-shebang-scripts-are-executable
10
  - id: check-toml
11
  - id: check-yaml
12
+ - id: double-quote-string-fixer
13
  - id: end-of-file-fixer
14
  - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
  - id: requirements-txt-fixer
17
  - id: trailing-whitespace
18
  - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
  hooks:
21
  - id: docformatter
22
+ args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
  rev: 5.12.0
25
  hooks:
26
  - id: isort
 
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
  hooks:
30
  - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
  hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
.vscode/settings.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "[python]": {
3
- "editor.defaultFormatter": "ms-python.black-formatter",
4
- "editor.formatOnType": true,
5
- "editor.codeActionsOnSave": {
6
- "source.organizeImports": true
7
- }
8
- },
9
- "black-formatter.args": [
10
- "--line-length=119"
11
- ],
12
- "isort.args": ["--profile", "black"],
13
- "flake8.args": [
14
- "--max-line-length=119"
15
- ],
16
- "ruff.args": [
17
- "--line-length=119"
18
- ],
19
- "editor.formatOnSave": true,
20
- "files.insertFinalNewline": true
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile CHANGED
@@ -35,7 +35,7 @@ WORKDIR ${HOME}/app
35
 
36
  RUN curl https://pyenv.run | bash
37
  ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
38
- ARG PYTHON_VERSION=3.10.11
39
  RUN pyenv install ${PYTHON_VERSION} && \
40
  pyenv global ${PYTHON_VERSION} && \
41
  pyenv rehash && \
@@ -44,8 +44,6 @@ RUN pyenv install ${PYTHON_VERSION} && \
44
  RUN pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1
45
  COPY --chown=1000 requirements.txt /tmp/requirements.txt
46
  RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
47
- COPY --chown=1000 requirements-monitor.txt /tmp/requirements-monitor.txt
48
- RUN pip install --no-cache-dir -U -r /tmp/requirements-monitor.txt
49
 
50
  COPY --chown=1000 . ${HOME}/app
51
  RUN cd Tune-A-Video && patch -p1 < ../patch
 
35
 
36
  RUN curl https://pyenv.run | bash
37
  ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
38
+ ENV PYTHON_VERSION=3.10.9
39
  RUN pyenv install ${PYTHON_VERSION} && \
40
  pyenv global ${PYTHON_VERSION} && \
41
  pyenv rehash && \
 
44
  RUN pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1
45
  COPY --chown=1000 requirements.txt /tmp/requirements.txt
46
  RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
 
 
47
 
48
  COPY --chown=1000 . ${HOME}/app
49
  RUN cd Tune-A-Video && patch -p1 < ../patch
app.py CHANGED
@@ -9,43 +9,43 @@ import gradio as gr
9
  import torch
10
 
11
  from app_inference import create_inference_demo
12
- from app_system_monitor import create_monitor_demo
13
  from app_training import create_training_demo
14
  from app_upload import create_upload_demo
15
  from inference import InferencePipeline
16
  from trainer import Trainer
17
 
18
- TITLE = "# [Tune-A-Video](https://tuneavideo.github.io/)"
19
 
20
- ORIGINAL_SPACE_ID = "Tune-A-Video-library/Tune-A-Video-Training-UI"
21
- SPACE_ID = os.getenv("SPACE_ID")
22
- GPU_DATA = getoutput("nvidia-smi")
23
- SHARED_UI_WARNING = f"""## Attention - Training doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
24
 
25
  <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
26
- """
27
 
28
- IS_SHARED_UI = SPACE_ID == ORIGINAL_SPACE_ID
29
- if os.getenv("SYSTEM") == "spaces" and SPACE_ID != ORIGINAL_SPACE_ID:
30
  SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
31
  else:
32
- SETTINGS = "Settings"
33
 
34
- INVALID_GPU_WARNING = f"""## Attention - the specified GPU is invalid. Training may not work. Make sure you have selected a `T4 GPU` for this task."""
35
 
36
- CUDA_NOT_AVAILABLE_WARNING = f"""## Attention - Running on CPU.
37
  <center>
38
  You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
39
  You can use "T4 small/medium" to run this demo.
40
  </center>
41
- """
42
 
43
- HF_TOKEN_NOT_SPECIFIED_WARNING = f"""The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.
44
-
45
- You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>. You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
46
- """
 
 
47
 
48
- HF_TOKEN = os.getenv("HF_TOKEN")
49
 
50
 
51
  def show_warning(warning_text: str) -> gr.Blocks:
@@ -56,36 +56,30 @@ def show_warning(warning_text: str) -> gr.Blocks:
56
 
57
 
58
  pipe = InferencePipeline(HF_TOKEN)
59
- trainer = Trainer()
60
 
61
- with gr.Blocks(css="style.css") as demo:
62
- if IS_SHARED_UI:
63
  show_warning(SHARED_UI_WARNING)
64
  elif not torch.cuda.is_available():
65
  show_warning(CUDA_NOT_AVAILABLE_WARNING)
66
- elif "T4" not in GPU_DATA:
67
  show_warning(INVALID_GPU_WARNING)
 
68
 
69
  gr.Markdown(TITLE)
70
  with gr.Tabs():
71
- with gr.TabItem("Train"):
72
- create_training_demo(trainer, pipe, disable_run_button=IS_SHARED_UI)
73
- with gr.TabItem("Run"):
74
- create_inference_demo(pipe, HF_TOKEN, disable_run_button=IS_SHARED_UI)
75
- with gr.TabItem("Upload"):
76
- gr.Markdown(
77
- """
78
  - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
79
- """
80
- )
81
- create_upload_demo(disable_run_button=IS_SHARED_UI)
82
-
83
- with gr.Row():
84
- if not IS_SHARED_UI and not os.getenv("DISABLE_SYSTEM_MONITOR"):
85
- with gr.Accordion(label="System info", open=False):
86
- create_monitor_demo()
87
-
88
  if not HF_TOKEN:
89
  show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
90
 
91
- demo.queue(api_open=False, max_size=1).launch()
 
9
  import torch
10
 
11
  from app_inference import create_inference_demo
 
12
  from app_training import create_training_demo
13
  from app_upload import create_upload_demo
14
  from inference import InferencePipeline
15
  from trainer import Trainer
16
 
17
+ TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/) UI'
18
 
19
+ ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
20
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
21
+ GPU_DATA = getoutput('nvidia-smi')
22
+ SHARED_UI_WARNING = f'''## Attention - Training doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
23
 
24
  <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
25
+ '''
26
 
27
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
 
28
  SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
29
  else:
30
+ SETTINGS = 'Settings'
31
 
32
+ INVALID_GPU_WARNING = f'''## Attention - the specified GPU is invalid. Training may not work. Make sure you have selected a `T4 GPU` for this task.'''
33
 
34
+ CUDA_NOT_AVAILABLE_WARNING = f'''## Attention - Running on CPU.
35
  <center>
36
  You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
37
  You can use "T4 small/medium" to run this demo.
38
  </center>
39
+ '''
40
 
41
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.
42
+ <center>
43
+ You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
44
+ You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
45
+ </center>
46
+ '''
47
 
48
+ HF_TOKEN = os.getenv('HF_TOKEN')
49
 
50
 
51
  def show_warning(warning_text: str) -> gr.Blocks:
 
56
 
57
 
58
  pipe = InferencePipeline(HF_TOKEN)
59
+ trainer = Trainer(HF_TOKEN)
60
 
61
+ with gr.Blocks(css='style.css') as demo:
62
+ if SPACE_ID == ORIGINAL_SPACE_ID:
63
  show_warning(SHARED_UI_WARNING)
64
  elif not torch.cuda.is_available():
65
  show_warning(CUDA_NOT_AVAILABLE_WARNING)
66
+ elif(not "T4" in GPU_DATA):
67
  show_warning(INVALID_GPU_WARNING)
68
+
69
 
70
  gr.Markdown(TITLE)
71
  with gr.Tabs():
72
+ with gr.TabItem('Train'):
73
+ create_training_demo(trainer, pipe)
74
+ with gr.TabItem('Run'):
75
+ create_inference_demo(pipe, HF_TOKEN)
76
+ with gr.TabItem('Upload'):
77
+ gr.Markdown('''
 
78
  - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
79
+ ''')
80
+ create_upload_demo(HF_TOKEN)
81
+
 
 
 
 
 
 
82
  if not HF_TOKEN:
83
  show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
84
 
85
+ demo.queue(max_size=1).launch(share=False)
app_inference.py CHANGED
@@ -14,7 +14,7 @@ from utils import find_exp_dirs
14
 
15
  class ModelSource(enum.Enum):
16
  HUB_LIB = UploadTarget.MODEL_LIBRARY.value
17
- LOCAL = "Local"
18
 
19
 
20
  class InferenceUtil:
@@ -23,13 +23,18 @@ class InferenceUtil:
23
 
24
  def load_hub_model_list(self) -> dict:
25
  api = HfApi(token=self.hf_token)
26
- choices = [info.modelId for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)]
27
- return gr.update(choices=choices, value=choices[0] if choices else None)
 
 
 
 
28
 
29
  @staticmethod
30
  def load_local_model_list() -> dict:
31
  choices = find_exp_dirs()
32
- return gr.update(choices=choices, value=choices[0] if choices else None)
 
33
 
34
  def reload_model_list(self, model_source: str) -> dict:
35
  if model_source == ModelSource.HUB_LIB.value:
@@ -43,21 +48,21 @@ class InferenceUtil:
43
  try:
44
  card = InferencePipeline.get_model_card(model_id, self.hf_token)
45
  except Exception:
46
- return "", ""
47
- base_model = getattr(card.data, "base_model", "")
48
- training_prompt = getattr(card.data, "training_prompt", "")
49
  return base_model, training_prompt
50
 
51
- def reload_model_list_and_update_model_info(self, model_source: str) -> tuple[dict, str, str]:
 
52
  model_list_update = self.reload_model_list(model_source)
53
- model_list = model_list_update["choices"]
54
- model_info = self.load_model_info(model_list[0] if model_list else "")
55
  return model_list_update, *model_info
56
 
57
 
58
- def create_inference_demo(
59
- pipe: InferencePipeline, hf_token: str | None = None, disable_run_button: bool = False
60
- ) -> gr.Blocks:
61
  app = InferenceUtil(hf_token)
62
 
63
  with gr.Blocks() as demo:
@@ -65,60 +70,83 @@ def create_inference_demo(
65
  with gr.Column():
66
  with gr.Box():
67
  model_source = gr.Radio(
68
- label="Model Source", choices=[_.value for _ in ModelSource], value=ModelSource.HUB_LIB.value
69
- )
70
- reload_button = gr.Button("Reload Model List")
71
- model_id = gr.Dropdown(label="Model ID", choices=None, value=None)
72
- with gr.Accordion(label="Model info (Base model and prompt used for training)", open=False):
 
 
 
 
 
 
73
  with gr.Row():
74
- base_model_used_for_training = gr.Text(label="Base model", interactive=False)
75
- prompt_used_for_training = gr.Text(label="Training prompt", interactive=False)
76
- prompt = gr.Textbox(label="Prompt", max_lines=1, placeholder='Example: "A panda is surfing"')
77
- video_length = gr.Slider(label="Video length", minimum=4, maximum=12, step=1, value=8)
78
- fps = gr.Slider(label="FPS", minimum=1, maximum=12, step=1, value=1)
79
- seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=0)
80
- with gr.Accordion("Advanced options", open=False):
81
- num_steps = gr.Slider(label="Number of Steps", minimum=0, maximum=100, step=1, value=50)
82
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=50, step=0.1, value=7.5)
83
-
84
- run_button = gr.Button("Generate", interactive=not disable_run_button)
85
-
86
- gr.Markdown(
87
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  - After training, you can press "Reload Model List" button to load your trained model names.
89
  - It takes a few minutes to download model first.
90
  - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
91
- """
92
- )
93
  with gr.Column():
94
- result = gr.Video(label="Result")
95
-
96
- model_source.change(
97
- fn=app.reload_model_list_and_update_model_info,
98
- inputs=model_source,
99
- outputs=[
100
- model_id,
101
- base_model_used_for_training,
102
- prompt_used_for_training,
103
- ],
104
- )
105
- reload_button.click(
106
- fn=app.reload_model_list_and_update_model_info,
107
- inputs=model_source,
108
- outputs=[
109
- model_id,
110
- base_model_used_for_training,
111
- prompt_used_for_training,
112
- ],
113
- )
114
- model_id.change(
115
- fn=app.load_model_info,
116
- inputs=model_id,
117
- outputs=[
118
- base_model_used_for_training,
119
- prompt_used_for_training,
120
- ],
121
- )
122
  inputs = [
123
  model_id,
124
  prompt,
@@ -133,10 +161,10 @@ def create_inference_demo(
133
  return demo
134
 
135
 
136
- if __name__ == "__main__":
137
  import os
138
 
139
- hf_token = os.getenv("HF_TOKEN")
140
  pipe = InferencePipeline(hf_token)
141
  demo = create_inference_demo(pipe, hf_token)
142
- demo.queue(api_open=False, max_size=10).launch()
 
14
 
15
  class ModelSource(enum.Enum):
16
  HUB_LIB = UploadTarget.MODEL_LIBRARY.value
17
+ LOCAL = 'Local'
18
 
19
 
20
  class InferenceUtil:
 
23
 
24
  def load_hub_model_list(self) -> dict:
25
  api = HfApi(token=self.hf_token)
26
+ choices = [
27
+ info.modelId
28
+ for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)
29
+ ]
30
+ return gr.update(choices=choices,
31
+ value=choices[0] if choices else None)
32
 
33
  @staticmethod
34
  def load_local_model_list() -> dict:
35
  choices = find_exp_dirs()
36
+ return gr.update(choices=choices,
37
+ value=choices[0] if choices else None)
38
 
39
  def reload_model_list(self, model_source: str) -> dict:
40
  if model_source == ModelSource.HUB_LIB.value:
 
48
  try:
49
  card = InferencePipeline.get_model_card(model_id, self.hf_token)
50
  except Exception:
51
+ return '', ''
52
+ base_model = getattr(card.data, 'base_model', '')
53
+ training_prompt = getattr(card.data, 'training_prompt', '')
54
  return base_model, training_prompt
55
 
56
+ def reload_model_list_and_update_model_info(
57
+ self, model_source: str) -> tuple[dict, str, str]:
58
  model_list_update = self.reload_model_list(model_source)
59
+ model_list = model_list_update['choices']
60
+ model_info = self.load_model_info(model_list[0] if model_list else '')
61
  return model_list_update, *model_info
62
 
63
 
64
+ def create_inference_demo(pipe: InferencePipeline,
65
+ hf_token: str | None = None) -> gr.Blocks:
 
66
  app = InferenceUtil(hf_token)
67
 
68
  with gr.Blocks() as demo:
 
70
  with gr.Column():
71
  with gr.Box():
72
  model_source = gr.Radio(
73
+ label='Model Source',
74
+ choices=[_.value for _ in ModelSource],
75
+ value=ModelSource.HUB_LIB.value)
76
+ reload_button = gr.Button('Reload Model List')
77
+ model_id = gr.Dropdown(label='Model ID',
78
+ choices=None,
79
+ value=None)
80
+ with gr.Accordion(
81
+ label=
82
+ 'Model info (Base model and prompt used for training)',
83
+ open=False):
84
  with gr.Row():
85
+ base_model_used_for_training = gr.Text(
86
+ label='Base model', interactive=False)
87
+ prompt_used_for_training = gr.Text(
88
+ label='Training prompt', interactive=False)
89
+ prompt = gr.Textbox(
90
+ label='Prompt',
91
+ max_lines=1,
92
+ placeholder='Example: "A panda is surfing"')
93
+ video_length = gr.Slider(label='Video length',
94
+ minimum=4,
95
+ maximum=12,
96
+ step=1,
97
+ value=8)
98
+ fps = gr.Slider(label='FPS',
99
+ minimum=1,
100
+ maximum=12,
101
+ step=1,
102
+ value=1)
103
+ seed = gr.Slider(label='Seed',
104
+ minimum=0,
105
+ maximum=100000,
106
+ step=1,
107
+ value=0)
108
+ with gr.Accordion('Other Parameters', open=False):
109
+ num_steps = gr.Slider(label='Number of Steps',
110
+ minimum=0,
111
+ maximum=100,
112
+ step=1,
113
+ value=50)
114
+ guidance_scale = gr.Slider(label='CFG Scale',
115
+ minimum=0,
116
+ maximum=50,
117
+ step=0.1,
118
+ value=7.5)
119
+
120
+ run_button = gr.Button('Generate')
121
+
122
+ gr.Markdown('''
123
  - After training, you can press "Reload Model List" button to load your trained model names.
124
  - It takes a few minutes to download model first.
125
  - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
126
+ ''')
 
127
  with gr.Column():
128
+ result = gr.Video(label='Result')
129
+
130
+ model_source.change(fn=app.reload_model_list_and_update_model_info,
131
+ inputs=model_source,
132
+ outputs=[
133
+ model_id,
134
+ base_model_used_for_training,
135
+ prompt_used_for_training,
136
+ ])
137
+ reload_button.click(fn=app.reload_model_list_and_update_model_info,
138
+ inputs=model_source,
139
+ outputs=[
140
+ model_id,
141
+ base_model_used_for_training,
142
+ prompt_used_for_training,
143
+ ])
144
+ model_id.change(fn=app.load_model_info,
145
+ inputs=model_id,
146
+ outputs=[
147
+ base_model_used_for_training,
148
+ prompt_used_for_training,
149
+ ])
 
 
 
 
 
 
150
  inputs = [
151
  model_id,
152
  prompt,
 
161
  return demo
162
 
163
 
164
+ if __name__ == '__main__':
165
  import os
166
 
167
+ hf_token = os.getenv('HF_TOKEN')
168
  pipe = InferencePipeline(hf_token)
169
  demo = create_inference_demo(pipe, hf_token)
170
+ demo.queue(max_size=10).launch(share=False)
app_system_monitor.py DELETED
@@ -1,86 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- from __future__ import annotations
4
-
5
- import collections
6
-
7
- import gradio as gr
8
- import nvitop
9
- import pandas as pd
10
- import plotly.express as px
11
- import psutil
12
-
13
-
14
- class SystemMonitor:
15
- MAX_SIZE = 61
16
-
17
- def __init__(self):
18
- self.devices = nvitop.Device.all()
19
- self.cpu_memory_usage = collections.deque([0 for _ in range(self.MAX_SIZE)], maxlen=self.MAX_SIZE)
20
- self.cpu_memory_usage_str = ""
21
- self.gpu_memory_usage = collections.deque([0 for _ in range(self.MAX_SIZE)], maxlen=self.MAX_SIZE)
22
- self.gpu_util = collections.deque([0 for _ in range(self.MAX_SIZE)], maxlen=self.MAX_SIZE)
23
- self.gpu_memory_usage_str = ""
24
- self.gpu_util_str = ""
25
-
26
- def update(self) -> None:
27
- self.update_cpu()
28
- self.update_gpu()
29
-
30
- def update_cpu(self) -> None:
31
- memory = psutil.virtual_memory()
32
- self.cpu_memory_usage.append(memory.percent)
33
- self.cpu_memory_usage_str = (
34
- f"{memory.used / 1024**3:0.2f}GiB / {memory.total / 1024**3:0.2f}GiB ({memory.percent}%)"
35
- )
36
-
37
- def update_gpu(self) -> None:
38
- if not self.devices:
39
- return
40
- device = self.devices[0]
41
- self.gpu_memory_usage.append(device.memory_percent())
42
- self.gpu_util.append(device.gpu_utilization())
43
- self.gpu_memory_usage_str = f"{device.memory_usage()} ({device.memory_percent()}%)"
44
- self.gpu_util_str = f"{device.gpu_utilization()}%"
45
-
46
- def get_json(self) -> dict[str, str]:
47
- return {
48
- "CPU memory usage": self.cpu_memory_usage_str,
49
- "GPU memory usage": self.gpu_memory_usage_str,
50
- "GPU Util": self.gpu_util_str,
51
- }
52
-
53
- def get_graph_data(self) -> dict[str, list[int | float]]:
54
- return {
55
- "index": list(range(-self.MAX_SIZE + 1, 1)),
56
- "CPU memory usage": self.cpu_memory_usage,
57
- "GPU memory usage": self.gpu_memory_usage,
58
- "GPU Util": self.gpu_util,
59
- }
60
-
61
- def get_graph(self):
62
- df = pd.DataFrame(self.get_graph_data())
63
- return px.line(
64
- df,
65
- x="index",
66
- y=[
67
- "CPU memory usage",
68
- "GPU memory usage",
69
- "GPU Util",
70
- ],
71
- range_y=[-5, 105],
72
- ).update_layout(xaxis_title="Time", yaxis_title="Percentage")
73
-
74
-
75
- def create_monitor_demo() -> gr.Blocks:
76
- monitor = SystemMonitor()
77
- with gr.Blocks() as demo:
78
- gr.JSON(value=monitor.update, every=1, visible=False)
79
- gr.JSON(value=monitor.get_json, show_label=False, every=1)
80
- gr.Plot(value=monitor.get_graph, show_label=False, every=1)
81
- return demo
82
-
83
-
84
- if __name__ == "__main__":
85
- demo = create_monitor_demo()
86
- demo.queue(api_open=False).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_training.py CHANGED
@@ -6,125 +6,135 @@ import os
6
 
7
  import gradio as gr
8
 
9
- from constants import UploadTarget
10
  from inference import InferencePipeline
11
  from trainer import Trainer
12
 
13
 
14
- def create_training_demo(
15
- trainer: Trainer, pipe: InferencePipeline | None = None, disable_run_button: bool = False
16
- ) -> gr.Blocks:
17
- def read_log() -> str:
18
- with open(trainer.log_file) as f:
19
- lines = f.readlines()
20
- return "".join(lines[-10:])
21
-
22
  with gr.Blocks() as demo:
23
  with gr.Row():
24
  with gr.Column():
25
  with gr.Box():
26
- gr.Markdown("Training Data")
27
- training_video = gr.File(label="Training video")
28
- training_prompt = gr.Textbox(label="Training prompt", max_lines=1, placeholder="A man is surfing")
29
- gr.Markdown(
30
- """
 
 
31
  - Upload a video and write a `Training Prompt` that describes the video.
32
- """
33
- )
34
 
35
  with gr.Column():
36
  with gr.Box():
37
- gr.Markdown("Training Parameters")
38
  with gr.Row():
39
- base_model = gr.Text(label="Base Model", value="CompVis/stable-diffusion-v1-4", max_lines=1)
40
- resolution = gr.Dropdown(
41
- choices=["512", "768"], value="512", label="Resolution", visible=False
42
- )
43
-
44
- hf_token = gr.Text(
45
- label="Hugging Face Write Token", type="password", visible=os.getenv("HF_TOKEN") is None
46
- )
47
- with gr.Accordion(label="Advanced options", open=False):
48
- num_training_steps = gr.Number(label="Number of Training Steps", value=300, precision=0)
49
- learning_rate = gr.Number(label="Learning Rate", value=0.000035)
 
 
 
50
  gradient_accumulation = gr.Number(
51
- label="Number of Gradient Accumulation", value=1, precision=0
52
- )
53
- seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, randomize=True, value=0)
54
- fp16 = gr.Checkbox(label="FP16", value=True)
55
- use_8bit_adam = gr.Checkbox(label="Use 8bit Adam", value=False)
56
- checkpointing_steps = gr.Number(label="Checkpointing Steps", value=1000, precision=0)
57
- validation_epochs = gr.Number(label="Validation Epochs", value=100, precision=0)
58
- gr.Markdown(
59
- """
 
 
 
 
 
 
 
 
 
60
  - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
61
  - Expected time to train a model for 300 steps: ~20 minutes with T4
62
  - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
63
- """
64
- )
65
-
66
  with gr.Row():
67
  with gr.Column():
68
- gr.Markdown("Output Model")
69
- output_model_name = gr.Text(label="Name of your model", placeholder="The surfer man", max_lines=1)
70
- validation_prompt = gr.Text(
71
- label="Validation Prompt", placeholder="prompt to test the model, e.g: a dog is surfing"
72
- )
73
  with gr.Column():
74
- gr.Markdown("Upload Settings")
75
  with gr.Row():
76
- upload_to_hub = gr.Checkbox(label="Upload model to Hub", value=True)
77
- use_private_repo = gr.Checkbox(label="Private", value=True)
78
- delete_existing_repo = gr.Checkbox(label="Delete existing repo of the same name", value=False)
 
 
 
 
79
  upload_to = gr.Radio(
80
- label="Upload to",
81
  choices=[_.value for _ in UploadTarget],
82
- value=UploadTarget.MODEL_LIBRARY.value,
83
- )
84
-
85
- pause_space_after_training = gr.Checkbox(
86
- label="Pause this Space after training",
87
  value=False,
88
- interactive=bool(os.getenv("SPACE_ID")),
89
- visible=False,
90
- )
91
- run_button = gr.Button("Start Training", interactive=not disable_run_button)
92
-
93
  with gr.Box():
94
- gr.Text(label="Log", value=read_log, lines=10, max_lines=10, every=1)
 
95
 
96
  if pipe is not None:
97
  run_button.click(fn=pipe.clear)
98
- run_button.click(
99
- fn=trainer.run,
100
- inputs=[
101
- training_video,
102
- training_prompt,
103
- output_model_name,
104
- delete_existing_repo,
105
- validation_prompt,
106
- base_model,
107
- resolution,
108
- num_training_steps,
109
- learning_rate,
110
- gradient_accumulation,
111
- seed,
112
- fp16,
113
- use_8bit_adam,
114
- checkpointing_steps,
115
- validation_epochs,
116
- upload_to_hub,
117
- use_private_repo,
118
- delete_existing_repo,
119
- upload_to,
120
- pause_space_after_training,
121
- hf_token,
122
- ],
123
- )
124
  return demo
125
 
126
 
127
- if __name__ == "__main__":
128
- trainer = Trainer()
 
129
  demo = create_training_demo(trainer)
130
- demo.queue(api_open=False, max_size=1).launch()
 
6
 
7
  import gradio as gr
8
 
9
+ from constants import MODEL_LIBRARY_ORG_NAME, SAMPLE_MODEL_REPO, UploadTarget
10
  from inference import InferencePipeline
11
  from trainer import Trainer
12
 
13
 
14
+ def create_training_demo(trainer: Trainer,
15
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
16
+ hf_token = os.getenv('HF_TOKEN')
 
 
 
 
 
17
  with gr.Blocks() as demo:
18
  with gr.Row():
19
  with gr.Column():
20
  with gr.Box():
21
+ gr.Markdown('Training Data')
22
+ training_video = gr.File(label='Training video')
23
+ training_prompt = gr.Textbox(
24
+ label='Training prompt',
25
+ max_lines=1,
26
+ placeholder='A man is surfing')
27
+ gr.Markdown('''
28
  - Upload a video and write a `Training Prompt` that describes the video.
29
+ ''')
 
30
 
31
  with gr.Column():
32
  with gr.Box():
33
+ gr.Markdown('Training Parameters')
34
  with gr.Row():
35
+ base_model = gr.Text(label='Base Model',
36
+ value='CompVis/stable-diffusion-v1-4',
37
+ max_lines=1)
38
+ resolution = gr.Dropdown(choices=['512', '768'],
39
+ value='512',
40
+ label='Resolution',
41
+ visible=False)
42
+
43
+ input_token = gr.Text(label="Hugging Face Write Token", placeholder="", visible=False if hf_token else True)
44
+ with gr.Accordion("Advanced settings", open=False):
45
+ num_training_steps = gr.Number(
46
+ label='Number of Training Steps', value=300, precision=0)
47
+ learning_rate = gr.Number(label='Learning Rate',
48
+ value=0.000035)
49
  gradient_accumulation = gr.Number(
50
+ label='Number of Gradient Accumulation',
51
+ value=1,
52
+ precision=0)
53
+ seed = gr.Slider(label='Seed',
54
+ minimum=0,
55
+ maximum=100000,
56
+ step=1,
57
+ randomize=True,
58
+ value=0)
59
+ fp16 = gr.Checkbox(label='FP16', value=True)
60
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=False)
61
+ checkpointing_steps = gr.Number(label='Checkpointing Steps',
62
+ value=1000,
63
+ precision=0)
64
+ validation_epochs = gr.Number(label='Validation Epochs',
65
+ value=100,
66
+ precision=0)
67
+ gr.Markdown('''
68
  - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
69
  - Expected time to train a model for 300 steps: ~20 minutes with T4
70
  - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
71
+ ''')
72
+
 
73
  with gr.Row():
74
  with gr.Column():
75
+ gr.Markdown('Output Model')
76
+ output_model_name = gr.Text(label='Name of your model',
77
+ placeholder='The surfer man',
78
+ max_lines=1)
79
+ validation_prompt = gr.Text(label='Validation Prompt', placeholder='prompt to test the model, e.g: a dog is surfing')
80
  with gr.Column():
81
+ gr.Markdown('Upload Settings')
82
  with gr.Row():
83
+ upload_to_hub = gr.Checkbox(
84
+ label='Upload model to Hub', value=True)
85
+ use_private_repo = gr.Checkbox(label='Private',
86
+ value=True)
87
+ delete_existing_repo = gr.Checkbox(
88
+ label='Delete existing repo of the same name',
89
+ value=False)
90
  upload_to = gr.Radio(
91
+ label='Upload to',
92
  choices=[_.value for _ in UploadTarget],
93
+ value=UploadTarget.MODEL_LIBRARY.value)
94
+
95
+ remove_gpu_after_training = gr.Checkbox(
96
+ label='Remove GPU after training',
 
97
  value=False,
98
+ interactive=bool(os.getenv('SPACE_ID')),
99
+ visible=False)
100
+ run_button = gr.Button('Start Training')
101
+
 
102
  with gr.Box():
103
+ gr.Markdown('Output message')
104
+ output_message = gr.Markdown()
105
 
106
  if pipe is not None:
107
  run_button.click(fn=pipe.clear)
108
+ run_button.click(fn=trainer.run,
109
+ inputs=[
110
+ training_video,
111
+ training_prompt,
112
+ output_model_name,
113
+ delete_existing_repo,
114
+ validation_prompt,
115
+ base_model,
116
+ resolution,
117
+ num_training_steps,
118
+ learning_rate,
119
+ gradient_accumulation,
120
+ seed,
121
+ fp16,
122
+ use_8bit_adam,
123
+ checkpointing_steps,
124
+ validation_epochs,
125
+ upload_to_hub,
126
+ use_private_repo,
127
+ delete_existing_repo,
128
+ upload_to,
129
+ remove_gpu_after_training,
130
+ input_token
131
+ ],
132
+ outputs=output_message)
 
133
  return demo
134
 
135
 
136
+ if __name__ == '__main__':
137
+ hf_token = os.getenv('HF_TOKEN')
138
+ trainer = Trainer(hf_token)
139
  demo = create_training_demo(trainer)
140
+ demo.queue(max_size=1).launch(share=False)
app_upload.py CHANGED
@@ -2,68 +2,103 @@
2
 
3
  from __future__ import annotations
4
 
5
- import os
6
 
7
  import gradio as gr
 
8
 
9
  from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
10
- from uploader import upload
11
  from utils import find_exp_dirs
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def load_local_model_list() -> dict:
15
  choices = find_exp_dirs()
16
  return gr.update(choices=choices, value=choices[0] if choices else None)
17
 
18
 
19
- def create_upload_demo(disable_run_button: bool = False) -> gr.Blocks:
 
20
  model_dirs = find_exp_dirs()
21
 
22
  with gr.Blocks() as demo:
23
  with gr.Box():
24
- gr.Markdown("Local Models")
25
- reload_button = gr.Button("Reload Model List")
26
  model_dir = gr.Dropdown(
27
- label="Model names", choices=model_dirs, value=model_dirs[0] if model_dirs else None
28
- )
 
29
  with gr.Box():
30
- gr.Markdown("Upload Settings")
31
  with gr.Row():
32
- use_private_repo = gr.Checkbox(label="Private", value=True)
33
- delete_existing_repo = gr.Checkbox(label="Delete existing repo of the same name", value=False)
34
- upload_to = gr.Radio(
35
- label="Upload to", choices=[_.value for _ in UploadTarget], value=UploadTarget.MODEL_LIBRARY.value
36
- )
37
- model_name = gr.Textbox(label="Model Name")
38
- hf_token = gr.Text(
39
- label="Hugging Face Write Token", type="password", visible=os.getenv("HF_TOKEN") is None
40
- )
41
- upload_button = gr.Button("Upload", interactive=not disable_run_button)
42
- gr.Markdown(
43
- f"""
44
- - You can upload your trained model to your personal profile (i.e. `https://huggingface.co/{{your_username}}/{{model_name}}`) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. `https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}`).
45
- """
46
- )
47
  with gr.Box():
48
- gr.Markdown("Output message")
49
  output_message = gr.Markdown()
50
 
51
- reload_button.click(fn=load_local_model_list, inputs=None, outputs=model_dir)
52
- upload_button.click(
53
- fn=upload,
54
- inputs=[
55
- model_dir,
56
- model_name,
57
- upload_to,
58
- use_private_repo,
59
- delete_existing_repo,
60
- hf_token,
61
- ],
62
- outputs=output_message,
63
- )
 
64
  return demo
65
 
66
 
67
- if __name__ == "__main__":
68
- demo = create_upload_demo()
69
- demo.queue(api_open=False, max_size=1).launch()
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
+ import pathlib
6
 
7
  import gradio as gr
8
+ import slugify
9
 
10
  from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
11
+ from uploader import Uploader
12
  from utils import find_exp_dirs
13
 
14
 
15
+ class ModelUploader(Uploader):
16
+ def upload_model(
17
+ self,
18
+ folder_path: str,
19
+ repo_name: str,
20
+ upload_to: str,
21
+ private: bool,
22
+ delete_existing_repo: bool,
23
+ input_token: str | None = None,
24
+ ) -> str:
25
+ if not folder_path:
26
+ raise ValueError
27
+ if not repo_name:
28
+ repo_name = pathlib.Path(folder_path).name
29
+ repo_name = slugify.slugify(repo_name)
30
+
31
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
32
+ organization = ''
33
+ elif upload_to == UploadTarget.MODEL_LIBRARY.value:
34
+ organization = MODEL_LIBRARY_ORG_NAME
35
+ else:
36
+ raise ValueError
37
+
38
+ return self.upload(folder_path,
39
+ repo_name,
40
+ organization=organization,
41
+ private=private,
42
+ delete_existing_repo=delete_existing_repo,
43
+ input_token=input_token)
44
+
45
+
46
  def load_local_model_list() -> dict:
47
  choices = find_exp_dirs()
48
  return gr.update(choices=choices, value=choices[0] if choices else None)
49
 
50
 
51
+ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
52
+ uploader = ModelUploader(hf_token)
53
  model_dirs = find_exp_dirs()
54
 
55
  with gr.Blocks() as demo:
56
  with gr.Box():
57
+ gr.Markdown('Local Models')
58
+ reload_button = gr.Button('Reload Model List')
59
  model_dir = gr.Dropdown(
60
+ label='Model names',
61
+ choices=model_dirs,
62
+ value=model_dirs[0] if model_dirs else None)
63
  with gr.Box():
64
+ gr.Markdown('Upload Settings')
65
  with gr.Row():
66
+ use_private_repo = gr.Checkbox(label='Private', value=True)
67
+ delete_existing_repo = gr.Checkbox(
68
+ label='Delete existing repo of the same name', value=False)
69
+ upload_to = gr.Radio(label='Upload to',
70
+ choices=[_.value for _ in UploadTarget],
71
+ value=UploadTarget.MODEL_LIBRARY.value)
72
+ model_name = gr.Textbox(label='Model Name')
73
+ input_token = gr.Text(label="Hugging Face Write Token", placeholder="", visible=False if hf_token else True)
74
+ upload_button = gr.Button('Upload')
75
+ gr.Markdown(f'''
76
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
77
+ ''')
 
 
 
78
  with gr.Box():
79
+ gr.Markdown('Output message')
80
  output_message = gr.Markdown()
81
 
82
+ reload_button.click(fn=load_local_model_list,
83
+ inputs=None,
84
+ outputs=model_dir)
85
+ upload_button.click(fn=uploader.upload_model,
86
+ inputs=[
87
+ model_dir,
88
+ model_name,
89
+ upload_to,
90
+ use_private_repo,
91
+ delete_existing_repo,
92
+ input_token,
93
+ ],
94
+ outputs=output_message)
95
+
96
  return demo
97
 
98
 
99
+ if __name__ == '__main__':
100
+ import os
101
+
102
+ hf_token = os.getenv('HF_TOKEN')
103
+ demo = create_upload_demo(hf_token)
104
+ demo.queue(max_size=1).launch(share=False)
constants.py CHANGED
@@ -2,12 +2,9 @@ import enum
2
 
3
 
4
  class UploadTarget(enum.Enum):
5
- PERSONAL_PROFILE = "Personal Profile"
6
- MODEL_LIBRARY = "Tune-A-Video Library"
7
 
8
 
9
- MODEL_LIBRARY_ORG_NAME = "Tune-A-Video-library"
10
- SAMPLE_MODEL_REPO = "Tune-A-Video-library/a-man-is-surfing"
11
- URL_TO_JOIN_MODEL_LIBRARY_ORG = (
12
- "https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk"
13
- )
 
2
 
3
 
4
  class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = 'Personal Profile'
6
+ MODEL_LIBRARY = 'Tune-A-Video Library'
7
 
8
 
9
+ MODEL_LIBRARY_ORG_NAME = 'Tune-A-Video-library'
10
+ SAMPLE_MODEL_REPO = 'Tune-A-Video-library/a-man-is-surfing'
 
 
 
inference.py CHANGED
@@ -13,7 +13,7 @@ from diffusers.utils.import_utils import is_xformers_available
13
  from einops import rearrange
14
  from huggingface_hub import ModelCard
15
 
16
- sys.path.append("Tune-A-Video")
17
 
18
  from tuneavideo.models.unet import UNet3DConditionModel
19
  from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
@@ -23,7 +23,8 @@ class InferencePipeline:
23
  def __init__(self, hf_token: str | None = None):
24
  self.hf_token = hf_token
25
  self.pipe = None
26
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
27
  self.model_id = None
28
 
29
  def clear(self) -> None:
@@ -38,9 +39,10 @@ class InferencePipeline:
38
  return pathlib.Path(model_id).exists()
39
 
40
  @staticmethod
41
- def get_model_card(model_id: str, hf_token: str | None = None) -> ModelCard:
 
42
  if InferencePipeline.check_if_model_is_local(model_id):
43
- card_path = (pathlib.Path(model_id) / "README.md").as_posix()
44
  else:
45
  card_path = model_id
46
  return ModelCard.load(card_path, token=hf_token)
@@ -55,11 +57,14 @@ class InferencePipeline:
55
  return
56
  base_model_id = self.get_base_model_info(model_id, self.hf_token)
57
  unet = UNet3DConditionModel.from_pretrained(
58
- model_id, subfolder="unet", torch_dtype=torch.float16, use_auth_token=self.hf_token
59
- )
60
- pipe = TuneAVideoPipeline.from_pretrained(
61
- base_model_id, unet=unet, torch_dtype=torch.float16, use_auth_token=self.hf_token
62
- )
 
 
 
63
  pipe = pipe.to(self.device)
64
  if is_xformers_available():
65
  pipe.unet.enable_xformers_memory_efficient_attention()
@@ -77,7 +82,7 @@ class InferencePipeline:
77
  guidance_scale: float,
78
  ) -> PIL.Image.Image:
79
  if not torch.cuda.is_available():
80
- raise gr.Error("CUDA is not available.")
81
 
82
  self.load_pipe(model_id)
83
 
@@ -92,10 +97,10 @@ class InferencePipeline:
92
  generator=generator,
93
  ) # type: ignore
94
 
95
- frames = rearrange(out.videos[0], "c t h w -> t h w c")
96
  frames = (frames * 255).to(torch.uint8).numpy()
97
 
98
- out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
99
  writer = imageio.get_writer(out_file.name, fps=fps)
100
  for frame in frames:
101
  writer.append_data(frame)
 
13
  from einops import rearrange
14
  from huggingface_hub import ModelCard
15
 
16
+ sys.path.append('Tune-A-Video')
17
 
18
  from tuneavideo.models.unet import UNet3DConditionModel
19
  from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
 
23
  def __init__(self, hf_token: str | None = None):
24
  self.hf_token = hf_token
25
  self.pipe = None
26
+ self.device = torch.device(
27
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
28
  self.model_id = None
29
 
30
  def clear(self) -> None:
 
39
  return pathlib.Path(model_id).exists()
40
 
41
  @staticmethod
42
+ def get_model_card(model_id: str,
43
+ hf_token: str | None = None) -> ModelCard:
44
  if InferencePipeline.check_if_model_is_local(model_id):
45
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
46
  else:
47
  card_path = model_id
48
  return ModelCard.load(card_path, token=hf_token)
 
57
  return
58
  base_model_id = self.get_base_model_info(model_id, self.hf_token)
59
  unet = UNet3DConditionModel.from_pretrained(
60
+ model_id,
61
+ subfolder='unet',
62
+ torch_dtype=torch.float16,
63
+ use_auth_token=self.hf_token)
64
+ pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
65
+ unet=unet,
66
+ torch_dtype=torch.float16,
67
+ use_auth_token=self.hf_token)
68
  pipe = pipe.to(self.device)
69
  if is_xformers_available():
70
  pipe.unet.enable_xformers_memory_efficient_attention()
 
82
  guidance_scale: float,
83
  ) -> PIL.Image.Image:
84
  if not torch.cuda.is_available():
85
+ raise gr.Error('CUDA is not available.')
86
 
87
  self.load_pipe(model_id)
88
 
 
97
  generator=generator,
98
  ) # type: ignore
99
 
100
+ frames = rearrange(out.videos[0], 'c t h w -> t h w c')
101
  frames = (frames * 255).to(torch.uint8).numpy()
102
 
103
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
104
  writer = imageio.get_writer(out_file.name, fps=fps)
105
  for frame in frames:
106
  writer.append_data(frame)
requirements-monitor.txt DELETED
@@ -1,4 +0,0 @@
1
- nvitop==1.1.1
2
- pandas==2.0.0
3
- plotly==5.14.1
4
- psutil==5.9.4
 
 
 
 
 
requirements.txt CHANGED
@@ -1,19 +1,19 @@
1
- accelerate==0.18.0
2
- bitsandbytes==0.37.2
3
  decord==0.6.0
4
  diffusers[torch]==0.11.1
5
  einops==0.6.0
6
  ftfy==6.1.1
7
- gradio==3.43.2
8
- huggingface-hub==0.17.1
9
- imageio==2.27.0
10
  imageio-ffmpeg==0.4.8
11
  omegaconf==2.3.0
12
- Pillow==9.5.0
13
- python-slugify==8.0.1
14
  tensorboard==2.11.2
15
  torch==1.13.1
16
  torchvision==0.14.1
17
  transformers==4.26.0
18
- triton==2.0.0
19
  xformers==0.0.16
 
1
+ accelerate==0.15.0
2
+ bitsandbytes==0.35.4
3
  decord==0.6.0
4
  diffusers[torch]==0.11.1
5
  einops==0.6.0
6
  ftfy==6.1.1
7
+ gradio==3.16.2
8
+ huggingface-hub==0.12.0
9
+ imageio==2.25.0
10
  imageio-ffmpeg==0.4.8
11
  omegaconf==2.3.0
12
+ Pillow==9.4.0
13
+ python-slugify==7.0.0
14
  tensorboard==2.11.2
15
  torch==1.13.1
16
  torchvision==0.14.1
17
  transformers==4.26.0
18
+ triton==2.0.0.dev20221202
19
  xformers==0.0.16
trainer.py CHANGED
@@ -8,34 +8,46 @@ import shutil
8
  import subprocess
9
  import sys
10
 
 
11
  import slugify
12
  import torch
13
  from huggingface_hub import HfApi
14
  from omegaconf import OmegaConf
15
 
16
- from uploader import upload
17
  from utils import save_model_card
18
 
19
- sys.path.append("Tune-A-Video")
20
 
 
 
 
21
 
22
  class Trainer:
23
- def __init__(self):
24
- self.checkpoint_dir = pathlib.Path("checkpoints")
25
- self.checkpoint_dir.mkdir(exist_ok=True)
26
 
27
- self.log_file = pathlib.Path("log.txt")
28
- self.log_file.touch(exist_ok=True)
29
 
30
  def download_base_model(self, base_model_id: str) -> str:
31
  model_dir = self.checkpoint_dir / base_model_id
32
  if not model_dir.exists():
33
- org_name = base_model_id.split("/")[0]
34
  org_dir = self.checkpoint_dir / org_name
35
  org_dir.mkdir(exist_ok=True)
36
- subprocess.run(shlex.split(f"git clone https://huggingface.co/{base_model_id}"), cwd=org_dir)
 
 
37
  return model_dir.as_posix()
38
 
 
 
 
 
 
 
39
  def run(
40
  self,
41
  training_video: str,
@@ -57,32 +69,37 @@ class Trainer:
57
  use_private_repo: bool,
58
  delete_existing_repo: bool,
59
  upload_to: str,
60
- pause_space_after_training: bool,
61
- hf_token: str,
62
- ) -> None:
 
 
63
  if not torch.cuda.is_available():
64
- raise RuntimeError("CUDA is not available.")
65
  if training_video is None:
66
- raise ValueError("You need to upload a video.")
67
  if not training_prompt:
68
- raise ValueError("The training prompt is missing.")
69
  if not validation_prompt:
70
- raise ValueError("The validation prompt is missing.")
71
 
72
  resolution = int(resolution_s)
73
 
74
  if not output_model_name:
75
- timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
76
- output_model_name = f"tune-a-video-{timestamp}"
77
  output_model_name = slugify.slugify(output_model_name)
78
 
79
  repo_dir = pathlib.Path(__file__).parent
80
- output_dir = repo_dir / "experiments" / output_model_name
81
  if overwrite_existing_model or upload_to_hub:
82
  shutil.rmtree(output_dir, ignore_errors=True)
83
  output_dir.mkdir(parents=True)
84
 
85
- config = OmegaConf.load("Tune-A-Video/configs/man-surfing.yaml")
 
 
 
86
  config.pretrained_model_path = self.download_base_model(base_model)
87
  config.output_dir = output_dir.as_posix()
88
  config.train_data.video_path = training_video.name # type: ignore
@@ -105,40 +122,40 @@ class Trainer:
105
  config.checkpointing_steps = checkpointing_steps
106
  config.validation_steps = validation_epochs
107
  config.seed = seed
108
- config.mixed_precision = "fp16" if fp16 else ""
109
  config.use_8bit_adam = use_8bit_adam
110
 
111
- config_path = output_dir / "config.yaml"
112
- with open(config_path, "w") as f:
113
  OmegaConf.save(config, f)
114
 
115
- command = f"accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}"
116
- with open(self.log_file, "w") as f:
117
- subprocess.run(shlex.split(command), stdout=f, stderr=subprocess.STDOUT, text=True)
118
- save_model_card(
119
- save_dir=output_dir,
120
- base_model=base_model,
121
- training_prompt=training_prompt,
122
- test_prompt=validation_prompt,
123
- test_image_dir="samples",
124
- )
125
 
126
- with open(self.log_file, "a") as f:
127
- f.write("Training completed!\n")
128
 
129
  if upload_to_hub:
130
- upload_message = upload(
131
- local_folder_path=output_dir.as_posix(),
132
- target_repo_name=output_model_name,
133
  upload_to=upload_to,
134
  private=use_private_repo,
135
  delete_existing_repo=delete_existing_repo,
136
- hf_token=hf_token,
137
- )
138
- with open(self.log_file, "a") as f:
139
- f.write(upload_message)
140
-
141
- if pause_space_after_training:
142
- if space_id := os.getenv("SPACE_ID"):
143
- api = HfApi(token=os.getenv("HF_TOKEN") or hf_token)
144
- api.pause_space(repo_id=space_id)
 
 
 
 
8
  import subprocess
9
  import sys
10
 
11
+ import gradio as gr
12
  import slugify
13
  import torch
14
  from huggingface_hub import HfApi
15
  from omegaconf import OmegaConf
16
 
17
+ from app_upload import ModelUploader
18
  from utils import save_model_card
19
 
20
+ sys.path.append('Tune-A-Video')
21
 
22
+ URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
23
+ ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
24
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
25
 
26
  class Trainer:
27
+ def __init__(self, hf_token: str | None = None):
28
+ self.hf_token = hf_token
29
+ self.model_uploader = ModelUploader(hf_token)
30
 
31
+ self.checkpoint_dir = pathlib.Path('checkpoints')
32
+ self.checkpoint_dir.mkdir(exist_ok=True)
33
 
34
  def download_base_model(self, base_model_id: str) -> str:
35
  model_dir = self.checkpoint_dir / base_model_id
36
  if not model_dir.exists():
37
+ org_name = base_model_id.split('/')[0]
38
  org_dir = self.checkpoint_dir / org_name
39
  org_dir.mkdir(exist_ok=True)
40
+ subprocess.run(shlex.split(
41
+ f'git clone https://huggingface.co/{base_model_id}'),
42
+ cwd=org_dir)
43
  return model_dir.as_posix()
44
 
45
+ def join_model_library_org(self, token: str) -> None:
46
+ subprocess.run(
47
+ shlex.split(
48
+ f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
49
+ ))
50
+
51
  def run(
52
  self,
53
  training_video: str,
 
69
  use_private_repo: bool,
70
  delete_existing_repo: bool,
71
  upload_to: str,
72
+ remove_gpu_after_training: bool,
73
+ input_token: str,
74
+ ) -> str:
75
+ if SPACE_ID == ORIGINAL_SPACE_ID:
76
+ raise gr.Error('This Space does not work on this Shared UI. Duplicate the Space and attribute a GPU')
77
  if not torch.cuda.is_available():
78
+ raise gr.Error('CUDA is not available.')
79
  if training_video is None:
80
+ raise gr.Error('You need to upload a video.')
81
  if not training_prompt:
82
+ raise gr.Error('The training prompt is missing.')
83
  if not validation_prompt:
84
+ raise gr.Error('The validation prompt is missing.')
85
 
86
  resolution = int(resolution_s)
87
 
88
  if not output_model_name:
89
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
90
+ output_model_name = f'tune-a-video-{timestamp}'
91
  output_model_name = slugify.slugify(output_model_name)
92
 
93
  repo_dir = pathlib.Path(__file__).parent
94
+ output_dir = repo_dir / 'experiments' / output_model_name
95
  if overwrite_existing_model or upload_to_hub:
96
  shutil.rmtree(output_dir, ignore_errors=True)
97
  output_dir.mkdir(parents=True)
98
 
99
+ if upload_to_hub:
100
+ self.join_model_library_org(self.hf_token if self.hf_token else input_token)
101
+
102
+ config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
103
  config.pretrained_model_path = self.download_base_model(base_model)
104
  config.output_dir = output_dir.as_posix()
105
  config.train_data.video_path = training_video.name # type: ignore
 
122
  config.checkpointing_steps = checkpointing_steps
123
  config.validation_steps = validation_epochs
124
  config.seed = seed
125
+ config.mixed_precision = 'fp16' if fp16 else ''
126
  config.use_8bit_adam = use_8bit_adam
127
 
128
+ config_path = output_dir / 'config.yaml'
129
+ with open(config_path, 'w') as f:
130
  OmegaConf.save(config, f)
131
 
132
+ command = f'accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}'
133
+ subprocess.run(shlex.split(command))
134
+ save_model_card(save_dir=output_dir,
135
+ base_model=base_model,
136
+ training_prompt=training_prompt,
137
+ test_prompt=validation_prompt,
138
+ test_image_dir='samples')
 
 
 
139
 
140
+ message = 'Training completed!'
141
+ print(message)
142
 
143
  if upload_to_hub:
144
+ upload_message = self.model_uploader.upload_model(
145
+ folder_path=output_dir.as_posix(),
146
+ repo_name=output_model_name,
147
  upload_to=upload_to,
148
  private=use_private_repo,
149
  delete_existing_repo=delete_existing_repo,
150
+ input_token=input_token)
151
+ print(upload_message)
152
+ message = message + '\n' + upload_message
153
+
154
+ if remove_gpu_after_training:
155
+ space_id = os.getenv('SPACE_ID')
156
+ if space_id:
157
+ api = HfApi(token=self.hf_token if self.hf_token else input_token)
158
+ api.request_space_hardware(repo_id=space_id,
159
+ hardware='cpu-basic')
160
+
161
+ return message
uploader.py CHANGED
@@ -1,66 +1,44 @@
1
  from __future__ import annotations
2
 
3
- import os
4
- import pathlib
5
- import shlex
6
- import subprocess
7
-
8
- import slugify
9
  from huggingface_hub import HfApi
10
 
11
- from constants import (
12
- MODEL_LIBRARY_ORG_NAME,
13
- URL_TO_JOIN_MODEL_LIBRARY_ORG,
14
- UploadTarget,
15
- )
16
-
17
-
18
- def join_model_library_org(hf_token: str) -> None:
19
- subprocess.run(
20
- shlex.split(
21
- f'curl -X POST -H "Authorization: Bearer {hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
22
- )
23
- )
24
-
25
-
26
- def upload(
27
- local_folder_path: str,
28
- target_repo_name: str,
29
- upload_to: str,
30
- private: bool = True,
31
- delete_existing_repo: bool = False,
32
- hf_token: str = "",
33
- ) -> str:
34
- hf_token = os.getenv("HF_TOKEN") or hf_token
35
- if not hf_token:
36
- raise ValueError
37
- api = HfApi(token=hf_token)
38
-
39
- if not local_folder_path:
40
- raise ValueError
41
- if not target_repo_name:
42
- target_repo_name = pathlib.Path(local_folder_path).name
43
- target_repo_name = slugify.slugify(target_repo_name)
44
-
45
- if upload_to == UploadTarget.PERSONAL_PROFILE.value:
46
- organization = api.whoami()["name"]
47
- elif upload_to == UploadTarget.MODEL_LIBRARY.value:
48
- organization = MODEL_LIBRARY_ORG_NAME
49
- join_model_library_org(hf_token)
50
- else:
51
- raise ValueError
52
 
53
- repo_id = f"{organization}/{target_repo_name}"
54
- if delete_existing_repo:
55
- try:
56
- api.delete_repo(repo_id, repo_type="model")
57
- except Exception:
58
- pass
59
- try:
60
- api.create_repo(repo_id, repo_type="model", private=private)
61
- api.upload_folder(repo_id=repo_id, folder_path=local_folder_path, path_in_repo=".", repo_type="model")
62
- url = f"https://huggingface.co/{repo_id}"
63
- message = f"Your model was successfully uploaded to {url}."
64
- except Exception as e:
65
- message = str(e)
66
- return message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
 
 
 
 
 
 
3
  from huggingface_hub import HfApi
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ class Uploader:
7
+ def __init__(self, hf_token: str | None):
8
+ self.hf_token = hf_token
9
+
10
+ def upload(self,
11
+ folder_path: str,
12
+ repo_name: str,
13
+ organization: str = '',
14
+ repo_type: str = 'model',
15
+ private: bool = True,
16
+ delete_existing_repo: bool = False,
17
+ input_token: str | None = None) -> str:
18
+
19
+ api = HfApi(token=self.hf_token if self.hf_token else input_token)
20
+
21
+ if not folder_path:
22
+ raise ValueError
23
+ if not repo_name:
24
+ raise ValueError
25
+ if not organization:
26
+ organization = api.whoami()['name']
27
+
28
+ repo_id = f'{organization}/{repo_name}'
29
+ if delete_existing_repo:
30
+ try:
31
+ self.api.delete_repo(repo_id, repo_type=repo_type)
32
+ except Exception:
33
+ pass
34
+ try:
35
+ api.create_repo(repo_id, repo_type=repo_type, private=private)
36
+ api.upload_folder(repo_id=repo_id,
37
+ folder_path=folder_path,
38
+ path_in_repo='.',
39
+ repo_type=repo_type)
40
+ url = f'https://huggingface.co/{repo_id}'
41
+ message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
42
+ except Exception as e:
43
+ message = str(e)
44
+ return message
utils.py CHANGED
@@ -5,11 +5,14 @@ import pathlib
5
 
6
  def find_exp_dirs() -> list[str]:
7
  repo_dir = pathlib.Path(__file__).parent
8
- exp_root_dir = repo_dir / "experiments"
9
  if not exp_root_dir.exists():
10
  return []
11
- exp_dirs = sorted(exp_root_dir.glob("*"))
12
- exp_dirs = [exp_dir for exp_dir in exp_dirs if (exp_dir / "model_index.json").exists()]
 
 
 
13
  return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
14
 
15
 
@@ -17,21 +20,21 @@ def save_model_card(
17
  save_dir: pathlib.Path,
18
  base_model: str,
19
  training_prompt: str,
20
- test_prompt: str = "",
21
- test_image_dir: str = "",
22
  ) -> None:
23
- image_str = ""
24
  if test_prompt and test_image_dir:
25
- image_paths = sorted((save_dir / test_image_dir).glob("*.gif"))
26
  if image_paths:
27
  image_path = image_paths[-1]
28
  rel_path = image_path.relative_to(save_dir)
29
- image_str = f"""## Samples
30
  Test prompt: {test_prompt}
31
 
32
- ![{image_path.stem}]({rel_path})"""
33
 
34
- model_card = f"""---
35
  license: creativeml-openrail-m
36
  base_model: {base_model}
37
  training_prompt: {training_prompt}
@@ -56,7 +59,7 @@ inference: false
56
  ## Related papers:
57
  - [Tune-A-Video](https://arxiv.org/abs/2212.11565): One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
58
  - [Stable-Diffusion](https://arxiv.org/abs/2112.10752): High-Resolution Image Synthesis with Latent Diffusion Models
59
- """
60
 
61
- with open(save_dir / "README.md", "w") as f:
62
  f.write(model_card)
 
5
 
6
  def find_exp_dirs() -> list[str]:
7
  repo_dir = pathlib.Path(__file__).parent
8
+ exp_root_dir = repo_dir / 'experiments'
9
  if not exp_root_dir.exists():
10
  return []
11
+ exp_dirs = sorted(exp_root_dir.glob('*'))
12
+ exp_dirs = [
13
+ exp_dir for exp_dir in exp_dirs
14
+ if (exp_dir / 'model_index.json').exists()
15
+ ]
16
  return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
17
 
18
 
 
20
  save_dir: pathlib.Path,
21
  base_model: str,
22
  training_prompt: str,
23
+ test_prompt: str = '',
24
+ test_image_dir: str = '',
25
  ) -> None:
26
+ image_str = ''
27
  if test_prompt and test_image_dir:
28
+ image_paths = sorted((save_dir / test_image_dir).glob('*.gif'))
29
  if image_paths:
30
  image_path = image_paths[-1]
31
  rel_path = image_path.relative_to(save_dir)
32
+ image_str = f'''## Samples
33
  Test prompt: {test_prompt}
34
 
35
+ ![{image_path.stem}]({rel_path})'''
36
 
37
+ model_card = f'''---
38
  license: creativeml-openrail-m
39
  base_model: {base_model}
40
  training_prompt: {training_prompt}
 
59
  ## Related papers:
60
  - [Tune-A-Video](https://arxiv.org/abs/2212.11565): One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
61
  - [Stable-Diffusion](https://arxiv.org/abs/2112.10752): High-Resolution Image Synthesis with Latent Diffusion Models
62
+ '''
63
 
64
+ with open(save_dir / 'README.md', 'w') as f:
65
  f.write(model_card)