yanze commited on
Commit
9ebe534
1 Parent(s): 240a720

Update flux/util.py

Browse files
Files changed (1) hide show
  1. flux/util.py +0 -45
flux/util.py CHANGED
@@ -4,7 +4,6 @@ from dataclasses import dataclass
4
  import torch
5
  from einops import rearrange
6
  from huggingface_hub import hf_hub_download
7
- from imwatermark import WatermarkEncoder
8
  from safetensors.torch import load_file as load_sft
9
 
10
  from flux.model import Flux, FluxParams
@@ -155,47 +154,3 @@ def load_ae(name: str, device: str = "cuda", hf_download: bool = True) -> AutoEn
155
  missing, unexpected = ae.load_state_dict(sd, strict=False)
156
  print_load_warning(missing, unexpected)
157
  return ae
158
-
159
-
160
- class WatermarkEmbedder:
161
- def __init__(self, watermark):
162
- self.watermark = watermark
163
- self.num_bits = len(WATERMARK_BITS)
164
- self.encoder = WatermarkEncoder()
165
- self.encoder.set_watermark("bits", self.watermark)
166
-
167
- def __call__(self, image: torch.Tensor) -> torch.Tensor:
168
- """
169
- Adds a predefined watermark to the input image
170
-
171
- Args:
172
- image: ([N,] B, RGB, H, W) in range [-1, 1]
173
-
174
- Returns:
175
- same as input but watermarked
176
- """
177
- image = 0.5 * image + 0.5
178
- squeeze = len(image.shape) == 4
179
- if squeeze:
180
- image = image[None, ...]
181
- n = image.shape[0]
182
- image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
183
- # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
184
- # watermarking libary expects input as cv2 BGR format
185
- for k in range(image_np.shape[0]):
186
- image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
187
- image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
188
- image.device
189
- )
190
- image = torch.clamp(image / 255, min=0.0, max=1.0)
191
- if squeeze:
192
- image = image[0]
193
- image = 2 * image - 1
194
- return image
195
-
196
-
197
- # A fixed 48-bit message that was choosen at random
198
- WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
199
- # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
200
- WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
201
- embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
 
4
  import torch
5
  from einops import rearrange
6
  from huggingface_hub import hf_hub_download
 
7
  from safetensors.torch import load_file as load_sft
8
 
9
  from flux.model import Flux, FluxParams
 
154
  missing, unexpected = ae.load_state_dict(sd, strict=False)
155
  print_load_warning(missing, unexpected)
156
  return ae