File size: 4,052 Bytes
cb347cb
 
 
fdd3e52
 
 
 
 
 
 
 
8d78495
 
 
 
6529858
8d78495
1771344
 
8d78495
 
 
b347096
 
8d78495
 
 
90d91be
8d78495
fdd3e52
 
 
8d78495
9b9c2cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e08643e
8d78495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
---
license: apache-2.0
pipeline_tag: image-classification
library_name: transformers
tags:
- image-detection
- ai-image-generation
- anime
- ai-anime
- human-detection
- art
---

# AI Anime Image Detector ViT

This is a proof of concept model for detecting anime style AI images. Using Vision Transformer, it was trained on 1M human-made real and 217K AI generated anime images. During training either type appeared in equal amount to avoid biases. The model was trained on a single RTX 3090 GPU for about 40 hours, ~35 epochs.

The training logs are available on my [wandb](https://wandb.ai/legekka/AI-Image-Detector).

## Evaluation

Each checkpoint was evaluated on 500-500 real and AI images.

Final result:
- Training Loss: 0.1009
- Eval Loss: 0.1386

It seems like using random crops helped the model to generalize better, however, the training dataset only contained 512x512 images, which meant that every cropped image had bilinear interpolation. Training the model on 1024x1024 images could probably further improve its performance. *(Maybe I'll do it later)*

## Performance comparison

We did a small comparison with the current available AI image detectors. Note that these models were not specificly trained on anime images.

| Image        | Nahrawy/AIorNot | umm-maybe/AI-image-detector | Organika/sdxl-detector | mmanikanta/VIT_AI_image_detector  | Ours       |
|--------------|-----------------|-----------------------------|------------------------|-----------------------------------|------------|
| ai_1.jpg     | ai (100%)       | human (86%)                 | artificial (100%)      | FAKE (89%)                        | ai (100%)  |
| ai_2.jpg     | ai (99%)        | human (96%)                 | artificial (100%)      | FAKE (89%)                        | ai (100%)  |
| ai_3.jpg     | ai (77%)        | human (98%)                 | artificial (100%)      | REAL (100%)                       | ai (100%)  |
| ai_4.jpg     | real (66%)      | human (100%)                | human (100%)           | REAL (100%)                       | real (100%)|
| ai_5.jpg     | ai (51%)        | human (99%)                 | artificial (55%)       | REAL (99%)                        | real (65%) |
| ai_6.jpg     | ai (100%)       | human (98%)                 | artificial (100%)      | FAKE (60%)                        | ai (84%)   |
| real_1.jpg   | ai (99%)        | human (99%)                 | artificial (100%)      | REAL (98%)                        | ai (55%)   |
| real_2.jpg   | ai (88%)        | human (100%)                | artificial (100%)      | REAL (100%)                       | real (85%) |
| real_3.jpg   | ai (95%)        | human (96%)                 | artificial (100%)      | REAL (100%)                       | real (97%) |
| real_4.jpg   | real (90%)      | human (100%)                | artificial (97%)       | REAL (100%)                       | real (94%) |
| real_5.jpg   | ai (75%)        | human (100%)                | human (57%)            | REAL (100%)                       | real (100%)|
| real_6.jpg   | ai (89%)        | human (98%)                 | human (100%)           | REAL (100%)                       | real (99%) |
| **Accuracy:**| 50%             | 50%                         | 58%                    | **75%**                           | **75%**    |



## Usage

Example inference code:

```python
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image

model = AutoModelForImageClassification.from_pretrained("legekka/AI-Anime-Image-Detector-ViT")
feature_extractor = AutoFeatureExtractor.from_pretrained("legekka/AI-Anime-Image-Detector-ViT")

model.eval()

image = Image.open("example.jpg")
inputs = feature_extractor(images=image, return_tensors="pt")

outputs = model(**inputs)
logits = outputs.logits

label = model.config.id2label[torch.argmax(logits).item()]
confidence = torch.nn.functional.softmax(logits, dim=1)[0][torch.argmax(logits)].item()

print(f"Prediction: {label} ({round(confidence * 100)}%)")
```