Nanobit commited on
Commit
b432889
1 Parent(s): 54fe07a

feat: enable trl's autounwrap (#1060)

Browse files

* feat: test trl's autounwrap

* fix: add check for adapter

* feat: add config to disable autounwrap

* chore: fix lint

.vscode/launch.json CHANGED
@@ -11,7 +11,7 @@
11
  "request": "launch",
12
  "args": [
13
  "-m", "axolotl.cli.train", "dev_sharegpt.yml",
14
- // The flags below simplify debugging by overriding the axolotl config
15
  // with the debugging tips above. Modify as needed.
16
  "--dataset_processes=1", // limits data preprocessing to one process
17
  "--max_steps=1", // limits training to just one step
 
11
  "request": "launch",
12
  "args": [
13
  "-m", "axolotl.cli.train", "dev_sharegpt.yml",
14
+ // The flags below simplify debugging by overriding the axolotl config
15
  // with the debugging tips above. Modify as needed.
16
  "--dataset_processes=1", // limits data preprocessing to one process
17
  "--max_steps=1", // limits training to just one step
devtools/README.md CHANGED
@@ -1 +1 @@
1
- This directory contains example config files that might be useful for debugging. Please see [docs/debugging.md](../docs/debugging.md) for more information.
 
1
+ This directory contains example config files that might be useful for debugging. Please see [docs/debugging.md](../docs/debugging.md) for more information.
docs/debugging.md CHANGED
@@ -30,13 +30,13 @@ While debugging it's helpful to simplify your test scenario as much as possible.
30
  3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
31
  4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings.
32
  - `micro_batch_size: 1`
33
- - `max_steps: 1`
34
  - `val_set_size: 0`
35
  5. **Clear Caches:** Axolotl caches certain steps and so does the underlying HuggingFace trainer. You may want to clear some of these caches when debugging.
36
  - Data preprocessing: When debugging data preprocessing, which includes prompt template formation, you may want to delete the directory set in `dataset_prepared_path:` in your axolotl config. If you didn't set this value, the default is `last_run_prepared`.
37
  - HF Hub: If you are debugging data preprocessing, you should clear the relevant HF cache [HuggingFace cache](https://huggingface.co/docs/datasets/cache), by deleting the appropriate `~/.cache/huggingface/datasets/...` folder(s).
38
  - **The recommended approach is to redirect all outputs and caches to a temporary folder and delete selected subfolders before each run. This is demonstrated in the example configuration below.**
39
-
40
 
41
  ## Debugging with VSCode
42
 
@@ -74,7 +74,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
74
  "request": "launch",
75
  "args": [
76
  "-m", "axolotl.cli.train", "dev_sharegpt.yml",
77
- // The flags below simplify debugging by overriding the axolotl config
78
  // with the debugging tips above. Modify as needed.
79
  "--dataset_processes=1", // limits data preprocessing to one process
80
  "--max_steps=1", // limits training to just one step
@@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
101
 
102
  - The argument `justMyCode` is set to `true` such that you step through only the axolotl code. If you want to step into dependencies, set this to `false`.
103
  - The `preLaunchTask`: `cleanup-for-dataprep` is defined in [.vscode/tasks.json](../.vscode/tasks.json) and is used to delete the following folders before debugging, which is essential to ensure that the data pre-processing code is run from scratch:
104
- - `./devtools/temp_debug/axolotl_outputs`
105
  - `./devtools/temp_debug/.hf-cache/datasets`
106
 
107
  >[!Tip]
 
30
  3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
31
  4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings.
32
  - `micro_batch_size: 1`
33
+ - `max_steps: 1`
34
  - `val_set_size: 0`
35
  5. **Clear Caches:** Axolotl caches certain steps and so does the underlying HuggingFace trainer. You may want to clear some of these caches when debugging.
36
  - Data preprocessing: When debugging data preprocessing, which includes prompt template formation, you may want to delete the directory set in `dataset_prepared_path:` in your axolotl config. If you didn't set this value, the default is `last_run_prepared`.
37
  - HF Hub: If you are debugging data preprocessing, you should clear the relevant HF cache [HuggingFace cache](https://huggingface.co/docs/datasets/cache), by deleting the appropriate `~/.cache/huggingface/datasets/...` folder(s).
38
  - **The recommended approach is to redirect all outputs and caches to a temporary folder and delete selected subfolders before each run. This is demonstrated in the example configuration below.**
39
+
40
 
41
  ## Debugging with VSCode
42
 
 
74
  "request": "launch",
75
  "args": [
76
  "-m", "axolotl.cli.train", "dev_sharegpt.yml",
77
+ // The flags below simplify debugging by overriding the axolotl config
78
  // with the debugging tips above. Modify as needed.
79
  "--dataset_processes=1", // limits data preprocessing to one process
80
  "--max_steps=1", // limits training to just one step
 
101
 
102
  - The argument `justMyCode` is set to `true` such that you step through only the axolotl code. If you want to step into dependencies, set this to `false`.
103
  - The `preLaunchTask`: `cleanup-for-dataprep` is defined in [.vscode/tasks.json](../.vscode/tasks.json) and is used to delete the following folders before debugging, which is essential to ensure that the data pre-processing code is run from scratch:
104
+ - `./devtools/temp_debug/axolotl_outputs`
105
  - `./devtools/temp_debug/.hf-cache/datasets`
106
 
107
  >[!Tip]
docs/rlhf.md CHANGED
@@ -33,3 +33,12 @@ datasets:
33
  ```yaml
34
  rl: ipo
35
  ```
 
 
 
 
 
 
 
 
 
 
33
  ```yaml
34
  rl: ipo
35
  ```
36
+
37
+ #### Trl autounwrap for peft
38
+
39
+ Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
40
+
41
+ ```yaml
42
+ # load ref model when adapter training.
43
+ rl_adapter_ref_model: true
44
+ ```
src/axolotl/train.py CHANGED
@@ -63,10 +63,15 @@ def train(
63
  model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
64
  model_ref = None
65
  if cfg.rl:
66
- # load the model again for model_ref/baseline
67
- model_ref, _ = load_model(
68
- cfg, tokenizer, inference=cli_args.inference, reference_model=True
69
- )
 
 
 
 
 
70
 
71
  safe_serialization = cfg.save_safetensors is True
72
 
 
63
  model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
64
  model_ref = None
65
  if cfg.rl:
66
+ if cfg.adapter and not cfg.rl_adapter_ref_model:
67
+ # use built-in trl autounwrap
68
+ LOG.debug("Passing model_ref: None to RL trainer")
69
+ model_ref = None # explicit setting to None
70
+ else:
71
+ # load the model again for model_ref/baseline
72
+ model_ref, _ = load_model(
73
+ cfg, tokenizer, inference=cli_args.inference, reference_model=True
74
+ )
75
 
76
  safe_serialization = cfg.save_safetensors is True
77