# train.py import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from super_large_language_model import TransformerModel class TextDataset(Dataset): def __init__(self, texts, vocab): self.texts = texts self.vocab = vocab def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] text_indices = [self.vocab[char] for char in text] return torch.tensor(text_indices) def train_model(model, dataset, num_epochs=10, batch_size=32, learning_rate=0.001): dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) for epoch in range(num_epochs): model.train() for batch in dataloader: optimizer.zero_grad() output = model(batch[:-1], batch[1:]) loss = criterion(output.view(-1, output.size(-1)), batch[1:].view(-1)) loss.backward() optimizer.step() print(f'Epoch {epoch+1}, Loss: {loss.item()}') if __name__ == "__main__": # Example texts and vocabulary texts = ["hello world", "pytorch is great"] vocab = {char: idx for idx, char in enumerate(set("".join(texts)))} dataset = TextDataset(texts, vocab) model = TransformerModel(vocab_size=len(vocab), d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048) train_model(model, dataset)