File size: 6,098 Bytes
aa8012e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import io
import os
import time
from pathlib import Path

import requests
from PIL import Image

API_ENDPOINT = "https://api.bfl.ml"


class ApiException(Exception):
    def __init__(self, status_code: int, detail: str = None):
        super().__init__()
        self.detail = detail
        self.status_code = status_code

    def __str__(self) -> str:
        return self.__repr__()

    def __repr__(self) -> str:
        if self.detail is None:
            message = None
        elif isinstance(self.detail, str):
            message = self.detail
        else:
            message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
        return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"


class ImageRequest:
    def __init__(
        self,
        prompt: str,
        width: int = 1024,
        height: int = 1024,
        name: str = "flux.1-pro",
        num_steps: int = 50,
        prompt_upsampling: bool = False,
        seed: int = None,
        validate: bool = True,
        launch: bool = True,
        api_key: str = None,
    ):
        """
        Manages an image generation request to the API.

        Args:
            prompt: Prompt to sample
            width: Width of the image in pixel
            height: Height of the image in pixel
            name: Name of the model
            num_steps: Number of network evaluations
            prompt_upsampling: Use prompt upsampling
            seed: Fix the generation seed
            validate: Run input validation
            launch: Directly launches request
            api_key: Your API key if not provided by the environment

        Raises:
            ValueError: For invalid input
            ApiException: For errors raised from the API
        """
        if validate:
            if name not in ["flux.1-pro"]:
                raise ValueError(f"Invalid model {name}")
            elif width % 32 != 0:
                raise ValueError(f"width must be divisible by 32, got {width}")
            elif not (256 <= width <= 1440):
                raise ValueError(f"width must be between 256 and 1440, got {width}")
            elif height % 32 != 0:
                raise ValueError(f"height must be divisible by 32, got {height}")
            elif not (256 <= height <= 1440):
                raise ValueError(f"height must be between 256 and 1440, got {height}")
            elif not (1 <= num_steps <= 50):
                raise ValueError(f"steps must be between 1 and 50, got {num_steps}")

        self.request_json = {
            "prompt": prompt,
            "width": width,
            "height": height,
            "variant": name,
            "steps": num_steps,
            "prompt_upsampling": prompt_upsampling,
        }
        if seed is not None:
            self.request_json["seed"] = seed

        self.request_id: str = None
        self.result: dict = None
        self._image_bytes: bytes = None
        self._url: str = None
        if api_key is None:
            self.api_key = os.environ.get("BFL_API_KEY")
        else:
            self.api_key = api_key

        if launch:
            self.request()

    def request(self):
        """
        Request to generate the image.
        """
        if self.request_id is not None:
            return
        response = requests.post(
            f"{API_ENDPOINT}/v1/image",
            headers={
                "accept": "application/json",
                "x-key": self.api_key,
                "Content-Type": "application/json",
            },
            json=self.request_json,
        )
        result = response.json()
        if response.status_code != 200:
            raise ApiException(status_code=response.status_code, detail=result.get("detail"))
        self.request_id = response.json()["id"]

    def retrieve(self) -> dict:
        """
        Wait for the generation to finish and retrieve response.
        """
        if self.request_id is None:
            self.request()
        while self.result is None:
            response = requests.get(
                f"{API_ENDPOINT}/v1/get_result",
                headers={
                    "accept": "application/json",
                    "x-key": self.api_key,
                },
                params={
                    "id": self.request_id,
                },
            )
            result = response.json()
            if "status" not in result:
                raise ApiException(status_code=response.status_code, detail=result.get("detail"))
            elif result["status"] == "Ready":
                self.result = result["result"]
            elif result["status"] == "Pending":
                time.sleep(0.5)
            else:
                raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
        return self.result

    @property
    def bytes(self) -> bytes:
        """
        Generated image as bytes.
        """
        if self._image_bytes is None:
            response = requests.get(self.url)
            if response.status_code == 200:
                self._image_bytes = response.content
            else:
                raise ApiException(status_code=response.status_code)
        return self._image_bytes

    @property
    def url(self) -> str:
        """
        Public url to retrieve the image from
        """
        if self._url is None:
            result = self.retrieve()
            self._url = result["sample"]
        return self._url

    @property
    def image(self) -> Image.Image:
        """
        Load the image as a PIL Image
        """
        return Image.open(io.BytesIO(self.bytes))

    def save(self, path: str):
        """
        Save the generated image to a local path
        """
        suffix = Path(self.url).suffix
        if not path.endswith(suffix):
            path = path + suffix
        Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
        with open(path, "wb") as file:
            file.write(self.bytes)


if __name__ == "__main__":
    from fire import Fire

    Fire(ImageRequest)