jayksharma's picture
Update train.py
da9ac04 verified
raw
history blame contribute delete
No virus
1.56 kB
# train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from super_large_language_model import TransformerModel
class TextDataset(Dataset):
def __init__(self, texts, vocab):
self.texts = texts
self.vocab = vocab
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
text_indices = [self.vocab[char] for char in text]
return torch.tensor(text_indices)
def train_model(model, dataset, num_epochs=10, batch_size=32, learning_rate=0.001):
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
model.train()
for batch in dataloader:
optimizer.zero_grad()
output = model(batch[:-1], batch[1:])
loss = criterion(output.view(-1, output.size(-1)), batch[1:].view(-1))
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
if __name__ == "__main__":
# Example texts and vocabulary
texts = ["hello world", "pytorch is great"]
vocab = {char: idx for idx, char in enumerate(set("".join(texts)))}
dataset = TextDataset(texts, vocab)
model = TransformerModel(vocab_size=len(vocab), d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048)
train_model(model, dataset)