vumichien commited on
Commit
bcff9ec
1 Parent(s): 0aa9213
app.py CHANGED
@@ -31,7 +31,7 @@ st.set_page_config(
31
 
32
  save_memory = False
33
 
34
- @st.experimental_singleton
35
  def load_model():
36
  model_path = hf_hub_download('lllyasviel/ControlNet', 'models/control_sd15_scribble.pth')
37
  model = create_model('./models/cldm_v15.yaml').cpu()
@@ -39,7 +39,6 @@ def load_model():
39
  model = model.cuda()
40
  return model
41
 
42
- @st.experimental_singleton
43
  def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
44
  with torch.no_grad():
45
 
@@ -60,7 +59,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
60
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
61
 
62
  if seed == -1:
63
- seed = random.randint(0, 65535)
64
  seed_everything(seed)
65
 
66
  if save_memory:
@@ -105,8 +104,8 @@ def main():
105
  st.header("Generate image with ControllNet")
106
  with st.sidebar:
107
  st_lottie(lottie_penguin, height=200)
108
- choose = option_menu("Generate image", ["Upload", "Canvas"],
109
- icons=['collection', 'file-plus'],
110
  menu_icon="infinity", default_index=0,
111
  styles={
112
  "container": {"padding": ".0rem", "font-size": "14px"},
@@ -158,7 +157,7 @@ def main():
158
  # file_bytes = np.asarray(bytearray(upload_file.read()), dtype=np.uint8)
159
  # imageBGR = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
160
  # input_image = cv2.cvtColor(imageBGR , cv2.COLOR_BGR2RGB)
161
- input_image = np.asarray(Image.open(upload_file))
162
  print("input_image", input_image.shape)
163
 
164
  if generate_button:
@@ -167,20 +166,17 @@ def main():
167
  print("input_image", input_image.shape)
168
  print("results", results[0].shape)
169
  H, W, C = input_image.shape
170
- # output_image = cv2.resize(results[0], (W, H), interpolation=cv2.INTER_AREA)
171
  col11.image(input_image, channels='RGB', width=None, clamp=False, caption='Input image')
172
- col12.image(results[0], channels='RGB', width=None, clamp=False, caption='Generated image')
173
 
174
  elif choose == 'Canvas':
175
- with st.form(key='canvas_form'):
176
  # Specify canvas parameters in application
177
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 3)
178
  stroke_color = st.sidebar.color_picker("Stroke color hex: ")
179
  bg_color = st.sidebar.color_picker("Background color hex: ", "#eee")
180
- bg_height = st.sidebar.slider("Canvas height", min_value=256, max_value=512, value=512, step=64)
181
- bg_width = st.sidebar.slider("Canvas width", min_value=256, max_value=512, value=512, step=64)
182
  realtime_update = st.sidebar.checkbox("Update in realtime", True)
183
-
184
  # Create a canvas component
185
  col31, col32 = st.columns(2)
186
  with col31:
@@ -191,22 +187,25 @@ def main():
191
  background_color=bg_color,
192
  background_image=None,
193
  update_streamlit=realtime_update,
194
- height=bg_height,
195
- width=bg_width,
196
  drawing_mode="freedraw",
197
  point_display_radius=0,
198
  key="canvas",
199
  )
200
- prompt = st.text_input(label="Prompt", placeholder='Type your instruction')
201
-
 
202
  with st.expander('Advanced option', expanded=False):
203
  col41, col42 = st.columns(2)
 
204
  with col41:
205
  image_resolution = st.slider(label="Image Resolution", min_value=256, max_value=512, value=512, step=256)
206
  strength = st.slider(label="Control Strength", min_value=0.0, max_value=2.0, value=1.0, step=0.01)
207
  guess_mode = st.checkbox(label='Guess Mode', value=False)
208
  detect_resolution = st.slider(label="HED Resolution", min_value=128, max_value=1024, value=512, step=1)
209
  ddim_steps = st.slider(label="Steps", min_value=1, max_value=100, value=20, step=1)
 
210
  with col42:
211
  scale = st.slider(label="Guidance Scale", min_value=0.1, max_value=30.0, value=9.0, step=0.1)
212
  seed = st.number_input(label="Seed", min_value=-1, value=-1)
@@ -217,13 +216,27 @@ def main():
217
 
218
  # Do something interesting with the image data and paths
219
  generate_button = st.form_submit_button(label='Generate Image')
220
- if canvas_result.image_data is not None:
221
- input_image = canvas_result.image_data
222
- with st.spinner(text=f"It may take up to 1 minute under high load. Generating images..."):
223
- results = process(input_image, prompt, a_prompt, n_prompt, 1, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta)
224
- H, W, C = input_image.shape
225
- output_image = cv2.resize(results[0], (W, H), interpolation=cv2.INTER_AREA)
226
- col32.image(output_image, channels='RGB', width=384, clamp=True, caption='Generated image')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  if __name__ == '__main__':
229
  main()
 
31
 
32
  save_memory = False
33
 
34
+ @st.experimental_memo
35
  def load_model():
36
  model_path = hf_hub_download('lllyasviel/ControlNet', 'models/control_sd15_scribble.pth')
37
  model = create_model('./models/cldm_v15.yaml').cpu()
 
39
  model = model.cuda()
40
  return model
41
 
 
42
  def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
43
  with torch.no_grad():
44
 
 
59
  control = einops.rearrange(control, 'b h w c -> b c h w').clone()
60
 
61
  if seed == -1:
62
+ seed = random.randint(0, 2147483647)
63
  seed_everything(seed)
64
 
65
  if save_memory:
 
104
  st.header("Generate image with ControllNet")
105
  with st.sidebar:
106
  st_lottie(lottie_penguin, height=200)
107
+ choose = option_menu("Generate image", ["Upload", "Canvas", "Image Gallery"],
108
+ icons=['cloud-upload', 'file-plus', 'collection'],
109
  menu_icon="infinity", default_index=0,
110
  styles={
111
  "container": {"padding": ".0rem", "font-size": "14px"},
 
157
  # file_bytes = np.asarray(bytearray(upload_file.read()), dtype=np.uint8)
158
  # imageBGR = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
159
  # input_image = cv2.cvtColor(imageBGR , cv2.COLOR_BGR2RGB)
160
+ input_image = np.asarray(Image.open(upload_file).convert("RGB"))
161
  print("input_image", input_image.shape)
162
 
163
  if generate_button:
 
166
  print("input_image", input_image.shape)
167
  print("results", results[0].shape)
168
  H, W, C = input_image.shape
169
+ output_image = cv2.resize(results[0], (W, H), interpolation=cv2.INTER_AREA)
170
  col11.image(input_image, channels='RGB', width=None, clamp=False, caption='Input image')
171
+ col12.image(output_image, channels='RGB', width=None, clamp=False, caption='Generated image')
172
 
173
  elif choose == 'Canvas':
174
+ with st.form(key='canvas_generate_form'):
175
  # Specify canvas parameters in application
176
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 3)
177
  stroke_color = st.sidebar.color_picker("Stroke color hex: ")
178
  bg_color = st.sidebar.color_picker("Background color hex: ", "#eee")
 
 
179
  realtime_update = st.sidebar.checkbox("Update in realtime", True)
 
180
  # Create a canvas component
181
  col31, col32 = st.columns(2)
182
  with col31:
 
187
  background_color=bg_color,
188
  background_image=None,
189
  update_streamlit=realtime_update,
190
+ height=512,
191
+ width=512,
192
  drawing_mode="freedraw",
193
  point_display_radius=0,
194
  key="canvas",
195
  )
196
+
197
+ prompt = st.text_input(label="Prompt", placeholder='Type your instruction')
198
+
199
  with st.expander('Advanced option', expanded=False):
200
  col41, col42 = st.columns(2)
201
+
202
  with col41:
203
  image_resolution = st.slider(label="Image Resolution", min_value=256, max_value=512, value=512, step=256)
204
  strength = st.slider(label="Control Strength", min_value=0.0, max_value=2.0, value=1.0, step=0.01)
205
  guess_mode = st.checkbox(label='Guess Mode', value=False)
206
  detect_resolution = st.slider(label="HED Resolution", min_value=128, max_value=1024, value=512, step=1)
207
  ddim_steps = st.slider(label="Steps", min_value=1, max_value=100, value=20, step=1)
208
+
209
  with col42:
210
  scale = st.slider(label="Guidance Scale", min_value=0.1, max_value=30.0, value=9.0, step=0.1)
211
  seed = st.number_input(label="Seed", min_value=-1, value=-1)
 
216
 
217
  # Do something interesting with the image data and paths
218
  generate_button = st.form_submit_button(label='Generate Image')
219
+ if generate_button:
220
+ if canvas_result.image_data is not None:
221
+ input_image = canvas_result.image_data
222
+ with st.spinner(text=f"It may take up to 1 minute under high load. Generating images..."):
223
+ results = process(input_image, prompt, a_prompt, n_prompt, 1, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta)
224
+ H, W, C = input_image.shape
225
+ output_image = cv2.resize(results[0], (W, H), interpolation=cv2.INTER_AREA)
226
+ col32.image(output_image, channels='RGB', width=None, clamp=True, caption='Generated image')
227
+
228
+ elif choose == "Image Gallery":
229
+ with st.expander('Image gallery', expanded=True):
230
+ col01, col02, = st.columns(2)
231
+ with col01:
232
+ st.image('demo/example_1.jpg', caption="Sport car")
233
+ st.image('demo/example_2.jpg', caption="Dog house")
234
+ st.image('demo/example_3.jpg', caption="Guitar")
235
+ with col02:
236
+ st.image('demo/example_4.jpg', caption="Sport car")
237
+ st.image('demo/example_5.jpg', caption="Dog house")
238
+ st.image('demo/example_6.jpg', caption="Guitar")
239
+
240
 
241
  if __name__ == '__main__':
242
  main()
demo/example_1.jpg ADDED
demo/example_2.jpg ADDED
demo/example_3.jpg ADDED
demo/example_4.jpg ADDED
demo/example_5.jpg ADDED
demo/example_6.jpg ADDED