dgdgdgdgd / app.py
Yhhxhfh's picture
Update app.py
35f7013 verified
raw
history blame contribute delete
No virus
4.17 kB
import os
from dotenv import load_dotenv
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset, concatenate_datasets
from huggingface_hub import login
import time
import uvicorn
from fastapi import FastAPI
import threading
# Cargar las variables de entorno
load_dotenv()
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
if huggingface_token is None:
raise ValueError("HUGGINGFACE_TOKEN not found in environment variables.")
# Iniciar sesi贸n en Hugging Face
login(token=huggingface_token)
# Definir la aplicaci贸n FastAPI
app = FastAPI()
@app.get("/")
async def root():
return {"message": "Modelo entrenado y en ejecuci贸n."}
def load_and_train():
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Intentar cargar los datasets con manejo de errores
try:
dataset_humanizado = load_dataset('daily_dialog', split='train', cache_dir='/dev/shm', trust_remote_code=True)
dataset_codigo = load_dataset('code_search_net', split='train', cache_dir='/dev/shm', trust_remote_code=True)
except Exception as e:
print(f"Error al cargar los datasets: {e}")
# Si hay un error, podr铆as intentar cargar un dataset alternativo o reintentar despu茅s de un tiempo
time.sleep(60) # Esperar 60 segundos antes de reintentar
try:
dataset_humanizado = load_dataset('alternative_dataset', split='train', cache_dir='/dev/shm', trust_remote_code=True)
except Exception as e:
print(f"Error al cargar el dataset alternativo: {e}")
return
print("Daily Dialog columns:", dataset_humanizado.column_names)
print("Code Search Net columns:", dataset_codigo.column_names)
# Combinar los datasets en memoria
combined_dataset = concatenate_datasets([dataset_humanizado, dataset_codigo])
print("Combined dataset columns:", combined_dataset.column_names)
# Funci贸n de tokenizaci贸n en RAM
def tokenize_function(examples):
if 'dialog' in examples:
return tokenizer(examples['dialog'], truncation=True, padding='max_length', max_length=512)
elif 'docstring' in examples:
return tokenizer(examples['docstring'], truncation=True, padding='max_length', max_length=512)
elif 'code' in examples:
return tokenizer(examples['code'], truncation=True, padding='max_length', max_length=512)
return {}
# Tokenizar y mantener todo en RAM
tokenized_dataset = combined_dataset.map(tokenize_function, batched=True, cache_file_name='/dev/shm/tokenized_dataset.arrow')
training_args = TrainingArguments(
output_dir='/dev/shm/results', # Almacenar temporalmente en RAM
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=1,
learning_rate=1e-5,
logging_steps=100,
save_total_limit=1,
seed=42,
weight_decay=0.01,
warmup_ratio=0.1,
evaluation_strategy="epoch",
lr_scheduler_type="linear",
save_steps=500, # Guardar menos frecuentemente para evitar escritura
save_strategy="epoch", # Guardar solo al final de cada epoch
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
)
while True:
try:
trainer.train()
# Subir el modelo a Hugging Face desde la RAM
model.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', commit_message="Actualizaci贸n del modelo")
tokenizer.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', commit_message="Actualizaci贸n del tokenizador")
time.sleep(300)
except Exception as e:
print(f"Error durante el entrenamiento: {e}. Reiniciando el proceso de entrenamiento...")
time.sleep(10)
if __name__ == "__main__":
# Correr FastAPI en un hilo separado
threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=7860)).start()
load_and_train()