from datasets import load_dataset, DatasetDict from transformers import AutoTokenizer, T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer from transformers import EarlyStoppingCallback from transformers.integrations import TensorBoardCallback import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def generate_formal_text(text): # Implement formal text generation here return text # Placeholder def prepare_data(example): example["formal_text"] = generate_formal_text(example["text"]) return example def tokenize_function(examples, tokenizer): model_inputs = tokenizer(examples["formal_text"], max_length=128, truncation=True, padding="max_length") labels = tokenizer(examples["text"], max_length=128, truncation=True, padding="max_length") model_inputs["labels"] = labels["input_ids"] return model_inputs def main(): # Load the dataset and take only 10000 samples logger.info("Loading dataset...") dataset = load_dataset("LucasChu/reddit_comments") dataset = dataset.shuffle(seed=42) dataset["train"] = dataset["train"].select(range(10000)) logger.info("Dataset loaded, shuffled, and truncated to 10,000 samples.") # Split the train dataset into train and test train_testvalid = dataset["train"].train_test_split(test_size=0.2, seed=42) test_valid = train_testvalid["test"].train_test_split(test_size=0.5, seed=42) dataset = DatasetDict({ "train": train_testvalid["train"], "test": test_valid["test"], "validation": test_valid["train"] }) # Prepare the dataset logger.info("Preparing dataset...") processed_dataset = {split: data.map(prepare_data) for split, data in dataset.items()} logger.info("Dataset prepared.") # Tokenize the dataset model_name = "t5-small" tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info("Tokenizing dataset...") tokenized_dataset = {split: data.map(lambda examples: tokenize_function(examples, tokenizer), batched=True) for split, data in processed_dataset.items()} logger.info("Dataset tokenized.") # Check available splits in the dataset available_splits = list(tokenized_dataset.keys()) logger.info(f"Available splits in the dataset: {available_splits}") # Set up the model and trainer logger.info("Setting up model and trainer...") model = T5ForConditionalGeneration.from_pretrained(model_name) training_args = Seq2SeqTrainingArguments( output_dir="./results", num_train_epochs=1, per_device_train_batch_size=16, warmup_steps=100, weight_decay=0.01, logging_dir="./logs", logging_steps=100, evaluation_strategy="steps" if "test" in available_splits else "no", eval_steps=500, save_steps=1000, load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False ) trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset.get("test"), tokenizer=tokenizer, callbacks=[EarlyStoppingCallback(early_stopping_patience=3), TensorBoardCallback()] ) logger.info("Model and trainer set up.") # Return the trainer object return trainer if __name__ == "__main__": main()