sagar007's picture
Update app.py
3cabadc verified
raw
history blame
No virus
6.69 kB
import os
import random
import gradio as gr
from huggingface_hub import login, hf_hub_download
import spaces
import torch
from diffusers import DiffusionPipeline
import hashlib
import pickle
import yaml
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Load config file
with open('config.yaml', 'r') as file:
config = yaml.safe_load(file)
# Authenticate using the token stored in Hugging Face Spaces secrets
if 'HF_TOKEN' in os.environ:
login(token=os.environ['HF_TOKEN'])
logging.info("Successfully logged in with HF_TOKEN")
else:
logging.warning("HF_TOKEN not found in environment variables. Some functionality may be limited.")
# Correctly access the config values
process_config = config['config']['process'][0] # Assuming the first process is the one we want
base_model = "black-forest-labs/FLUX.1-dev"
lora_model = "sagar007/sagar_flux" # This isn't in the config, so we're keeping it as is
trigger_word = process_config['trigger_word']
logging.info(f"Base model: {base_model}")
logging.info(f"LoRA model: {lora_model}")
logging.info(f"Trigger word: {trigger_word}")
# Global variables
pipe = None
cache = {}
CACHE_FILE = "image_cache.pkl"
# Example prompts
example_prompts = [
"Photos of sagar as superman flying in the sky, cape billowing in the wind, sagar",
"Professional photo of sagar for LinkedIn headshot, DSLR quality, neutral background, sagar",
"Sagar as an astronaut exploring a distant alien planet, vibrant colors, sagar",
"Sagar hiking in a lush green forest, sunlight filtering through the trees, sagar",
"Sagar as a wizard casting a spell, magical energy swirling around, sagar",
"Sagar scoring a goal in a dramatic soccer match, stadium lights shining, sagar",
"Sagar as a Roman emperor addressing a crowd, wearing a toga and laurel wreath, sagar"
]
def initialize_model():
global pipe
if pipe is None:
try:
logging.info(f"Attempting to load model: {base_model}")
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16, use_safetensors=True)
logging.info("Moving model to CUDA...")
pipe = pipe.to("cuda")
logging.info(f"Successfully loaded model: {base_model}")
except Exception as e:
logging.error(f"Error loading model {base_model}: {str(e)}")
raise
def load_cache():
global cache
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, 'rb') as f:
cache = pickle.load(f)
logging.info(f"Loaded {len(cache)} cached images")
def save_cache():
with open(CACHE_FILE, 'wb') as f:
pickle.dump(cache, f)
logging.info(f"Saved {len(cache)} cached images")
def get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale):
return hashlib.md5(f"{prompt}{cfg_scale}{steps}{seed}{width}{height}{lora_scale}".encode()).hexdigest()
@spaces.GPU(duration=80)
def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale):
global pipe, cache
if randomize_seed:
seed = random.randint(0, 2**32-1)
cache_key = get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale)
if cache_key in cache:
logging.info("Using cached image")
return cache[cache_key], seed
try:
logging.info(f"Starting run_lora with prompt: {prompt}")
if pipe is None:
logging.info("Initializing model...")
initialize_model()
logging.info(f"Using seed: {seed}")
generator = torch.Generator(device="cuda").manual_seed(seed)
full_prompt = f"{prompt} {trigger_word}"
logging.info(f"Full prompt: {full_prompt}")
logging.info("Starting image generation...")
image = pipe(
prompt=full_prompt,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
).images[0]
logging.info("Image generation completed successfully")
# Cache the generated image
cache[cache_key] = image
save_cache()
return image, seed
except Exception as e:
logging.error(f"Error during generation: {str(e)}")
import traceback
logging.error(traceback.format_exc())
return None, seed
def update_prompt(example):
return example
# Load cache at startup
load_cache()
# Pre-generate and cache example images
def cache_example_images():
for prompt in example_prompts:
run_lora(prompt, process_config['sample']['guidance_scale'], process_config['sample']['sample_steps'],
process_config['sample']['walk_seed'], process_config['sample']['seed'],
process_config['sample']['width'], process_config['sample']['height'], 0.75)
# Gradio interface setup
with gr.Blocks() as app:
gr.Markdown("# Text-to-Image Generation with FLUX (ZeroGPU)")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
example_dropdown = gr.Dropdown(choices=example_prompts, label="Example Prompts")
run_button = gr.Button("Generate")
with gr.Column():
result = gr.Image(label="Result")
with gr.Row():
cfg_scale = gr.Slider(minimum=1, maximum=20, value=process_config['sample']['guidance_scale'], step=0.1, label="CFG Scale")
steps = gr.Slider(minimum=1, maximum=100, value=process_config['sample']['sample_steps'], step=1, label="Steps")
with gr.Row():
width = gr.Slider(minimum=128, maximum=1024, value=process_config['sample']['width'], step=64, label="Width")
height = gr.Slider(minimum=128, maximum=1024, value=process_config['sample']['height'], step=64, label="Height")
with gr.Row():
seed = gr.Number(label="Seed", value=process_config['sample']['seed'], precision=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=process_config['sample']['walk_seed'])
lora_scale = gr.Slider(minimum=0, maximum=1, value=0.75, step=0.01, label="LoRA Scale")
example_dropdown.change(update_prompt, inputs=[example_dropdown], outputs=[prompt])
run_button.click(
run_lora,
inputs=[prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale],
outputs=[result, seed]
)
# Launch the app
if __name__ == "__main__":
logging.info("Starting the Gradio app...")
logging.info("Pre-generating example images...")
cache_example_images()
app.launch(share=True)
logging.info("Gradio app launched successfully")