File size: 4,921 Bytes
3ee9503
 
cb11302
d927270
cb11302
3ee9503
01732e2
3ee9503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d927270
 
 
52bfc97
d927270
 
 
 
 
 
 
 
 
 
cb11302
 
 
 
 
 
d927270
cb11302
 
 
 
 
 
d927270
 
 
 
 
 
 
3ee9503
cb11302
3ee9503
 
 
 
 
 
 
cb11302
 
 
 
 
 
3f58313
cb11302
d29e675
3ee9503
 
 
 
 
 
 
 
 
bed1a2d
cb11302
3ee9503
 
 
cb11302
3ee9503
 
 
 
 
 
cb11302
3ee9503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb11302
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_community.llms import LlamaCpp
from huggingface_hub.file_download import http_get
# from llama_cpp import Llama
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
from langchain_core.prompts import ChatPromptTemplate

import os
import fal_client

# FastAPI app
app = FastAPI()

# Set the environment variable
os.environ['FAL_KEY'] = 'bb79b746-999d-4bec-af22-04fddb05d087:49350e8b76fd8dda0fb9dd8442a9ccf5'

# Request body model
class StoryRequest(BaseModel):
    mood: str
    story_type: str
    theme: str
    num_scenes: int
    txt: str

# Initialize the LLM
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])

def load_model(
    directory: str = ".",
    model_name: str = "natsumura-storytelling-rp-1.0-llama-3.1-8B.Q3_K_M.gguf",
    model_url: str = "https://huggingface.co/tohur/natsumura-storytelling-rp-1.0-llama-3.1-8b-GGUF/resolve/main/natsumura-storytelling-rp-1.0-llama-3.1-8B.Q3_K_M.gguf"
):
    final_model_path = os.path.join(directory, model_name)
    
    print("Downloading all files...")
    if not os.path.exists(final_model_path):
        with open(final_model_path, "wb") as f:
            http_get(model_url, f)
    os.chmod(final_model_path, 0o777)
    print("Files downloaded!")
    
    # model = Llama(
    #     model_path=final_model_path,
    #     n_ctx=1024
    # )

    model = LlamaCpp(
        model_path=final_model_path,
        temperature=0.3,
        max_tokens=2000,
        top_p=1,
        n_ctx=1024,
        callback_manager=callback_manager,
        verbose=True,
    )
    
    print("Model loaded!")
    return model


llm = load_model()


# Create a prompt template
# system = """You are a helpful and creative assistant that specializes in generating engaging and imaginative stories for kids.
# Based on the user's provided mood, preferred story type, theme, age, and desired story length of 500-600 words, create a unique and captivating story.
# Always start with Story Title then generate a single story and dont ask for any feedback at the end just sign off with a cute closing inviting the reader 
# to create another adventure soon! 
# """

system = """You are a helpful and creative assistant that specializes in generating engaging and imaginative short storie for kids.
Based on the user's provided mood, preferred story type, theme, age, and desired story length of 500-600 words, create a unique and captivating story.
Always start with Story Title then generate a single story.Storie begin on Page 1(also mention the all pages headings in bold) and end on Page 7.
Total pages in storie are seven each page have one short paragraph and dont ask for any feedback at the end just sign off with a cute closing inviting the reader 
to create another adventure soon! 
"""

prompt_template = ChatPromptTemplate.from_messages([("system", system), ("human", "{text}")])

# FastAPI endpoint to generate the story
@app.post("/generate_story/")
async def generate_story(story_request: StoryRequest):
    story = f"""here are the inputs from user:
    - **Mood:** {story_request.mood}
    - **Story Type:** {story_request.story_type}
    - **Theme:** {story_request.theme}
    - **Details Provided:** {story_request.txt}
    """
    
    final_prompt = prompt_template.format(text=story)

    # Create the LLMChain
    # chain = LLMChain(llm=llm, prompt=prompt_template)
    chain = llm | prompt_template
    
    # try:
    #     response = chain.invoke(final_prompt)
    #     return {"story": response}
    # except Exception as e:
    #     raise HTTPException(status_code=500, detail=str(e))
    response = chain.invoke(final_prompt)
    
    if not response:
        raise HTTPException(status_code=500, detail="Failed to generate the story")
    
    images = []
    for i in range(story_request.num_scenes):
        # image_prompt = f"Generate an image for Scene {i+1} based on this story: Mood: {story_request.mood}, Story Type: {story_request.story_type}, Theme: {story_request.theme}. Story: {response}"
        image_prompt = (
        f"Generate an image for Scene {i+1}. "
        f"This image should represent the details described in paragraph {i+1} of the story. "
        f"Mood: {story_request.mood}, Story Type: {', '.join(story_request.story_type)}, Theme: {story_request.theme}. "
        f"Story: {response} "
        f"Focus on the key elements in paragraph {i+1}."
        )
        handler = fal_client.submit(
            "fal-ai/flux/schnell",
            arguments={
                "prompt": image_prompt,
                "num_images": 1,
                "enable_safety_checker": True
            },
        )
        result = handler.get()
        image_url = result['images'][0]['url']
        images.append(image_url)
    
    return {
        "story": response,
        "images": images
    }