lucidmorto commited on
Commit
085809d
1 Parent(s): fea89cd

Initial commit with training script

Browse files
Files changed (3) hide show
  1. humanizer.py +94 -0
  2. requirements.txt +3 -0
  3. run_trainer.py +5 -0
humanizer.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, DatasetDict
2
+ from transformers import AutoTokenizer, T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
3
+ from transformers import EarlyStoppingCallback
4
+ from transformers.integrations import TensorBoardCallback
5
+ import logging
6
+
7
+ # Set up logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ def generate_formal_text(text):
12
+ # Implement formal text generation here
13
+ return text # Placeholder
14
+
15
+ def prepare_data(example):
16
+ example["formal_text"] = generate_formal_text(example["text"])
17
+ return example
18
+
19
+ def tokenize_function(examples, tokenizer):
20
+ model_inputs = tokenizer(examples["formal_text"], max_length=128, truncation=True, padding="max_length")
21
+ labels = tokenizer(examples["text"], max_length=128, truncation=True, padding="max_length")
22
+ model_inputs["labels"] = labels["input_ids"]
23
+ return model_inputs
24
+
25
+ def main():
26
+ # Load the dataset and take only 10000 samples
27
+ logger.info("Loading dataset...")
28
+ dataset = load_dataset("LucasChu/reddit_comments")
29
+ dataset = dataset.shuffle(seed=42)
30
+ dataset["train"] = dataset["train"].select(range(10000))
31
+ logger.info("Dataset loaded, shuffled, and truncated to 10,000 samples.")
32
+
33
+ # Split the train dataset into train and test
34
+ train_testvalid = dataset["train"].train_test_split(test_size=0.2, seed=42)
35
+ test_valid = train_testvalid["test"].train_test_split(test_size=0.5, seed=42)
36
+
37
+ dataset = DatasetDict({
38
+ "train": train_testvalid["train"],
39
+ "test": test_valid["test"],
40
+ "validation": test_valid["train"]
41
+ })
42
+
43
+ # Prepare the dataset
44
+ logger.info("Preparing dataset...")
45
+ processed_dataset = {split: data.map(prepare_data) for split, data in dataset.items()}
46
+ logger.info("Dataset prepared.")
47
+
48
+ # Tokenize the dataset
49
+ model_name = "t5-small"
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+
52
+ logger.info("Tokenizing dataset...")
53
+ tokenized_dataset = {split: data.map(lambda examples: tokenize_function(examples, tokenizer), batched=True) for split, data in processed_dataset.items()}
54
+ logger.info("Dataset tokenized.")
55
+
56
+ # Check available splits in the dataset
57
+ available_splits = list(tokenized_dataset.keys())
58
+ logger.info(f"Available splits in the dataset: {available_splits}")
59
+
60
+ # Set up the model and trainer
61
+ logger.info("Setting up model and trainer...")
62
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
63
+
64
+ training_args = Seq2SeqTrainingArguments(
65
+ output_dir="./results",
66
+ num_train_epochs=1,
67
+ per_device_train_batch_size=16,
68
+ warmup_steps=100,
69
+ weight_decay=0.01,
70
+ logging_dir="./logs",
71
+ logging_steps=100,
72
+ evaluation_strategy="steps" if "test" in available_splits else "no",
73
+ eval_steps=500,
74
+ save_steps=1000,
75
+ load_best_model_at_end=True,
76
+ metric_for_best_model="eval_loss",
77
+ greater_is_better=False
78
+ )
79
+
80
+ trainer = Seq2SeqTrainer(
81
+ model=model,
82
+ args=training_args,
83
+ train_dataset=tokenized_dataset["train"],
84
+ eval_dataset=tokenized_dataset.get("test"),
85
+ tokenizer=tokenizer,
86
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3), TensorBoardCallback()]
87
+ )
88
+ logger.info("Model and trainer set up.")
89
+
90
+ # Return the trainer object
91
+ return trainer
92
+
93
+ if __name__ == "__main__":
94
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ datasets
2
+ transformers
3
+ torch
run_trainer.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from humanizer import main
2
+
3
+ if __name__ == "__main__":
4
+ trainer = main()
5
+ trainer.train()