import os import pandas as pd import torch import gc import re import random from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer from diffusers import StableDiffusionPipeline import gradio as gr # Initialize the text generation pipeline with the pre-quantized 8-bit model model_name = 'HuggingFaceTB/SmolLM-1.7B-Instruct' model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=-1) # Use CPU # Load the Stable Diffusion model model_id = "stabilityai/stable-diffusion-2-1-base" # Smaller model pipe = StableDiffusionPipeline.from_pretrained(model_id) pipe = pipe.to("cpu") # Use CPU # Create a directory to save the generated images output_dir = 'generated_images' os.makedirs(output_dir, exist_ok=True) os.chmod(output_dir, 0o777) # Function to generate a detailed visual description prompt def generate_description_prompt(user_prompt, user_examples): prompt = f'generate enclosed in quotes in the format "" description according to guidelines of {user_prompt} different from {user_examples}' try: generated_text = text_generator(prompt, max_length=150, num_return_sequences=1, truncation=True)[0]['generated_text'] match = re.search(r'"(.*?)"', generated_text) if match: generated_description = match.group(1).strip() # Capture the description between quotes return f'"{generated_description}"' else: return None except Exception as e: print(f"Error generating description for prompt '{user_prompt}': {e}") return None # Seed words pool seed_words = [] used_words = set() def generate_description(user_prompt, user_examples_list): seed_words.extend(user_examples_list) # Select a subject that has not been used available_subjects = [word for word in seed_words if word not in used_words] if not available_subjects: print("No more available subjects to use.") return None, None subject = random.choice(available_subjects) generated_description = generate_description_prompt(user_prompt, subject) if generated_description: # Remove any offending symbols clean_description = generated_description.encode('ascii', 'ignore').decode('ascii') # Print the generated description to the command line print(f"Generated description for subject '{subject}': {clean_description}") # Update used words and seed words used_words.add(subject) seed_words.append(clean_description.strip('"')) # Add the generated description to the seed bank array without quotes return clean_description, subject else: return None, None # Function to generate an image based on the description def generate_image(description, seed=42): prompt = f'detailed photorealistic full shot of {description}' generator = torch.Generator().manual_seed(seed) image = pipe( prompt=prompt, width=512, height=512, num_inference_steps=10, # Use 10 inference steps generator=generator, guidance_scale=7.5, ).images[0] return image # Gradio interface def gradio_interface(user_prompt, user_examples): user_examples_list = [example.strip().strip('"') for example in user_examples.split(',')] generated_description, subject = generate_description(user_prompt, user_examples_list) if generated_description: # Generate image image = generate_image(generated_description) image_path = os.path.join(output_dir, f"image_{len(os.listdir(output_dir))}.png") image.save(image_path) os.chmod(image_path, 0o777) return image, generated_description else: return None, "Failed to generate description." iface = gr.Interface( fn=gradio_interface, inputs=[ gr.Textbox(lines=2, placeholder="Enter the generation task or general thing you are looking for"), gr.Textbox(lines=2, placeholder='Provide a few examples (enclosed in quotes and separated by commas)') ], outputs=[ gr.Image(label="Generated Image"), gr.Textbox(label="Generated Description") ], title="Description and Image Generator", description="Generate detailed descriptions and images based on your input." ) iface.launch(server_name="0.0.0.0", server_port=7860) # Clear GPU memory when the process is closed def clear_gpu_memory(): torch.cuda.empty_cache() gc.collect() print("GPU memory cleared.")