dgdgdgdgd / app.py
Yhhxhfh's picture
Update app.py
0dd719b verified
raw
history blame
No virus
2.45 kB
import os
from dotenv import load_dotenv
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset, concatenate_datasets
from huggingface_hub import login
import time
import uvicorn
from fastapi import FastAPI
load_dotenv()
login(token=os.getenv('HUGGINGFACE_TOKEN'))
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Cargar datasets y mantener todo en RAM
dataset_humanizado = load_dataset('daily_dialog', split='train', trust_remote_code=True)
dataset_codigo = load_dataset('code_search_net', split='train', trust_remote_code=True)
dataset_prompts = load_dataset('openai_humaneval', split='train', trust_remote_code=True)
combined_dataset = concatenate_datasets([
dataset_humanizado,
dataset_codigo,
dataset_prompts
])
# Tokenizar y mantener todo en RAM
def tokenize_function(examples):
return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)
tokenized_dataset = combined_dataset.map(tokenize_function, batched=True)
training_args = TrainingArguments(
output_dir='./results',
per_device_train_batch_size=100,
per_device_eval_batch_size=100,
num_train_epochs=0,
learning_rate=1e-5,
logging_steps=-1,
max_grad_norm=1,
save_total_limit=1,
seed=42,
weight_decay=0,
warmup_ratio=0.0,
evaluation_strategy="no",
optim="adamw_torch",
lr_scheduler_type="constant",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
)
app = FastAPI()
@app.get("/")
async def root():
return {"message": "Modelo entrenado y en ejecuci贸n."}
@spaces.gpu
def run_training():
while True:
try:
trainer.train()
model.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', use_temp_dir=True, commit_message="Actualizaci贸n del modelo")
tokenizer.push_to_hub('Yhhxhfh/nombre_de_tu_modelo', repo_type='model', use_temp_dir=True, commit_message="Actualizaci贸n del tokenizador")
time.sleep(5)
except Exception as e:
print(f"Error durante el entrenamiento: {e}. Reiniciando el proceso de entrenamiento...")
time.sleep(10)
if __name__ == "__main__":
import threading
threading.Thread(target=run_training).start()
uvicorn.run(app, host="0.0.0.0", port=7860)