gokaygokay commited on
Commit
9c8e948
1 Parent(s): f93e9ec
Files changed (2) hide show
  1. app.py +135 -158
  2. requirements.txt +11 -4
app.py CHANGED
@@ -1,181 +1,158 @@
1
  import spaces
2
- import json
3
- import subprocess
4
- from llama_cpp import Llama
5
- from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
6
- from llama_cpp_agent.providers import LlamaCppPythonProvider
7
- from llama_cpp_agent.chat_history import BasicChatHistory
8
- from llama_cpp_agent.chat_history.messages import Roles
9
  import gradio as gr
10
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
11
 
12
- llm = None
13
- llm_model = None
14
 
15
- hf_hub_download(
16
- repo_id="unsloth/Reflection-Llama-3.1-70B-GGUF",
17
- filename="Reflection-Llama-3.1-70B.Q3_K_L.gguf",
18
- local_dir = "./models"
19
- )
20
 
21
- hf_hub_download(
22
- repo_id="jhofseth/Reflection-Llama-3.1-70B-GGUF",
23
- filename="Reflection-Llama-3.1-70B-IQ3_XXS.gguf",
24
- local_dir = "./models"
25
- )
26
 
27
- hf_hub_download(
28
- repo_id="bartowski/Reflection-Llama-3.1-70B-GGUF",
29
- filename="Reflection-Llama-3.1-70B.imatrix",
30
- local_dir = "./random"
31
- )
32
 
33
- def get_messages_formatter_type(model_name):
34
- if "Llama" in model_name:
35
- return MessagesFormatterType.LLAMA_3
36
- else:
37
- raise ValueError(f"Unsupported model: {model_name}")
38
 
 
 
39
 
 
 
 
 
40
  @spaces.GPU
41
- def respond(
42
- message,
43
- history: list[tuple[str, str]],
44
- model,
45
- system_message,
46
- max_tokens,
47
- temperature,
48
- top_p,
49
- top_k,
50
- repeat_penalty,
51
- ):
52
- global llm
53
- global llm_model
54
-
55
- chat_template = get_messages_formatter_type(model)
56
 
57
- if llm is None or llm_model != model:
58
- llm = Llama(
59
- model_path=f"models/{model}",
60
- flash_attn=True,
61
- n_gpu_layers=81,
62
- n_batch=1024,
63
- n_ctx=8192,
64
- )
65
- llm_model = model
66
-
67
- provider = LlamaCppPythonProvider(llm)
68
-
69
- agent = LlamaCppAgent(
70
- provider,
71
- system_prompt=f"{system_message}",
72
- predefined_messages_formatter_type=chat_template,
73
- debug_output=True
74
  )
75
-
76
- settings = provider.get_provider_default_settings()
77
- settings.temperature = temperature
78
- settings.top_k = top_k
79
- settings.top_p = top_p
80
- settings.max_tokens = max_tokens
81
- settings.repeat_penalty = repeat_penalty
82
- settings.stream = True
83
 
84
- messages = BasicChatHistory()
 
 
 
 
85
 
86
- for msn in history:
87
- user = {
88
- 'role': Roles.user,
89
- 'content': msn[0]
90
- }
91
- assistant = {
92
- 'role': Roles.assistant,
93
- 'content': msn[1]
94
- }
95
- messages.add_message(user)
96
- messages.add_message(assistant)
97
 
98
- stream = agent.get_chat_response(
99
- message,
100
- llm_sampling_settings=settings,
101
- chat_history=messages,
102
- returns_streaming_generator=True,
103
- print_output=False
104
- )
105
 
106
- outputs = ""
107
- for output in stream:
108
- outputs += output
109
- yield outputs
 
 
 
 
 
 
 
110
 
111
- description = """<p><center>
112
- <a href="https://huggingface.co/mattshumer/ref_70_e3" target="_blank">[Reflection Llama 3.1 70B Correct Weights]</a>
113
- <a href="https://huggingface.co/mattshumer/Reflection-Llama-3.1-70B" target="_blank">[Old Repo]</a>
114
- <a href="https://huggingface.co/unsloth/Reflection-Llama-3.1-70B-GGUF" target="_blank">[Reflection-Llama-3.1-70B-GGUF]</a>
 
 
 
 
 
 
 
 
 
 
 
 
115
 
 
 
 
 
 
 
116
  </center></p>
117
  """
118
 
119
- demo = gr.ChatInterface(
120
- respond,
121
- additional_inputs=[
122
- gr.Dropdown([
123
- "Reflection-Llama-3.1-70B.Q3_K_L.gguf",
124
- "Reflection-Llama-3.1-70B-IQ3_XXS.gguf"
125
- ],
126
- value="Reflection-Llama-3.1-70B.Q3_K_L.gguf",
127
- label="Model"
128
- ),
129
- gr.Textbox(value="You are a world-class AI system, capable of complex reasoning and reflection. Reason through the query inside <thinking> tags, and then provide your final response inside <output> tags. If you detect that you made a mistake in your reasoning at any point, correct yourself inside <reflection> tags.", label="System message"),
130
- gr.Slider(minimum=1, maximum=8192, value=2048, step=1, label="Max tokens"),
131
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
132
- gr.Slider(
133
- minimum=0.1,
134
- maximum=1.0,
135
- value=0.95,
136
- step=0.05,
137
- label="Top-p",
138
- ),
139
- gr.Slider(
140
- minimum=0,
141
- maximum=100,
142
- value=40,
143
- step=1,
144
- label="Top-k",
145
- ),
146
- gr.Slider(
147
- minimum=0.0,
148
- maximum=2.0,
149
- value=1.1,
150
- step=0.1,
151
- label="Repetition penalty",
152
- ),
153
- ],
154
- theme=gr.themes.Soft(primary_hue="violet", secondary_hue="violet", neutral_hue="gray",font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
155
- body_background_fill_dark="#16141c",
156
- block_background_fill_dark="#16141c",
157
- block_border_width="1px",
158
- block_title_background_fill_dark="#1e1c26",
159
- input_background_fill_dark="#292733",
160
- button_secondary_background_fill_dark="#24212b",
161
- border_color_accent_dark="#343140",
162
- border_color_primary_dark="#343140",
163
- background_fill_secondary_dark="#16141c",
164
- color_accent_soft_dark="transparent",
165
- code_background_fill_dark="#292733",
166
- ),
167
- retry_btn="Retry",
168
- undo_btn="Undo",
169
- clear_btn="Clear",
170
- submit_btn="Send",
171
- title="Reflection Llama-3.1 70B",
172
- description=description,
173
- chatbot=gr.Chatbot(
174
- scale=1,
175
- likeable=False,
176
- show_copy_button=True
177
  )
178
- )
179
 
180
- if __name__ == "__main__":
181
- demo.launch()
 
1
  import spaces
 
 
 
 
 
 
 
2
  import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
6
+ from diffusers import DiffusionPipeline
7
+ import random
8
+ import numpy as np
9
+ import os
10
+ import subprocess
11
 
12
+ # Install flash-attn
13
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
14
 
15
+ # Initialize models
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
 
18
 
19
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
 
 
 
20
 
21
+ # SD3.5 model
22
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-3.5-large", torch_dtype=dtype, use_safetensors=True, variant="fp16", token=huggingface_token).to(device)
 
 
 
23
 
24
+ # Initialize Florence model
25
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
26
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
 
 
27
 
28
+ # Prompt Enhancer
29
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
30
 
31
+ MAX_SEED = np.iinfo(np.int32).max
32
+ MAX_IMAGE_SIZE = 1024
33
+
34
+ # Florence caption function
35
  @spaces.GPU
36
+ def florence_caption(image):
37
+ # Convert image to PIL if it's not already
38
+ if not isinstance(image, Image.Image):
39
+ image = Image.fromarray(image)
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
42
+ generated_ids = florence_model.generate(
43
+ input_ids=inputs["input_ids"],
44
+ pixel_values=inputs["pixel_values"],
45
+ max_new_tokens=1024,
46
+ early_stopping=False,
47
+ do_sample=False,
48
+ num_beams=3,
 
 
 
 
 
 
 
 
 
49
  )
50
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
51
+ parsed_answer = florence_processor.post_process_generation(
52
+ generated_text,
53
+ task="<MORE_DETAILED_CAPTION>",
54
+ image_size=(image.width, image.height)
55
+ )
56
+ return parsed_answer["<MORE_DETAILED_CAPTION>"]
 
57
 
58
+ # Prompt Enhancer function
59
+ def enhance_prompt(input_prompt):
60
+ result = enhancer_long("Enhance the description: " + input_prompt)
61
+ enhanced_text = result[0]['summary_text']
62
+ return enhanced_text
63
 
64
+ @spaces.GPU(duration=190)
65
+ def process_workflow(image, text_prompt, use_enhancer, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, negative_prompt="", progress=gr.Progress(track_tqdm=True)):
66
+ if image is not None:
67
+ # Convert image to PIL if it's not already
68
+ if not isinstance(image, Image.Image):
69
+ image = Image.fromarray(image)
70
+
71
+ prompt = florence_caption(image)
72
+ print(prompt)
73
+ else:
74
+ prompt = text_prompt
75
 
76
+ if use_enhancer:
77
+ prompt = enhance_prompt(prompt)
78
+
79
+ if randomize_seed:
80
+ seed = random.randint(0, MAX_SEED)
81
+
82
+ generator = torch.Generator(device=device).manual_seed(seed)
83
 
84
+ image = pipe(
85
+ prompt=prompt,
86
+ negative_prompt=negative_prompt,
87
+ generator=generator,
88
+ num_inference_steps=num_inference_steps,
89
+ width=width,
90
+ height=height,
91
+ guidance_scale=guidance_scale
92
+ ).images[0]
93
+
94
+ return image, prompt, seed
95
 
96
+ custom_css = """
97
+ .input-group, .output-group {
98
+ border: 1px solid #e0e0e0;
99
+ border-radius: 10px;
100
+ padding: 20px;
101
+ margin-bottom: 20px;
102
+ background-color: #f9f9f9;
103
+ }
104
+ .submit-btn {
105
+ background-color: #2980b9 !important;
106
+ color: white !important;
107
+ }
108
+ .submit-btn:hover {
109
+ background-color: #3498db !important;
110
+ }
111
+ """
112
 
113
+ title = """<h1 align="center">Stable Diffusion 3.5 with Florence-2 Captioner and Prompt Enhancer</h1>
114
+ <p><center>
115
+ <a href="https://huggingface.co/stabilityai/stable-diffusion-3.5-large" target="_blank">[Stable Diffusion 3.5 Model]</a>
116
+ <a href="https://huggingface.co/microsoft/Florence-2-base" target="_blank">[Florence-2 Model]</a>
117
+ <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
118
+ <p align="center">Create long prompts from images or enhance your short prompts with prompt enhancer</p>
119
  </center></p>
120
  """
121
 
122
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
123
+ gr.HTML(title)
124
+
125
+ with gr.Row():
126
+ with gr.Column(scale=1):
127
+ with gr.Group(elem_classes="input-group"):
128
+ input_image = gr.Image(label="Input Image (Florence-2 Captioner)")
129
+
130
+ with gr.Accordion("Advanced Settings", open=False):
131
+ text_prompt = gr.Textbox(label="Text Prompt (optional, used if no image is uploaded)")
132
+ negative_prompt = gr.Textbox(label="Negative Prompt")
133
+ use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
134
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
135
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
136
+ width = gr.Slider(label="Width", minimum=512, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
137
+ height = gr.Slider(label="Height", minimum=512, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
138
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=7.5, step=0.1, value=4.5)
139
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=40)
140
+
141
+ generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
142
+
143
+ with gr.Column(scale=1):
144
+ with gr.Group(elem_classes="output-group"):
145
+ output_image = gr.Image(label="Result", elem_id="gallery", show_label=False)
146
+ final_prompt = gr.Textbox(label="Final Prompt Used")
147
+ used_seed = gr.Number(label="Seed Used")
148
+
149
+ generate_btn.click(
150
+ fn=process_workflow,
151
+ inputs=[
152
+ input_image, text_prompt, use_enhancer, seed, randomize_seed,
153
+ width, height, guidance_scale, num_inference_steps, negative_prompt
154
+ ],
155
+ outputs=[output_image, final_prompt, used_seed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  )
 
157
 
158
+ demo.launch(debug=True)
 
requirements.txt CHANGED
@@ -1,4 +1,11 @@
1
- huggingface_hub==0.22.2
2
- scikit-build-core
3
- https://github.com/abetlen/llama-cpp-python/releases/download/v0.2.90-cu124/llama_cpp_python-0.2.90-cp310-cp310-linux_x86_64.whl
4
- llama-cpp-agent>=0.2.25
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ torch
4
+ transformers
5
+ git+https://github.com/huggingface/diffusers.git
6
+ sentencepiece
7
+ spaces
8
+ xformers
9
+ sentencepiece
10
+ timm
11
+ einops