|
import os |
|
from dotenv import load_dotenv |
|
import torch |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments |
|
from datasets import load_dataset, concatenate_datasets |
|
from huggingface_hub import login |
|
import time |
|
import uvicorn |
|
from fastapi import FastAPI |
|
import threading |
|
|
|
|
|
load_dotenv() |
|
huggingface_token = os.getenv('HUGGINGFACE_TOKEN') |
|
if huggingface_token is None: |
|
raise ValueError("HUGGINGFACE_TOKEN not found in environment variables.") |
|
|
|
|
|
login(token=huggingface_token) |
|
|
|
|
|
app = FastAPI() |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Modelo entrenado y en ejecuci贸n."} |
|
|
|
def load_and_train(): |
|
model_name = 'gpt2' |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
|
|
|
|
try: |
|
dataset_humanizado = load_dataset('daily_dialog', split='train', cache_dir='/dev/shm', trust_remote_code=True) |
|
dataset_codigo = load_dataset('code_search_net', split='train', cache_dir='/dev/shm', trust_remote_code=True) |
|
except Exception as e: |
|
print(f"Error al cargar los datasets: {e}") |
|
|
|
time.sleep(60) |
|
try: |
|
dataset_humanizado = load_dataset('alternative_dataset', split='train', cache_dir='/dev/shm', trust_remote_code=True) |
|
except Exception as e: |
|
print(f"Error al cargar el dataset alternativo: {e}") |
|
return |
|
|
|
print("Daily Dialog columns:", dataset_humanizado.column_names) |
|
print("Code Search Net columns:", dataset_codigo.column_names) |
|
|
|
|
|
combined_dataset = concatenate_datasets([dataset_humanizado, dataset_codigo]) |
|
|
|
print("Combined dataset columns:", combined_dataset.column_names) |
|
|
|
|
|
def tokenize_function(examples): |
|
if 'dialog' in examples: |
|
return tokenizer(examples['dialog'], truncation=True, padding='max_length', max_length=512) |
|
elif 'docstring' in examples: |
|
return tokenizer(examples['docstring'], truncation=True, padding='max_length', max_length=512) |
|
elif 'code' in examples: |
|
return tokenizer(examples['code'], truncation=True, padding='max_length', max_length=512) |
|
return {} |
|
|
|
|
|
tokenized_dataset = combined_dataset.map(tokenize_function, batched=True, cache_file_name='/dev/shm/tokenized_dataset.arrow') |
|
|
|
training_args = TrainingArguments( |
|
output_dir='/dev/shm/results', |
|
per_device_train_batch_size=4, |
|
per_device_eval_batch_size=4, |
|
num_train_epochs=1, |
|
learning_rate=1e-5, |
|
logging_steps=100, |
|
save_total_limit=1, |
|
seed=42, |
|
weight_decay=0.01, |
|
warmup_ratio=0.1, |
|
evaluation_strategy="epoch", |
|
lr_scheduler_type="linear", |
|
save_steps=500, |
|
save_strategy="epoch", |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_dataset, |
|
) |
|
|
|
while True: |
|
try: |
|
trainer.train() |
|
|
|
model.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', commit_message="Actualizaci贸n del modelo") |
|
tokenizer.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', commit_message="Actualizaci贸n del tokenizador") |
|
time.sleep(300) |
|
except Exception as e: |
|
print(f"Error durante el entrenamiento: {e}. Reiniciando el proceso de entrenamiento...") |
|
time.sleep(10) |
|
|
|
if __name__ == "__main__": |
|
|
|
threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=7860)).start() |
|
load_and_train() |
|
|