Update README.md
Browse files
README.md
CHANGED
@@ -248,4 +248,78 @@ The following hyperparameters were used during training:
|
|
248 |
- Transformers 4.35.2
|
249 |
- Pytorch 2.1.0+cu118
|
250 |
- Datasets 2.15.0
|
251 |
-
- Tokenizers 0.15.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
- Transformers 4.35.2
|
249 |
- Pytorch 2.1.0+cu118
|
250 |
- Datasets 2.15.0
|
251 |
+
- Tokenizers 0.15.0
|
252 |
+
|
253 |
+
### Example of usage
|
254 |
+
|
255 |
+
```python
|
256 |
+
from datasets import load_dataset
|
257 |
+
from transformers import TrainingArguments
|
258 |
+
from transformers import CLIPProcessor, AutoModelForImageClassification
|
259 |
+
|
260 |
+
processor = CLIPProcessor.from_pretrained("Andron00e/CLIPForImageClassification-v1")
|
261 |
+
model = AutoModelForImageClassification.from_pretrained("Andron00e/CLIPForImageClassification-v1")
|
262 |
+
|
263 |
+
dataset = load_dataset("Andron00e/CIFAR100-custom")
|
264 |
+
dataset = dataset["train"].train_test_split(test_size=0.2)
|
265 |
+
from datasets import DatasetDict
|
266 |
+
|
267 |
+
val_test = dataset["test"].train_test_split(test_size=0.5)
|
268 |
+
dataset = DatasetDict({
|
269 |
+
"train": dataset["train"],
|
270 |
+
"validation": val_test["train"],
|
271 |
+
"test": val_test["test"],
|
272 |
+
})
|
273 |
+
|
274 |
+
def transform(example_batch):
|
275 |
+
inputs = processor(text=[classes[x] for x in example_batch['labels']], images=[x for x in example_batch['image']], padding=True, return_tensors='pt')
|
276 |
+
inputs['labels'] = example_batch['labels']
|
277 |
+
return inputs
|
278 |
+
|
279 |
+
def collate_fn(batch):
|
280 |
+
return {
|
281 |
+
'input_ids': torch.stack([x['input_ids'] for x in batch]),
|
282 |
+
'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
|
283 |
+
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
|
284 |
+
'labels': torch.tensor([x['labels'] for x in batch])
|
285 |
+
}
|
286 |
+
|
287 |
+
training_args = TrainingArguments(
|
288 |
+
output_dir="./outputs",
|
289 |
+
per_device_train_batch_size=16,
|
290 |
+
evaluation_strategy="steps",
|
291 |
+
num_train_epochs=4,
|
292 |
+
fp16=False,
|
293 |
+
save_steps=100,
|
294 |
+
eval_steps=100,
|
295 |
+
logging_steps=10,
|
296 |
+
learning_rate=2e-4,
|
297 |
+
save_total_limit=2,
|
298 |
+
remove_unused_columns=False,
|
299 |
+
push_to_hub=False,
|
300 |
+
report_to='tensorboard',
|
301 |
+
load_best_model_at_end=True,
|
302 |
+
)
|
303 |
+
|
304 |
+
from transformers import Trainer
|
305 |
+
|
306 |
+
trainer = Trainer(
|
307 |
+
model=model,
|
308 |
+
args=training_args,
|
309 |
+
data_collator=collate_fn,
|
310 |
+
compute_metrics=compute_metrics,
|
311 |
+
train_dataset=dataset.with_transform(transform)["train"],
|
312 |
+
eval_dataset=dataset.with_transform(transform)["validation"],
|
313 |
+
tokenizer=model.processor,
|
314 |
+
)
|
315 |
+
|
316 |
+
train_results = trainer.train()
|
317 |
+
trainer.save_model()
|
318 |
+
trainer.log_metrics("train", train_results.metrics)
|
319 |
+
trainer.save_metrics("train", train_results.metrics)
|
320 |
+
trainer.save_state()
|
321 |
+
|
322 |
+
metrics = trainer.evaluate(processed_dataset['test'])
|
323 |
+
trainer.log_metrics("eval", metrics)
|
324 |
+
trainer.save_metrics("eval", metrics)
|
325 |
+
```
|