K00B404 commited on
Commit
f2922b7
1 Parent(s): a375431

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -96
app.py CHANGED
@@ -6,7 +6,7 @@ import random
6
  import os
7
  from PIL import Image
8
  from deep_translator import GoogleTranslator
9
- from gradio_client import Client # Import the gradio client for prompt enhancement
10
 
11
  # os.makedirs('assets', exist_ok=True)
12
  if not os.path.exists('icon.jpg'):
@@ -15,108 +15,85 @@ API_URL_DEV = "https://api-inference.huggingface.co/models/black-forest-labs/FLU
15
  API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
16
  timeout = 100
17
 
18
- # Function to set the system prompt once
19
- def set_system_prompt():
20
- client = Client("Qwen/Qwen2.5-72B-Instruct")
21
  result = client.predict(
22
- system="You are Qwen, an image generation prompt enhancer",
23
- api_name="/modify_system_session"
 
 
 
 
24
  )
25
- print(f"System session modified: {result}")
26
  return result
27
 
28
- # Function to enhance the prompt with Qwen model
29
- def enhance_prompt_with_qwen(prompt):
30
- client = Client("Qwen/Qwen2.5-72B-Instruct")
31
- result = client.predict(
32
- query=prompt,
33
- history=[],
34
- system="You are Qwen, an image generation prompt enhancer",
35
- api_name="/model_chat"
36
- )
37
-
38
- # Extract the relevant part of the tuple, index [0], which contains the enhanced prompt.
39
- enhanced_prompt = result[0] # This is the string we need for the image generation prompt.
40
-
41
- print(f"Enhanced prompt: {enhanced_prompt}")
42
- return enhanced_prompt
43
-
44
- # Image generation query function
45
  def query(prompt, is_negative=False, steps=30, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, huggingface_api_key=None, use_dev=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  try:
47
- # Set system prompt first
48
- set_system_prompt()
49
-
50
- # Enhance the prompt before translation
51
- enhanced_prompt = enhance_prompt_with_qwen(prompt)
52
-
53
- # Determine which API URL to use
54
- api_url = API_URL_DEV if use_dev else API_URL
55
-
56
- # Check if the request is an API call by checking for the presence of the huggingface_api_key
57
- is_api_call = huggingface_api_key is not None
58
-
59
- if is_api_call:
60
- # Use the environment variable for the API key in GUI mode
61
- API_TOKEN = os.getenv("HF_READ_TOKEN")
62
- headers = {"Authorization": f"Bearer {API_TOKEN}"}
63
- else:
64
- # Validate the API key if it's an API call
65
- if huggingface_api_key == "":
66
- raise gr.Error("API key is required for API calls.")
67
- headers = {"Authorization": f"Bearer {huggingface_api_key}"}
68
-
69
- if enhanced_prompt == "" or enhanced_prompt is None:
70
- return None, None
71
-
72
- key = random.randint(0, 999)
73
-
74
- # Translate the enhanced prompt
75
- enhanced_prompt = GoogleTranslator(source='ru', target='en').translate(enhanced_prompt)
76
- print(f'\033[1mGeneration {key} translation:\033[0m {enhanced_prompt}')
77
-
78
- enhanced_prompt = f"{enhanced_prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
79
- print(f'\033[1mGeneration {key}:\033[0m {enhanced_prompt}')
80
-
81
- # If seed is -1, generate a random seed and use it
82
- if seed == -1:
83
- seed = random.randint(1, 1000000000)
84
-
85
- payload = {
86
- "inputs": enhanced_prompt,
87
- "is_negative": is_negative,
88
- "steps": steps,
89
- "cfg_scale": cfg_scale,
90
- "seed": seed,
91
- "strength": strength
92
- }
93
-
94
- response = requests.post(api_url, headers=headers, json=payload, timeout=timeout)
95
- if response.status_code != 200:
96
- print(f"Error: Failed to get image. Response status: {response.status_code}")
97
- print(f"Response content: {response.text}")
98
- if response.status_code == 503:
99
- raise gr.Error(f"{response.status_code} : The model is being loaded")
100
- raise gr.Error(f"{response.status_code}")
101
 
102
- try:
103
- # Attempt to open the image
104
- image_bytes = response.content
105
- image = Image.open(io.BytesIO(image_bytes))
106
- print(f'\033[1mGeneration {key} completed!\033[0m ({enhanced_prompt})')
107
-
108
- # Save the image to a file and return the file path and seed
109
- output_path = f"./output_{key}.png"
110
- image.save(output_path)
111
-
112
- return output_path, seed
113
- except Exception as e:
114
- print(f"Error when trying to open the image: {e}")
115
- return None, seed # If the image fails, return None for image, seed is still returned
116
-
117
- except Exception as ex:
118
- print(f"Error in query execution: {ex}")
119
- return None, None # If the entire process fails, return None for both
120
 
121
  css = """
122
  #app-container {
 
6
  import os
7
  from PIL import Image
8
  from deep_translator import GoogleTranslator
9
+ from gradio_client import Client # Import the Gradio Client
10
 
11
  # os.makedirs('assets', exist_ok=True)
12
  if not os.path.exists('icon.jpg'):
 
15
  API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
16
  timeout = 100
17
 
18
+ def enhance_prompt(prompt):
19
+ """Enhance the prompt using the Mistral Nemo prompt enhancer API."""
20
+ client = Client("K00B404/mistral-nemo-prompt-enhancer")
21
  result = client.predict(
22
+ message=prompt,
23
+ system_message="You are an image generation prompt enhancer and should only respond with the enhanced version of the user input image generation prompt.",
24
+ max_tokens=512,
25
+ temperature=0.7,
26
+ top_p=0.95,
27
+ api_name="/chat"
28
  )
 
29
  return result
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def query(prompt, is_negative=False, steps=30, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, huggingface_api_key=None, use_dev=False):
32
+ # Determine which API URL to use
33
+ api_url = API_URL_DEV if use_dev else API_URL
34
+
35
+ # Check if the request is an API call by checking for the presence of the huggingface_api_key
36
+ is_api_call = huggingface_api_key is not None
37
+
38
+ if is_api_call:
39
+ # Use the environment variable for the API key in GUI mode
40
+ API_TOKEN = os.getenv("HF_READ_TOKEN")
41
+ headers = {"Authorization": f"Bearer {API_TOKEN}"}
42
+ else:
43
+ # Validate the API key if it's an API call
44
+ if huggingface_api_key == "":
45
+ raise gr.Error("API key is required for API calls.")
46
+ headers = {"Authorization": f"Bearer {huggingface_api_key}"}
47
+
48
+ if prompt == "" or prompt is None:
49
+ return None
50
+
51
+ key = random.randint(0, 999)
52
+
53
+ prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
54
+ print(f'\033[1mGeneration {key} translation:\033[0m {prompt}')
55
+
56
+ # Enhance the prompt using the API
57
+ enhanced_prompt = enhance_prompt(prompt)
58
+ print(f'\033[1mEnhanced Prompt:\033[0m {enhanced_prompt}')
59
+
60
+ prompt = f"{enhanced_prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
61
+ print(f'\033[1mGeneration {key}:\033[0m {prompt}')
62
+
63
+ # If seed is -1, generate a random seed and use it
64
+ if seed == -1:
65
+ seed = random.randint(1, 1000000000)
66
+
67
+ payload = {
68
+ "inputs": prompt,
69
+ "is_negative": is_negative,
70
+ "steps": steps,
71
+ "cfg_scale": cfg_scale,
72
+ "seed": seed,
73
+ "strength": strength
74
+ }
75
+
76
+ response = requests.post(api_url, headers=headers, json=payload, timeout=timeout)
77
+ if response.status_code != 200:
78
+ print(f"Error: Failed to get image. Response status: {response.status_code}")
79
+ print(f"Response content: {response.text}")
80
+ if response.status_code == 503:
81
+ raise gr.Error(f"{response.status_code} : The model is being loaded")
82
+ raise gr.Error(f"{response.status_code}")
83
+
84
  try:
85
+ image_bytes = response.content
86
+ image = Image.open(io.BytesIO(image_bytes))
87
+ print(f'\033[1mGeneration {key} completed!\033[0m ({prompt})')
88
+
89
+ # Save the image to a file and return the file path and seed
90
+ output_path = f"./output_{key}.png"
91
+ image.save(output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ return output_path, seed
94
+ except Exception as e:
95
+ print(f"Error when trying to open the image: {e}")
96
+ return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  css = """
99
  #app-container {