import os import random import gradio as gr from huggingface_hub import login, hf_hub_download import spaces import torch from diffusers import DiffusionPipeline import hashlib import pickle import yaml # Load config file with open('config.yaml', 'r') as file: config = yaml.safe_load(file) # Authenticate using the token stored in Hugging Face Spaces secrets if 'HF_TOKEN' in os.environ: login(token=os.environ['HF_TOKEN']) else: raise ValueError("HF_TOKEN not found in environment variables. Please add it to your Space's secrets.") # Correctly access the config values process_config = config['config']['process'][0] # Assuming the first process is the one we want base_model = process_config['model']['name_or_path'] lora_model = "sagar007/sagar_flux" # This isn't in the config, so we're keeping it as is trigger_word = process_config['trigger_word'] # Global variables pipe = None cache = {} CACHE_FILE = "image_cache.pkl" # Example prompts example_prompts = [ "Photos of sagar as superman flying in the sky, cape billowing in the wind, sagar", "Professional photo of sagar for LinkedIn headshot, DSLR quality, neutral background, sagar", "Sagar as an astronaut exploring a distant alien planet, vibrant colors, sagar", "Sagar hiking in a lush green forest, sunlight filtering through the trees, sagar", "Sagar as a wizard casting a spell, magical energy swirling around, sagar", "Sagar scoring a goal in a dramatic soccer match, stadium lights shining, sagar", "Sagar as a Roman emperor addressing a crowd, wearing a toga and laurel wreath, sagar" ] def initialize_model(): global pipe if pipe is None: try: print("Loading base model...") pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16) print("Moving model to CUDA...") pipe = pipe.to("cuda") print(f"Successfully loaded base model: {base_model}") except Exception as e: print(f"Error initializing model: {str(e)}") import traceback print(traceback.format_exc()) raise def load_cache(): global cache if os.path.exists(CACHE_FILE): with open(CACHE_FILE, 'rb') as f: cache = pickle.load(f) print(f"Loaded {len(cache)} cached images") def save_cache(): with open(CACHE_FILE, 'wb') as f: pickle.dump(cache, f) print(f"Saved {len(cache)} cached images") def get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale): return hashlib.md5(f"{prompt}{cfg_scale}{steps}{seed}{width}{height}{lora_scale}".encode()).hexdigest() @spaces.GPU(duration=80) def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale): global pipe, cache if randomize_seed: seed = random.randint(0, 2**32-1) cache_key = get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale) if cache_key in cache: print("Using cached image") return cache[cache_key], seed try: print(f"Starting run_lora with prompt: {prompt}") if pipe is None: print("Initializing model...") initialize_model() print(f"Using seed: {seed}") generator = torch.Generator(device="cuda").manual_seed(seed) full_prompt = f"{prompt} {trigger_word}" print(f"Full prompt: {full_prompt}") print("Starting image generation...") image = pipe( prompt=full_prompt, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, ).images[0] print("Image generation completed successfully") # Cache the generated image cache[cache_key] = image save_cache() return image, seed except Exception as e: print(f"Error during generation: {str(e)}") import traceback print(traceback.format_exc()) return None, seed def update_prompt(example): return example # Load cache at startup load_cache() # Pre-generate and cache example images def cache_example_images(): for prompt in example_prompts: run_lora(prompt, process_config['sample']['guidance_scale'], process_config['sample']['sample_steps'], process_config['sample']['walk_seed'], process_config['sample']['seed'], process_config['sample']['width'], process_config['sample']['height'], 0.75) # Gradio interface setup with gr.Blocks() as app: gr.Markdown("# Text-to-Image Generation with FLUX (ZeroGPU)") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt") example_dropdown = gr.Dropdown(choices=example_prompts, label="Example Prompts") run_button = gr.Button("Generate") with gr.Column(): result = gr.Image(label="Result") with gr.Row(): cfg_scale = gr.Slider(minimum=1, maximum=20, value=process_config['sample']['guidance_scale'], step=0.1, label="CFG Scale") steps = gr.Slider(minimum=1, maximum=100, value=process_config['sample']['sample_steps'], step=1, label="Steps") with gr.Row(): width = gr.Slider(minimum=128, maximum=1024, value=process_config['sample']['width'], step=64, label="Width") height = gr.Slider(minimum=128, maximum=1024, value=process_config['sample']['height'], step=64, label="Height") with gr.Row(): seed = gr.Number(label="Seed", value=process_config['sample']['seed'], precision=0) randomize_seed = gr.Checkbox(label="Randomize seed", value=process_config['sample']['walk_seed']) lora_scale = gr.Slider(minimum=0, maximum=1, value=0.75, step=0.01, label="LoRA Scale") example_dropdown.change(update_prompt, inputs=[example_dropdown], outputs=[prompt]) run_button.click( run_lora, inputs=[prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale], outputs=[result, seed] ) # Launch the app if __name__ == "__main__": print("Starting the Gradio app...") print("Pre-generating example images...") cache_example_images() app.launch(share=True) print("Gradio app launched successfully")