Edit model card

AGE-ViT

Age-classifying Generative Entity Vision Transformer

A Vision Transformer finetuned to classify images of human faces into 'minor' or 'adult'.

This model is a finetuned version of https://huggingface.co/nateraw/vit-age-classifier which was finetuned on the fairface dataset. We then utilize a dataset of generated humans to get the model to recognize the composition and styles across anime, cartoons, digitial art, etc as they're created by diffusion models.

Users should note that fairface is trained on specifically the human face and maybe a small portion of their body, similar to a 'headshot' whereas the generated dataset may be headshot style or include more of the body. To allow for better recognition we did not extract the faces of the generated dataset during training, instead allowing the model to train on the full image.

Datasets

These datasets were used in finetuning, with fairface finetuning the classifier we built on top of.

FairFace dataset

https://github.com/dchen236/FairFace

This is a balanced dataset for race, gender, and age and was initial intended for bias mitigation. The majority of the images in this dataset are direct and front facing.

Synthetic Dataset

https://civitai.com/models/668458/synthetic-human-dataset

This dataset was fully generated by flux and contains 15k images of men, women, boys, and girls from the front, side, and slightly above. This dataset will be expanded with sd15 images and the model will be retrained.

To use the model

import requests
from PIL import Image
from io import BytesIO

from transformers import ViTImageProcessor, ViTForImageClassification

# Get example image from official fairface repo + read it in as an image
r = requests.get('https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/9488af10-7f1f-4361-877b-d9cfafeab131/original=true,quality=90/24599129.jpeg')
im = Image.open(BytesIO(r.content))

model_dir = 'civitai/age-vit'

# Init model, transforms
model = ViTForImageClassification.from_pretrained(model_dir)
transforms = ViTImageProcessor.from_pretrained(model_dir)

# Transform our image and pass it through the model
inputs = transforms(im, return_tensors='pt')
output = model(**inputs)

# Predicted Class probabilities
proba = output.logits.softmax(1)

# Predicted Classes
preds = proba.argmax(1)

# Get label/string prediction 
prediction = model.config.id2label[preds.item()]
Downloads last month
595
Safetensors
Model size
85.8M params
Tensor type
F32
·
Inference API
Unable to determine this model's library. Check the docs .