CSD / README.md
yuxi-liu-wired's picture
Update README.md
386dad3 verified
---
license: mit
language:
- en
base_model:
- openai/clip-vit-large-patch14
tags:
- art
- style
- clip
- image
- embedding
- vit
- model_hub_mixin
- pytorch_model_hub_mixin
---
## Measuring Style Similarity in Diffusion Models
Cloned from [learn2phoenix/CSD](https://github.com/learn2phoenix/CSD?tab=readme-ov-file).
Their model (`csd-vit-l.pth`) downloaded from their [Google Drive](https://drive.google.com/file/d/1FX0xs8p-C7Ob-h5Y4cUhTeOepHzXv_46/view?usp=sharing).
The original Git Repo is in the `CSD` folder.
## Model architecture
The model CSD ("contrastive style descriptor") is initialized from the image encoder part of [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14). Let $f$ be the function implemented by the image encoder. $f$ is implemented as a vision Transformer, that takes an image, and converts it into a $1024$-dimensional real-valued vector. This is then followed by a single matrix ("projection matrix") of dimensions $1024 \times 768$, converting it to a CLIP-embedding vector.
Now, remove the projection matrix. This gives us $g: \text{Image} \to \R^{1024}$. The output from $g$ is the `feature vector`. Now, add in two more projection matrices of dimensions $1024 \times 768$. The output from one is the `style vector` and the other is the `content vector`. All parameters of the resulting model was then finetuned by [tadeephuy/GradientReversal](https://github.com/tadeephuy/GradientReversal) for content style disentanglement, resulting in the final model.
The original paper actually stated that they trained *two* models, and one of them was based on ViT-B, but they did not release it.
The model takes as input real-valued tensors. To preprocess images, use the CLIP preprocessor. That is, use `_, preprocess = clip.load("ViT-L/14")`. Explicitly, the preprocessor performs the following operation:
```python
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
```
See the documentation for [`CLIPImageProcessor` for details](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor).
Also, despite the names `style vector` and `content vector`, I have noticed by visual inspection that both are basically equally good for style embedding. I don't know why, but I guess that's life?
## How to use it
### Quickstart
Go to `examples` and run the `example.ipynb` notebook, then run `tsne_visualization.py`. It will say something like `Running on http://127.0.0.1:49860`. Click that link and enjoy the pretty interactive picture.
![](examples/style_embedding_tsne.png)
### Loading the model
```python
import copy
import torch
import torch.nn as nn
import clip
from transformers import CLIPProcessor
from huggingface_hub import PyTorchModelHubMixin
from transformers import PretrainedConfig
class CSDCLIPConfig(PretrainedConfig):
model_type = "csd_clip"
def __init__(
self,
name="csd_large",
embedding_dim=1024,
feature_dim=1024,
content_dim=768,
style_dim=768,
content_proj_head="default",
**kwargs
):
super().__init__(**kwargs)
self.name = name
self.embedding_dim = embedding_dim
self.content_proj_head = content_proj_head
self.task_specific_params = None # Add this line
class CSD_CLIP(nn.Module, PyTorchModelHubMixin):
"""backbone + projection head"""
def __init__(self, name='vit_large',content_proj_head='default'):
super(CSD_CLIP, self).__init__()
self.content_proj_head = content_proj_head
if name == 'vit_large':
clipmodel, _ = clip.load("ViT-L/14")
self.backbone = clipmodel.visual
self.embedding_dim = 1024
self.feature_dim = 1024
self.content_dim = 768
self.style_dim = 768
self.name = "csd_large"
elif name == 'vit_base':
clipmodel, _ = clip.load("ViT-B/16")
self.backbone = clipmodel.visual
self.embedding_dim = 768
self.feature_dim = 512
self.content_dim = 512
self.style_dim = 512
self.name = "csd_base"
else:
raise Exception('This model is not implemented')
self.last_layer_style = copy.deepcopy(self.backbone.proj)
self.last_layer_content = copy.deepcopy(self.backbone.proj)
self.backbone.proj = None
self.config = CSDCLIPConfig(
name=self.name,
embedding_dim=self.embedding_dim,
feature_dim=self.feature_dim,
content_dim=self.content_dim,
style_dim=self.style_dim,
content_proj_head=self.content_proj_head
)
def get_config(self):
return self.config.to_dict()
@property
def dtype(self):
return self.backbone.conv1.weight.dtype
@property
def device(self):
return next(self.parameters()).device
def forward(self, input_data):
feature = self.backbone(input_data)
style_output = feature @ self.last_layer_style
style_output = nn.functional.normalize(style_output, dim=1, p=2)
content_output = feature @ self.last_layer_content
content_output = nn.functional.normalize(content_output, dim=1, p=2)
return feature, content_output, style_output
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CSD_CLIP.from_pretrained("yuxi-liu-wired/CSD")
model.to(device);
```
### Loading the pipeline
```python
import torch
from transformers import Pipeline
from typing import Union, List
from PIL import Image
class CSDCLIPPipeline(Pipeline):
def __init__(self, model, processor, device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
super().__init__(model=model, tokenizer=None, device=device)
self.processor = processor
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def preprocess(self, images):
if isinstance(images, (str, Image.Image)):
images = [images]
processed = self.processor(images=images, return_tensors="pt", padding=True, truncation=True)
return {k: v.to(self.device) for k, v in processed.items()}
def _forward(self, model_inputs):
pixel_values = model_inputs['pixel_values'].to(self.model.dtype)
with torch.no_grad():
features, content_output, style_output = self.model(pixel_values)
return {"features": features, "content_output": content_output, "style_output": style_output}
def postprocess(self, model_outputs):
return {
"features": model_outputs["features"].cpu().numpy(),
"content_output": model_outputs["content_output"].cpu().numpy(),
"style_output": model_outputs["style_output"].cpu().numpy()
}
def __call__(self, images: Union[str, List[str], Image.Image, List[Image.Image]]):
return super().__call__(images)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
pipeline = CSDCLIPPipeline(model=model, processor=processor, device=device)
```
### An example application
First, load the model and the pipeline, as described above. Then, run the following to load the [yuxi-liu-wired/style-content-grid-SDXL](https://huggingface.co/datasets/yuxi-liu-wired/style-content-grid-SDXL) dataset, embed its style vectors, which is then written to a `parquet` output file.
```python
import io
from PIL import Image
from datasets import load_dataset
import pandas as pd
from tqdm import tqdm
def to_jpeg(image):
buffered = io.BytesIO()
if image.mode not in ("RGB"):
image = image.convert("RGB")
image.save(buffered, format='JPEG')
return buffered.getvalue()
def scale_image(image, max_resolution):
if max(image.width, image.height) > max_resolution:
image = image.resize((max_resolution, int(image.height * max_resolution / image.width)))
return image
def process_dataset(pipeline, dataset_name, dataset_size=900, max_resolution=192):
dataset = load_dataset(dataset_name, split='train')
dataset = dataset.select(range(dataset_size))
# Print the column names
print("Dataset columns:", dataset.column_names)
# Initialize lists to store results
embeddings = []
jpeg_images = []
# Process each item in the dataset
for item in tqdm(dataset, desc="Processing images"):
try:
img = item['image']
# If img is a string (file path), load the image
if isinstance(img, str):
img = Image.open(img)
output = pipeline(img)
style_output = output["style_output"].squeeze(0)
img = scale_image(img, max_resolution)
jpeg_img = to_jpeg(img)
# Append results to lists
embeddings.append(style_output)
jpeg_images.append(jpeg_img)
except Exception as e:
print(f"Error processing item: {e}")
# Create a DataFrame with the results
df = pd.DataFrame({
'embedding': embeddings,
'image': jpeg_images
})
df.to_parquet('processed_dataset.parquet')
print("Processing complete. Results saved to 'processed_dataset.parquet'")
process_dataset(pipeline, "yuxi-liu-wired/style-content-grid-SDXL",
dataset_size=900, max_resolution=192)
```
After that, you can go to `examples` and run `tsne_visualization.py` to get an interactive Dash app browser for the images.
![](examples/style_embedding_tsne.png)