|
from datasets import load_dataset, DatasetDict |
|
from transformers import AutoTokenizer, T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer |
|
from transformers import EarlyStoppingCallback |
|
from transformers.integrations import TensorBoardCallback |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
def generate_formal_text(text): |
|
|
|
return text |
|
|
|
def prepare_data(example): |
|
example["formal_text"] = generate_formal_text(example["text"]) |
|
return example |
|
|
|
def tokenize_function(examples, tokenizer): |
|
model_inputs = tokenizer(examples["formal_text"], max_length=128, truncation=True, padding="max_length") |
|
labels = tokenizer(examples["text"], max_length=128, truncation=True, padding="max_length") |
|
model_inputs["labels"] = labels["input_ids"] |
|
return model_inputs |
|
|
|
def main(): |
|
|
|
logger.info("Loading dataset...") |
|
dataset = load_dataset("LucasChu/reddit_comments") |
|
dataset = dataset.shuffle(seed=42) |
|
dataset["train"] = dataset["train"].select(range(10000)) |
|
logger.info("Dataset loaded, shuffled, and truncated to 10,000 samples.") |
|
|
|
|
|
train_testvalid = dataset["train"].train_test_split(test_size=0.2, seed=42) |
|
test_valid = train_testvalid["test"].train_test_split(test_size=0.5, seed=42) |
|
|
|
dataset = DatasetDict({ |
|
"train": train_testvalid["train"], |
|
"test": test_valid["test"], |
|
"validation": test_valid["train"] |
|
}) |
|
|
|
|
|
logger.info("Preparing dataset...") |
|
processed_dataset = {split: data.map(prepare_data) for split, data in dataset.items()} |
|
logger.info("Dataset prepared.") |
|
|
|
|
|
model_name = "t5-small" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
logger.info("Tokenizing dataset...") |
|
tokenized_dataset = {split: data.map(lambda examples: tokenize_function(examples, tokenizer), batched=True) for split, data in processed_dataset.items()} |
|
logger.info("Dataset tokenized.") |
|
|
|
|
|
available_splits = list(tokenized_dataset.keys()) |
|
logger.info(f"Available splits in the dataset: {available_splits}") |
|
|
|
|
|
logger.info("Setting up model and trainer...") |
|
model = T5ForConditionalGeneration.from_pretrained(model_name) |
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
output_dir="./results", |
|
num_train_epochs=1, |
|
per_device_train_batch_size=16, |
|
warmup_steps=100, |
|
weight_decay=0.01, |
|
logging_dir="./logs", |
|
logging_steps=100, |
|
evaluation_strategy="steps" if "test" in available_splits else "no", |
|
eval_steps=500, |
|
save_steps=1000, |
|
load_best_model_at_end=True, |
|
metric_for_best_model="eval_loss", |
|
greater_is_better=False |
|
) |
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_dataset["train"], |
|
eval_dataset=tokenized_dataset.get("test"), |
|
tokenizer=tokenizer, |
|
callbacks=[EarlyStoppingCallback(early_stopping_patience=3), TensorBoardCallback()] |
|
) |
|
logger.info("Model and trainer set up.") |
|
|
|
|
|
return trainer |
|
|
|
if __name__ == "__main__": |
|
main() |