# RWKV v5 multi-size training experiment

**Note:** This project assumes you have the rwkv-infctx conda env setup

# Basic Setup

In [1]:
# First lets setup the various directories, and init the model
!mkdir -p ../../../../model/
!mkdir -p ../../../../datapath/
!mkdir -p ../../../../checkpoint/

In [2]:
DEEPSPEED_STRAT="deepspeed_stage_1"
GPU_DEVICES="auto"
ENABLE_WANDB=True

EMBED_SCALE=0.01
EMBED_SCALE_LABEL=str(EMBED_SCALE).replace(".", "_")

LAYER_COUNT=12
EMBED_SIZE=2048

WANDB_PREFIX=f"[Multi-size] v5-L{LAYER_COUNT}-D{EMBED_SIZE}-E{EMBED_SCALE}"
FILENAME_PREFIX=f"v5-L{LAYER_COUNT}-D{EMBED_SIZE}-E{EMBED_SCALE_LABEL}"

print("DEEPSPEED_STRAT:", DEEPSPEED_STRAT)
print("ENABLE_WANDB:", ENABLE_WANDB)
print("GPU_DEVICES:", GPU_DEVICES)

if ENABLE_WANDB:
    WANDB_MODE="online"
else:
    WANDB_MODE="disabled"

# Computing the notebook, and various paths
import os
NOTEBOOK_DIR=os.path.dirname(os.path.abspath("__file__"))
PROJECT_DIR=os.path.abspath(os.path.join(NOTEBOOK_DIR, "../../../../"))
TRAINER_DIR=os.path.abspath(os.path.join(PROJECT_DIR, "./RWKV-v5/"))
INFERENCE_DIR=os.path.abspath(os.path.join(PROJECT_DIR, "./RWKV-v5/"))

print("NOTEBOOK_DIR:", NOTEBOOK_DIR)
print("INFERENCE_DIR:", INFERENCE_DIR)
print("TRAINER_DIR:", TRAINER_DIR)
print("PROJECT_DIR:", PROJECT_DIR)

DEEPSPEED_STRAT: deepspeed_stage_1
ENABLE_WANDB: True
GPU_DEVICES: auto
NOTEBOOK_DIR: /actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/notebook/experiment/rwkv-x-exp/multi-size-train
INFERENCE_DIR: /actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5
TRAINER_DIR: /actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5
PROJECT_DIR: /actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer


In [3]:
# Init the model
!cd "{TRAINER_DIR}" && \
    python3 ./init_model.py \
        --n_layer {LAYER_COUNT} --n_embd {EMBED_SIZE} \
        --emb-scale "{EMBED_SCALE}" \
        --vocab_size neox --skip-if-exists \
        "../model/{FILENAME_PREFIX}-neox-v5base-init.pth"

[2023-09-29 09:57:16,435] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.0.1+cu118'
---- Initializing model ----
No of layers: 12
Embedding size: 2048
Output model path: ../model/v5-L12-D2048-E0_01-neox-v5base-init.pth
Vocab size: 50277
Emb scale: 0.01
Note: this process takes a significant time (and ram) for large models
---- ----- ----


50277 2048  -0.01 emb.weight


2048  2048  1.0  blocks.0.att.gate.weight


2048  2048  1.0  blocks.0.att.receptance.weight


2048  2048  1.0  blocks.0.att.key.weight


2048  2048  1.0  blocks.0.att.value.weight


2048  2048  0    blocks.0.att.output.weight
7168  2048  1.0  blocks.0.ffn.key.weight


2048  2048  0    blocks.0.ffn.receptance.weight
2048  7168  0    blocks.0.ffn.value.weight


2048  2048  1.0  blocks.1.att.gate.weight


2048  2048  1.0  blocks.1.att.receptance.weight


2048  2048  1.0  blocks.1.att.key.weight


2048  2048  1.0  blocks.1.att.value.weight


2048  2048  0    blocks.1.att.output.weight


7168  2048  1.0  blocks.1.ffn.key.weight


2048  2048  0    blocks.1.ffn.receptance.weight
2048  7168  0    blocks.1.ffn.value.weight
2048  2048  1.0  blocks.2.att.gate.weight


2048  2048  1.0  blocks.2.att.receptance.weight


2048  2048  1.0  blocks.2.att.key.weight


2048  2048  1.0  blocks.2.att.value.weight


2048  2048  0    blocks.2.att.output.weight
7168  2048  1.0  blocks.2.ffn.key.weight


2048  2048  0    blocks.2.ffn.receptance.weight
2048  7168  0    blocks.2.ffn.value.weight


2048  2048  1.0  blocks.3.att.gate.weight


2048  2048  1.0  blocks.3.att.receptance.weight


2048  2048  1.0  blocks.3.att.key.weight


2048  2048  1.0  blocks.3.att.value.weight


2048  2048  0    blocks.3.att.output.weight


7168  2048  1.0  blocks.3.ffn.key.weight


2048  2048  0    blocks.3.ffn.receptance.weight
2048  7168  0    blocks.3.ffn.value.weight


2048  2048  1.0  blocks.4.att.gate.weight


2048  2048  1.0  blocks.4.att.receptance.weight


2048  2048  1.0  blocks.4.att.key.weight


2048  2048  1.0  blocks.4.att.value.weight


2048  2048  0    blocks.4.att.output.weight


7168  2048  1.0  blocks.4.ffn.key.weight


2048  2048  0    blocks.4.ffn.receptance.weight
2048  7168  0    blocks.4.ffn.value.weight


2048  2048  1.0  blocks.5.att.gate.weight


2048  2048  1.0  blocks.5.att.receptance.weight


2048  2048  1.0  blocks.5.att.key.weight


2048  2048  1.0  blocks.5.att.value.weight


2048  2048  0    blocks.5.att.output.weight
7168  2048  1.0  blocks.5.ffn.key.weight


2048  2048  0    blocks.5.ffn.receptance.weight
2048  7168  0    blocks.5.ffn.value.weight


2048  2048  1.0  blocks.6.att.gate.weight


2048  2048  1.0  blocks.6.att.receptance.weight


2048  2048  1.0  blocks.6.att.key.weight


2048  2048  1.0  blocks.6.att.value.weight


2048  2048  0    blocks.6.att.output.weight
7168  2048  1.0  blocks.6.ffn.key.weight


2048  2048  0    blocks.6.ffn.receptance.weight
2048  7168  0    blocks.6.ffn.value.weight


2048  2048  1.0  blocks.7.att.gate.weight


2048  2048  1.0  blocks.7.att.receptance.weight


2048  2048  1.0  blocks.7.att.key.weight


2048  2048  1.0  blocks.7.att.value.weight


2048  2048  0    blocks.7.att.output.weight
7168  2048  1.0  blocks.7.ffn.key.weight


2048  2048  0    blocks.7.ffn.receptance.weight
2048  7168  0    blocks.7.ffn.value.weight
2048  2048  1.0  blocks.8.att.gate.weight


2048  2048  1.0  blocks.8.att.receptance.weight


2048  2048  1.0  blocks.8.att.key.weight


2048  2048  1.0  blocks.8.att.value.weight


2048  2048  0    blocks.8.att.output.weight
7168  2048  1.0  blocks.8.ffn.key.weight


2048  2048  0    blocks.8.ffn.receptance.weight
2048  7168  0    blocks.8.ffn.value.weight


2048  2048  1.0  blocks.9.att.gate.weight


2048  2048  1.0  blocks.9.att.receptance.weight


2048  2048  1.0  blocks.9.att.key.weight


2048  2048  1.0  blocks.9.att.value.weight


2048  2048  0    blocks.9.att.output.weight
7168  2048  1.0  blocks.9.ffn.key.weight


2048  2048  0    blocks.9.ffn.receptance.weight
2048  7168  0    blocks.9.ffn.value.weight


2048  2048  1.0  blocks.10.att.gate.weight


2048  2048  1.0  blocks.10.att.receptance.weight


2048  2048  1.0  blocks.10.att.key.weight


2048  2048  1.0  blocks.10.att.value.weight


2048  2048  0    blocks.10.att.output.weight
7168  2048  1.0  blocks.10.ffn.key.weight


2048  2048  0    blocks.10.ffn.receptance.weight
2048  7168  0    blocks.10.ffn.value.weight


2048  2048  1.0  blocks.11.att.gate.weight


2048  2048  1.0  blocks.11.att.receptance.weight


2048  2048  1.0  blocks.11.att.key.weight


2048  2048  1.0  blocks.11.att.value.weight


2048  2048  0    blocks.11.att.output.weight
7168  2048  1.0  blocks.11.ffn.key.weight


2048  2048  0    blocks.11.ffn.receptance.weight
2048  7168  0    blocks.11.ffn.value.weight


50277 2048  0.5  head.weight


## Enwiki Stage 1 : Foundation 4k model training

In [4]:
# Lets preload the requried dataset 
!cd "{TRAINER_DIR}" && \
    python3 preload_datapath.py "{NOTEBOOK_DIR}/enwiki-4k-part1.yaml"

Saving the dataset (0/3 shards):   0%|         | 0/54401 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   4%| | 2000/54401 [00:00<00:03, 15197.33 examp

Saving the dataset (0/3 shards):   7%| | 4000/54401 [00:00<00:03, 15929.46 examp

Saving the dataset (0/3 shards):  11%| | 6000/54401 [00:00<00:02, 16418.37 examp

Saving the dataset (0/3 shards):  15%|‚ñè| 8000/54401 [00:00<00:02, 16923.89 examp

Saving the dataset (0/3 shards):  18%|‚ñè| 10000/54401 [00:00<00:02, 17273.31 exam

Saving the dataset (0/3 shards):  22%|‚ñè| 12000/54401 [00:00<00:02, 17662.61 exam

Saving the dataset (0/3 shards):  26%|‚ñé| 14000/54401 [00:00<00:02, 17923.49 exam

Saving the dataset (0/3 shards):  29%|‚ñé| 16000/54401 [00:00<00:02, 18184.27 exam

Saving the dataset (0/3 shards):  33%|‚ñé| 18000/54401 [00:01<00:01, 18438.75 examSaving the dataset (1/3 shards):  33%|‚ñé| 18134/54401 [00:01<00:01, 18438.75 exam

Saving the dataset (1/3 shards):  37%|‚ñé| 20134/54401 [00:01<00:01, 17356.03 exam

Saving the dataset (1/3 shards):  41%|‚ñç| 22134/54401 [00:01<00:01, 17970.31 exam

Saving the dataset (1/3 shards):  44%|‚ñç| 24134/54401 [00:01<00:01, 18401.36 exam

Saving the dataset (1/3 shards):  48%|‚ñç| 26134/54401 [00:01<00:01, 18772.52 exam

Saving the dataset (1/3 shards):  52%|‚ñå| 28134/54401 [00:01<00:01, 19015.25 exam

Saving the dataset (1/3 shards):  55%|‚ñå| 30134/54401 [00:01<00:01, 19175.86 exam

Saving the dataset (1/3 shards):  59%|‚ñå| 32134/54401 [00:01<00:01, 19340.44 exam

Saving the dataset (1/3 shards):  63%|‚ñã| 34134/54401 [00:01<00:01, 19458.62 exam

Saving the dataset (1/3 shards):  67%|‚ñã| 36268/54401 [00:01<00:00, 19480.21 examSaving the dataset (2/3 shards):  67%|‚ñã| 36268/54401 [00:01<00:00, 19480.21 exam

Saving the dataset (2/3 shards):  72%|‚ñã| 39268/54401 [00:02<00:00, 19488.59 exam

Saving the dataset (2/3 shards):  80%|‚ñä| 43268/54401 [00:02<00:00, 19830.22 exam

Saving the dataset (2/3 shards):  87%|‚ñä| 47268/54401 [00:02<00:00, 20058.57 exam

Saving the dataset (2/3 shards):  94%|‚ñâ| 51268/54401 [00:02<00:00, 20178.13 exam

Saving the dataset (2/3 shards): 100%|‚ñà| 54401/54401 [00:02<00:00, 20197.74 examSaving the dataset (3/3 shards): 100%|‚ñà| 54401/54401 [00:02<00:00, 20197.74 examSaving the dataset (3/3 shards): 100%|‚ñà| 54401/54401 [00:02<00:00, 18877.90 exam
Saving the dataset (0/1 shards):   0%|           | 0/109 [00:00<?, ? examples/s]

Saving the dataset (1/1 shards): 100%|‚ñà| 109/109 [00:00<00:00, 7330.11 examples/Saving the dataset (1/1 shards): 100%|‚ñà| 109/109 [00:00<00:00, 7058.50 examples/


In [5]:
# Start the foundation model training
!cd "{TRAINER_DIR}" && \
    export WANDB_MODE="{WANDB_MODE}" && \
    python3 lightning_trainer.py fit \
        -c "{NOTEBOOK_DIR}/enwiki-4k-part1.yaml" \
        --trainer.logger.init_args.name="{WANDB_PREFIX} - Enwiki-4k Part 1 (train-ctx=4k, {DEEPSPEED_STRAT})" \
        --trainer.strategy="{DEEPSPEED_STRAT}" \
        --trainer.devices="{GPU_DEVICES}" \
        --trainer.callbacks.init_args.dirpath="../checkpoint/{FILENAME_PREFIX}-enwiki-4k-p1/" \
        --model.load_model="../model/{FILENAME_PREFIX}-neox-v5base-init.pth" \
        --model.ctx_len=4096 \
        --model.bptt_learning_range=1

[2023-09-29 09:58:12,868] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.0.1+cu118'


  rank_zero_warn(


  rank_zero_warn(f"No seed found, seed set to {seed}")
Global seed set to 207026176


[34m[1mwandb[0m: Currently logged in as: [33mpicocreator[0m ([33mrwkv-x-dev[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Tracking run with wandb version 0.15.11
[34m[1mwandb[0m: Run data is saved locally in [35m[1m./wandb/run-20230929_095815-3rwyj6ei[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33m[Multi-size] v5-L12-D2048-E0.01 - Enwiki-4k Part 1 (train-ctx=4k, deepspeed_stage_1)[0m
[34m[1mwandb[0m: ‚≠êÔ∏è View project at [34m[4mhttps://wandb.ai/rwkv-x-dev/RWKV-5X-Experiments[0m
[34m[1mwandb[0m: üöÄ View run at [34m[4mhttps://wandb.ai/rwkv-x-dev/RWKV-5X-Experiments/runs/3rwyj6ei[0m


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


[RWKV.Trainer] Applying 'target_batch_size' with the following:
   - target_batch_size:       32
   - num_nodes:               1
   - num_devices:             1
   - accumulate_grad_batches: 32
   - effective_batch_size:    32



Saving the dataset (0/3 shards):   0%|         | 0/54401 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   2%| | 1000/54401 [00:00<00:07, 6759.93 exampl

Saving the dataset (0/3 shards):   6%| | 3000/54401 [00:00<00:04, 12309.88 examp

Saving the dataset (0/3 shards):   9%| | 5000/54401 [00:00<00:03, 14423.41 examp

Saving the dataset (0/3 shards):  13%|‚ñè| 7000/54401 [00:00<00:03, 15440.72 examp

Saving the dataset (0/3 shards):  17%|‚ñè| 9000/54401 [00:00<00:02, 16338.39 examp

Saving the dataset (0/3 shards):  22%|‚ñè| 12000/54401 [00:00<00:02, 17519.79 exam

Saving the dataset (0/3 shards):  28%|‚ñé| 15000/54401 [00:00<00:02, 18285.95 exam

Saving the dataset (0/3 shards):  33%|‚ñé| 18134/54401 [00:01<00:01, 19055.07 exam

Saving the dataset (1/3 shards):  33%|‚ñé| 18134/54401 [00:01<00:01, 19055.07 exam

Saving the dataset (1/3 shards):  37%|‚ñé| 20134/54401 [00:01<00:05, 6782.90 examp

Saving the dataset (1/3 shards):  44%|‚ñç| 24134/54401 [00:02<00:03, 9522.94 examp

Saving the dataset (1/3 shards):  52%|‚ñå| 28134/54401 [00:02<00:02, 12076.84 exam

Saving the dataset (1/3 shards):  59%|‚ñå| 32134/54401 [00:02<00:01, 14336.61 exam

Saving the dataset (1/3 shards):  66%|‚ñã| 36134/54401 [00:02<00:01, 16234.67 exam

Saving the dataset (2/3 shards):  67%|‚ñã| 36268/54401 [00:03<00:01, 16234.67 exam

Saving the dataset (2/3 shards):  70%|‚ñã| 38268/54401 [00:03<00:02, 7102.74 examp

Saving the dataset (2/3 shards):  78%|‚ñä| 42268/54401 [00:03<00:01, 9363.53 examp

Saving the dataset (2/3 shards):  85%|‚ñä| 46268/54401 [00:03<00:00, 11635.76 exam

Saving the dataset (2/3 shards):  92%|‚ñâ| 50268/54401 [00:04<00:00, 13831.20 exam

Saving the dataset (2/3 shards): 100%|‚ñâ| 54268/54401 [00:04<00:00, 15794.33 exam

Saving the dataset (3/3 shards): 100%|‚ñà| 54401/54401 [00:05<00:00, 15794.33 examSaving the dataset (3/3 shards): 100%|‚ñà| 54401/54401 [00:05<00:00, 10560.28 exam
Saving the dataset (0/1 shards):   0%|           | 0/109 [00:00<?, ? examples/s]Saving the dataset (1/1 shards): 100%|‚ñà| 109/109 [00:00<00:00, 7298.75 examples/

Saving the dataset (1/1 shards): 100%|‚ñà| 109/109 [00:00<00:00, 6907.18 examples/


[rank: 0] Global seed set to 207026176
initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/1


Enabling DeepSpeed BF16.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
#
# RWKV lighting_trainer.py important notes 
# https://github.com/RWKV/RWKV-infctx-trainer 
#
# - Ensure your host is not running cuda 12.0 (use either 11.8, or >=12.1), as this is known to have freeze issues
# - The terms used in wandb / the progress bar can be confusing, see the github README.md for beter clarifications
# - When resuming from checkpoint, the estimated time is inaccurate
#

[RWKV.model] Configuring optimizer with
    - lr_init:  6.000e-04 (0.0006)
    - lr_final: 5.000e-04 (0.0005)

Using /root/.cache/torch_extensions/py310_cu118 as PyTorch extensions root...


Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py310_cu118/fused_adam/build.ninja...
Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.
Loading extension module fused_adam...
Time to load fused_adam op: 0.07915163040161133 seconds
Loading `train_dataloader` to estimate number of stepping batches.


Rank: 0 partition count [1, 1] and sizes[(860549120, False), (768, False)] 



  | Name   | Type       | Params
--------------------------------------
0 | emb    | Embedding  | 102 M 
1 | blocks | ModuleList | 654 M 
2 | ln_out | LayerNorm  | 4.1 K 
3 | head   | Linear     | 102 M 
--------------------------------------
860 M     Trainable params
0         Non-trainable params
860 M     Total params
3,442.200 Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]Training:   0%|                                       | 0/54401 [00:00<?, ?it/s]Epoch 0:   0%|                                        | 0/54401 [00:00<?, ?it/s]

Epoch 0:   0%|                            | 1/54401 [00:10<154:16:48, 10.21s/it]Epoch 0:   0%| | 1/54401 [00:10<154:17:58, 10.21s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 2/54401 [00:11<86:21:01,  5.71s/it, v_num=j6ei, train/loss=10.9Epoch 0:   0%| | 2/54401 [00:11<86:21:23,  5.71s/it, v_num=j6ei, train/loss=11.0

Epoch 0:   0%| | 3/54401 [00:12<63:41:34,  4.22s/it, v_num=j6ei, train/loss=11.0Epoch 0:   0%| | 3/54401 [00:12<63:41:49,  4.22s/it, v_num=j6ei, train/loss=10.9

Epoch 0:   0%| | 4/54401 [00:13<52:21:56,  3.47s/it, v_num=j6ei, train/loss=10.9Epoch 0:   0%| | 4/54401 [00:13<52:22:08,  3.47s/it, v_num=j6ei, train/loss=10.9

Epoch 0:   0%| | 5/54401 [00:15<45:34:11,  3.02s/it, v_num=j6ei, train/loss=10.9Epoch 0:   0%| | 5/54401 [00:15<45:34:21,  3.02s/it, v_num=j6ei, train/loss=10.9

Epoch 0:   0%| | 6/54401 [00:16<41:01:56,  2.72s/it, v_num=j6ei, train/loss=10.9Epoch 0:   0%| | 6/54401 [00:16<41:02:04,  2.72s/it, v_num=j6ei, train/loss=10.9

Epoch 0:   0%| | 7/54401 [00:17<37:47:47,  2.50s/it, v_num=j6ei, train/loss=10.9Epoch 0:   0%| | 7/54401 [00:17<37:47:53,  2.50s/it, v_num=j6ei, train/loss=10.9

Epoch 0:   0%| | 8/54401 [00:18<35:22:29,  2.34s/it, v_num=j6ei, train/loss=10.9Epoch 0:   0%| | 8/54401 [00:18<35:22:35,  2.34s/it, v_num=j6ei, train/loss=10.9

Epoch 0:   0%| | 9/54401 [00:19<33:29:09,  2.22s/it, v_num=j6ei, train/loss=10.9Epoch 0:   0%| | 9/54401 [00:19<33:29:15,  2.22s/it, v_num=j6ei, train/loss=10.9

Epoch 0:   0%| | 10/54401 [00:21<31:58:25,  2.12s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 10/54401 [00:21<31:58:30,  2.12s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 11/54401 [00:22<30:44:06,  2.03s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 11/54401 [00:22<30:44:10,  2.03s/it, v_num=j6ei, train/loss=11.

Epoch 0:   0%| | 12/54401 [00:23<29:42:14,  1.97s/it, v_num=j6ei, train/loss=11.Epoch 0:   0%| | 12/54401 [00:23<29:42:17,  1.97s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 13/54401 [00:24<28:49:59,  1.91s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 13/54401 [00:24<28:50:03,  1.91s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 14/54401 [00:26<28:05:17,  1.86s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 14/54401 [00:26<28:05:20,  1.86s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 15/54401 [00:27<27:26:24,  1.82s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 15/54401 [00:27<27:26:27,  1.82s/it, v_num=j6ei, train/loss=11.

Epoch 0:   0%| | 16/54401 [00:28<26:52:37,  1.78s/it, v_num=j6ei, train/loss=11.

Epoch 0:   0%| | 16/54401 [00:28<26:52:39,  1.78s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 17/54401 [00:29<26:22:57,  1.75s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 17/54401 [00:29<26:23:00,  1.75s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 18/54401 [00:30<25:57:05,  1.72s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 18/54401 [00:30<25:57:08,  1.72s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 19/54401 [00:32<25:33:20,  1.69s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 19/54401 [00:32<25:33:22,  1.69s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 20/54401 [00:33<25:11:45,  1.67s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 20/54401 [00:33<25:11:48,  1.67s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 21/54401 [00:34<24:52:15,  1.65s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 21/54401 [00:34<24:52:17,  1.65s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 22/54401 [00:35<24:34:31,  1.63s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 22/54401 [00:35<24:34:33,  1.63s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 23/54401 [00:37<24:18:22,  1.61s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 23/54401 [00:37<24:18:24,  1.61s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 24/54401 [00:38<24:03:23,  1.59s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 24/54401 [00:38<24:03:25,  1.59s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 25/54401 [00:39<23:49:44,  1.58s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 25/54401 [00:39<23:49:46,  1.58s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 26/54401 [00:40<23:37:08,  1.56s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 26/54401 [00:40<23:37:10,  1.56s/it, v_num=j6ei, train/loss=11.

Epoch 0:   0%| | 27/54401 [00:41<23:25:25,  1.55s/it, v_num=j6ei, train/loss=11.Epoch 0:   0%| | 27/54401 [00:41<23:25:27,  1.55s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 28/54401 [00:43<23:14:36,  1.54s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 28/54401 [00:43<23:14:38,  1.54s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 29/54401 [00:44<23:04:27,  1.53s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 29/54401 [00:44<23:04:28,  1.53s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 30/54401 [00:45<22:55:04,  1.52s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 30/54401 [00:45<22:55:05,  1.52s/it, v_num=j6ei, train/loss=11.

Epoch 0:   0%| | 31/54401 [00:46<22:46:11,  1.51s/it, v_num=j6ei, train/loss=11.Epoch 0:   0%| | 31/54401 [00:46<22:46:12,  1.51s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 32/54401 [00:48<22:43:11,  1.50s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 32/54401 [00:48<22:45:59,  1.51s/it, v_num=j6ei, train/loss=10.

Epoch 0:   0%| | 33/54401 [00:49<22:37:44,  1.50s/it, v_num=j6ei, train/loss=10.Epoch 0:   0%| | 33/54401 [00:49<22:37:46,  1.50s/it, v_num=j6ei, train/loss=9.4

Epoch 0:   0%| | 34/54401 [00:50<22:30:18,  1.49s/it, v_num=j6ei, train/loss=9.4Epoch 0:   0%| | 34/54401 [00:50<22:30:19,  1.49s/it, v_num=j6ei, train/loss=9.2

Epoch 0:   0%| | 35/54401 [00:51<22:23:17,  1.48s/it, v_num=j6ei, train/loss=9.2Epoch 0:   0%| | 35/54401 [00:51<22:23:18,  1.48s/it, v_num=j6ei, train/loss=9.3

Epoch 0:   0%| | 36/54401 [00:53<22:16:41,  1.48s/it, v_num=j6ei, train/loss=9.3Epoch 0:   0%| | 36/54401 [00:53<22:16:42,  1.48s/it, v_num=j6ei, train/loss=9.6

Epoch 0:   0%| | 37/54401 [00:54<22:10:26,  1.47s/it, v_num=j6ei, train/loss=9.6Epoch 0:   0%| | 37/54401 [00:54<22:10:27,  1.47s/it, v_num=j6ei, train/loss=9.6

Epoch 0:   0%| | 38/54401 [00:55<22:04:33,  1.46s/it, v_num=j6ei, train/loss=9.6Epoch 0:   0%| | 38/54401 [00:55<22:04:35,  1.46s/it, v_num=j6ei, train/loss=9.4

Epoch 0:   0%| | 39/54401 [00:56<21:58:58,  1.46s/it, v_num=j6ei, train/loss=9.4Epoch 0:   0%| | 39/54401 [00:56<21:58:59,  1.46s/it, v_num=j6ei, train/loss=9.5

Epoch 0:   0%| | 40/54401 [00:57<21:53:35,  1.45s/it, v_num=j6ei, train/loss=9.5Epoch 0:   0%| | 40/54401 [00:57<21:53:36,  1.45s/it, v_num=j6ei, train/loss=9.6

Epoch 0:   0%| | 41/54401 [00:59<21:48:31,  1.44s/it, v_num=j6ei, train/loss=9.6Epoch 0:   0%| | 41/54401 [00:59<21:48:32,  1.44s/it, v_num=j6ei, train/loss=9.4

Epoch 0:   0%| | 42/54401 [01:00<21:43:42,  1.44s/it, v_num=j6ei, train/loss=9.4Epoch 0:   0%| | 42/54401 [01:00<21:43:43,  1.44s/it, v_num=j6ei, train/loss=9.4

Epoch 0:   0%| | 43/54401 [01:01<21:39:04,  1.43s/it, v_num=j6ei, train/loss=9.4Epoch 0:   0%| | 43/54401 [01:01<21:39:06,  1.43s/it, v_num=j6ei, train/loss=9.6

Epoch 0:   0%| | 44/54401 [01:02<21:34:37,  1.43s/it, v_num=j6ei, train/loss=9.6Epoch 0:   0%| | 44/54401 [01:02<21:34:38,  1.43s/it, v_num=j6ei, train/loss=9.5

Epoch 0:   0%| | 45/54401 [01:04<21:30:26,  1.42s/it, v_num=j6ei, train/loss=9.5Epoch 0:   0%| | 45/54401 [01:04<21:30:27,  1.42s/it, v_num=j6ei, train/loss=9.5

Epoch 0:   0%| | 46/54401 [01:05<21:26:24,  1.42s/it, v_num=j6ei, train/loss=9.5Epoch 0:   0%| | 46/54401 [01:05<21:26:25,  1.42s/it, v_num=j6ei, train/loss=9.5

Epoch 0:   0%| | 47/54401 [01:06<21:22:30,  1.42s/it, v_num=j6ei, train/loss=9.5Epoch 0:   0%| | 47/54401 [01:06<21:22:31,  1.42s/it, v_num=j6ei, train/loss=9.5

Epoch 0:   0%| | 48/54401 [01:07<21:18:53,  1.41s/it, v_num=j6ei, train/loss=9.5Epoch 0:   0%| | 48/54401 [01:07<21:18:54,  1.41s/it, v_num=j6ei, train/loss=9.5

Epoch 0:   0%| | 49/54401 [01:08<21:15:23,  1.41s/it, v_num=j6ei, train/loss=9.5Epoch 0:   0%| | 49/54401 [01:08<21:15:24,  1.41s/it, v_num=j6ei, train/loss=9.5

Epoch 0:   0%| | 50/54401 [01:10<21:11:54,  1.40s/it, v_num=j6ei, train/loss=9.5Epoch 0:   0%| | 50/54401 [01:10<21:11:55,  1.40s/it, v_num=j6ei, train/loss=9.4

Epoch 0:   0%| | 51/54401 [01:11<21:08:37,  1.40s/it, v_num=j6ei, train/loss=9.4Epoch 0:   0%| | 51/54401 [01:11<21:08:38,  1.40s/it, v_num=j6ei, train/loss=9.5

Epoch 0:   0%| | 52/54401 [01:12<21:05:29,  1.40s/it, v_num=j6ei, train/loss=9.5Epoch 0:   0%| | 52/54401 [01:12<21:05:30,  1.40s/it, v_num=j6ei, train/loss=9.6

Epoch 0:   0%| | 53/54401 [01:13<21:02:31,  1.39s/it, v_num=j6ei, train/loss=9.6Epoch 0:   0%| | 53/54401 [01:13<21:02:32,  1.39s/it, v_num=j6ei, train/loss=9.6

Epoch 0:   0%| | 54/54401 [01:15<20:59:37,  1.39s/it, v_num=j6ei, train/loss=9.6Epoch 0:   0%| | 54/54401 [01:15<20:59:38,  1.39s/it, v_num=j6ei, train/loss=9.5

Epoch 0:   0%| | 55/54401 [01:16<20:56:54,  1.39s/it, v_num=j6ei, train/loss=9.5Epoch 0:   0%| | 55/54401 [01:16<20:56:55,  1.39s/it, v_num=j6ei, train/loss=9.5

Epoch 0:   0%| | 56/54401 [01:17<20:54:17,  1.38s/it, v_num=j6ei, train/loss=9.5Epoch 0:   0%| | 56/54401 [01:17<20:54:18,  1.38s/it, v_num=j6ei, train/loss=9.3

Epoch 0:   0%| | 57/54401 [01:18<20:51:46,  1.38s/it, v_num=j6ei, train/loss=9.3Epoch 0:   0%| | 57/54401 [01:18<20:51:47,  1.38s/it, v_num=j6ei, train/loss=9.4

Epoch 0:   0%| | 58/54401 [01:20<20:49:18,  1.38s/it, v_num=j6ei, train/loss=9.4Epoch 0:   0%| | 58/54401 [01:20<20:49:19,  1.38s/it, v_num=j6ei, train/loss=9.3

Epoch 0:   0%| | 59/54401 [01:21<20:46:56,  1.38s/it, v_num=j6ei, train/loss=9.3Epoch 0:   0%| | 59/54401 [01:21<20:46:57,  1.38s/it, v_num=j6ei, train/loss=9.6

Epoch 0:   0%| | 60/54401 [01:22<20:44:35,  1.37s/it, v_num=j6ei, train/loss=9.6Epoch 0:   0%| | 60/54401 [01:22<20:44:35,  1.37s/it, v_num=j6ei, train/loss=9.5

Epoch 0:   0%| | 61/54401 [01:23<20:42:18,  1.37s/it, v_num=j6ei, train/loss=9.5Epoch 0:   0%| | 61/54401 [01:23<20:42:19,  1.37s/it, v_num=j6ei, train/loss=9.3

Epoch 0:   0%| | 62/54401 [01:24<20:40:02,  1.37s/it, v_num=j6ei, train/loss=9.3Epoch 0:   0%| | 62/54401 [01:24<20:40:03,  1.37s/it, v_num=j6ei, train/loss=9.6

Epoch 0:   0%| | 63/54401 [01:26<20:37:53,  1.37s/it, v_num=j6ei, train/loss=9.6Epoch 0:   0%| | 63/54401 [01:26<20:37:54,  1.37s/it, v_num=j6ei, train/loss=9.4

Traceback (most recent call last):
  File "/actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5/lightning_trainer.py", line 278, in <module>
    cli_main()
  File "/actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5/lightning_trainer.py", line 253, in cli_main
    LightningCLI(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/cli.py", line 353, in __init__
    self._run_subcommand(self.subcommand)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/cli.py", line 642, in _run_subcommand
    fn(**fn_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 529, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 41, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightnin

[34m[1mwandb[0m: - 0.005 MB of 0.005 MB uploaded (0.000 MB deduped)

[34m[1mwandb[0m: \ 0.005 MB of 0.005 MB uploaded (0.000 MB deduped)

[34m[1mwandb[0m: | 0.005 MB of 0.005 MB uploaded (0.000 MB deduped)

[34m[1mwandb[0m: / 0.005 MB of 0.005 MB uploaded (0.000 MB deduped)

[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:                  batchidx ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
[34m[1mwandb[0m:               global_rank ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
[34m[1mwandb[0m: perf/tokens_per_sec.gpu.0 ‚ñÅ‚ñÅ‚ñÉ‚ñÉ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
[34m[1mwandb[0m:   perf/tokens_total.gpu.0 ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
[34m[1mwandb[0m:              real_ctx_len ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
[34m[1mwandb[0m:                   substep ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚

In [6]:
# Lets export the model from the checkpoint
!cd "{TRAINER_DIR}" && \
    python3 export_checkpoint.py "../checkpoint/{FILENAME_PREFIX}-enwiki-4k-p1/last.ckpt" "../model/{FILENAME_PREFIX}-enwiki-4k-p1.pth" "bf16"
!cd "{TRAINER_DIR}" && ls -alh "../model/{FILENAME_PREFIX}-enwiki-4k-p1.pth"

[2023-09-29 10:00:23,854] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


Traceback (most recent call last):
  File "/actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5/export_checkpoint.py", line 651, in <module>
    convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, output_file, save_dtype=args.dtype)
  File "/actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5/export_checkpoint.py", line 542, in convert_zero_checkpoint_to_fp32_state_dict
    state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
  File "/actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5/export_checkpoint.py", line 516, in get_fp32_state_dict_from_zero_checkpoint
    raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ValueError: Unable to find 'latest' file at ../checkpoint/v5-L12-D2048-E0_01-enwiki-4k-p1/last.ckpt/latest


ls: cannot access '../model/v5-L12-D2048-E0_01-enwiki-4k-p1.pth': No such file or directory


In [7]:
# # Lets do a quick dragon prompt validation
!cd "{INFERENCE_DIR}" && \
    python3 dragon_test.py "../model/{FILENAME_PREFIX}-enwiki-4k-p1.pth" "cuda fp32"

[2023-09-29 10:00:29,417] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.0.1+cu118'
Traceback (most recent call last):
  File "/actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5/dragon_test.py", line 52, in <module>
    model = SimpleRWKV(MODEL_PATH, device=DEVICE)
  File "/actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5/src/model.py", line 1420, in __init__
    self.model = RWKV(**model_config)
  File "/actions-runner/_work/RWKV-infctx-trainer/RWKV-infctx-trainer/RWKV-v5/src/model.py", line 566, in __init__
    raise ValueError(f"load_model file '{load_model}' does not exist")
ValueError: load_model file '../model/v5-L12-D2048-E0_01-enwiki-4k-p1.pth' does not exist
