sagar007 commited on
Commit
6d39491
1 Parent(s): 1bbe2e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -22
app.py CHANGED
@@ -7,6 +7,11 @@ import torch
7
  from diffusers import DiffusionPipeline
8
  import hashlib
9
  import pickle
 
 
 
 
 
10
 
11
  # Authenticate using the token stored in Hugging Face Spaces secrets
12
  if 'HF_TOKEN' in os.environ:
@@ -14,9 +19,9 @@ if 'HF_TOKEN' in os.environ:
14
  else:
15
  raise ValueError("HF_TOKEN not found in environment variables. Please add it to your Space's secrets.")
16
 
17
- base_model = "black-forest-labs/FLUX.1-dev"
18
- lora_model = "sagar007/sagar_flux"
19
- trigger_word = "sagar"
20
 
21
  # Global variables
22
  pipe = None
@@ -67,29 +72,29 @@ def get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale):
67
  @spaces.GPU(duration=80)
68
  def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale):
69
  global pipe, cache
70
-
71
  if randomize_seed:
72
  seed = random.randint(0, 2**32-1)
73
-
74
  cache_key = get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale)
75
-
76
  if cache_key in cache:
77
  print("Using cached image")
78
  return cache[cache_key], seed
79
-
80
  try:
81
  print(f"Starting run_lora with prompt: {prompt}")
82
  if pipe is None:
83
  print("Initializing model...")
84
  initialize_model()
85
-
86
  print(f"Using seed: {seed}")
87
-
88
  generator = torch.Generator(device="cuda").manual_seed(seed)
89
-
90
  full_prompt = f"{prompt} {trigger_word}"
91
  print(f"Full prompt: {full_prompt}")
92
-
93
  print("Starting image generation...")
94
  image = pipe(
95
  prompt=full_prompt,
@@ -100,11 +105,11 @@ def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora
100
  generator=generator,
101
  ).images[0]
102
  print("Image generation completed successfully")
103
-
104
  # Cache the generated image
105
  cache[cache_key] = image
106
  save_cache()
107
-
108
  return image, seed
109
  except Exception as e:
110
  print(f"Error during generation: {str(e)}")
@@ -121,7 +126,9 @@ load_cache()
121
  # Pre-generate and cache example images
122
  def cache_example_images():
123
  for prompt in example_prompts:
124
- run_lora(prompt, 4, 20, False, 42, 1024, 1024, 0.75)
 
 
125
 
126
  # Gradio interface setup
127
  with gr.Blocks() as app:
@@ -134,18 +141,18 @@ with gr.Blocks() as app:
134
  with gr.Column():
135
  result = gr.Image(label="Result")
136
  with gr.Row():
137
- cfg_scale = gr.Slider(minimum=1, maximum=20, value=4, step=0.1, label="CFG Scale")
138
- steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Steps")
139
  with gr.Row():
140
- width = gr.Slider(minimum=128, maximum=1024, value=1024, step=64, label="Width")
141
- height = gr.Slider(minimum=128, maximum=1024, value=1024, step=64, label="Height")
142
  with gr.Row():
143
- seed = gr.Number(label="Seed", precision=0)
144
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
145
  lora_scale = gr.Slider(minimum=0, maximum=1, value=0.75, step=0.01, label="LoRA Scale")
146
-
147
  example_dropdown.change(update_prompt, inputs=[example_dropdown], outputs=[prompt])
148
-
149
  run_button.click(
150
  run_lora,
151
  inputs=[prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
 
7
  from diffusers import DiffusionPipeline
8
  import hashlib
9
  import pickle
10
+ import yaml
11
+
12
+ # Load config file
13
+ with open('config.yaml', 'r') as file:
14
+ config = yaml.safe_load(file)
15
 
16
  # Authenticate using the token stored in Hugging Face Spaces secrets
17
  if 'HF_TOKEN' in os.environ:
 
19
  else:
20
  raise ValueError("HF_TOKEN not found in environment variables. Please add it to your Space's secrets.")
21
 
22
+ base_model = config['config']['model']['name_or_path']
23
+ lora_model = "sagar007/sagar_flux" # This isn't in the config, so we're keeping it as is
24
+ trigger_word = config['config']['trigger_word']
25
 
26
  # Global variables
27
  pipe = None
 
72
  @spaces.GPU(duration=80)
73
  def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale):
74
  global pipe, cache
75
+
76
  if randomize_seed:
77
  seed = random.randint(0, 2**32-1)
78
+
79
  cache_key = get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale)
80
+
81
  if cache_key in cache:
82
  print("Using cached image")
83
  return cache[cache_key], seed
84
+
85
  try:
86
  print(f"Starting run_lora with prompt: {prompt}")
87
  if pipe is None:
88
  print("Initializing model...")
89
  initialize_model()
90
+
91
  print(f"Using seed: {seed}")
92
+
93
  generator = torch.Generator(device="cuda").manual_seed(seed)
94
+
95
  full_prompt = f"{prompt} {trigger_word}"
96
  print(f"Full prompt: {full_prompt}")
97
+
98
  print("Starting image generation...")
99
  image = pipe(
100
  prompt=full_prompt,
 
105
  generator=generator,
106
  ).images[0]
107
  print("Image generation completed successfully")
108
+
109
  # Cache the generated image
110
  cache[cache_key] = image
111
  save_cache()
112
+
113
  return image, seed
114
  except Exception as e:
115
  print(f"Error during generation: {str(e)}")
 
126
  # Pre-generate and cache example images
127
  def cache_example_images():
128
  for prompt in example_prompts:
129
+ run_lora(prompt, config['config']['sample']['guidance_scale'], config['config']['sample']['sample_steps'],
130
+ config['config']['sample']['walk_seed'], config['config']['sample']['seed'],
131
+ config['config']['sample']['width'], config['config']['sample']['height'], 0.75)
132
 
133
  # Gradio interface setup
134
  with gr.Blocks() as app:
 
141
  with gr.Column():
142
  result = gr.Image(label="Result")
143
  with gr.Row():
144
+ cfg_scale = gr.Slider(minimum=1, maximum=20, value=config['config']['sample']['guidance_scale'], step=0.1, label="CFG Scale")
145
+ steps = gr.Slider(minimum=1, maximum=100, value=config['config']['sample']['sample_steps'], step=1, label="Steps")
146
  with gr.Row():
147
+ width = gr.Slider(minimum=128, maximum=1024, value=config['config']['sample']['width'], step=64, label="Width")
148
+ height = gr.Slider(minimum=128, maximum=1024, value=config['config']['sample']['height'], step=64, label="Height")
149
  with gr.Row():
150
+ seed = gr.Number(label="Seed", value=config['config']['sample']['seed'], precision=0)
151
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=config['config']['sample']['walk_seed'])
152
  lora_scale = gr.Slider(minimum=0, maximum=1, value=0.75, step=0.01, label="LoRA Scale")
153
+
154
  example_dropdown.change(update_prompt, inputs=[example_dropdown], outputs=[prompt])
155
+
156
  run_button.click(
157
  run_lora,
158
  inputs=[prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],