File size: 10,489 Bytes
2b739a9
 
 
f0fb314
646f05b
f0fb314
cb0c160
f0fb314
21c2476
0c4f028
 
 
db6b323
ad71a7a
0c4f028
ad71a7a
0c4f028
 
 
 
 
ad71a7a
0c4f028
ad71a7a
 
0c4f028
 
 
 
 
 
 
 
c692566
 
 
db6b323
0c4f028
c692566
0c4f028
ad71a7a
 
db6b323
ad71a7a
 
db6b323
 
 
 
 
0c4f028
 
 
 
ad71a7a
db6b323
ad71a7a
1d72170
 
 
 
 
 
db6b323
1d72170
 
 
 
 
 
db6b323
1d72170
db6b323
 
b7c10b1
db6b323
 
ad71a7a
db6b323
 
 
 
 
 
b7c10b1
db6b323
 
 
b7c10b1
ad71a7a
b7c10b1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
---
license: agpl-3.0
---

This repo catalogs my weights for use with my [VALL-E](https://github.com/e-c-k-e-r/vall-e) implementation as I try and iron out the kinks.

The model currently is in a *semi-usable* state, and I'm releasing them now in hopes that it also helps jumpstart anyone else that wants to use them.

To reiterate, this is ***by no means*** complete. I am not passing this off as competitive.

## Models

This repo contains the following configurations under `./models/`:

* `config.retnet.yaml` / `ar+nar-retnet-8`: The previously released weights.
	+ This configuration utilizes a RetNet (retention based "transformer") as the underlying architecture due to a number of misleading interpretations with comparisons, for better or for worse.
		+ Prompt and response embeddings are summed (further RVQ levels gets the previous RVQ levels' embeddings factored in).
		+ Tokenizer is a homebrewed "naive" implementation.
	+ This model received the most training time between my 4070Ti, 7900XTX, and a few rental rigs to training further progress, entirely at `bfloat16` with `prodigyopt` (and a few optimizer restarts).
	+ The later part of training aimed to shuffle between speakers rather than the global pool of utterances to better focus on zero-shot performance. Due to this, I feel it achieved *decent* zero-shot performance.
	+ However, due to the dataset being aggressively trimmed under 12 seconds for memory savings during training, it suffers trying to inference non-short utterances. Additional training may fix this, the following models seemed to adapt well to longer utterances.
        + From the `ar+nar-llama-8` experiment, I believe this can be "fixed" with additional training on the currently processed dataset.
	+ Prior testing showed that longer prompt durations results in better utterances.
    + *Can* benefit from additional training, but I recall the average loss being around `1.9` to `2.1`.
        + However, due to regressions (or bias from working under `llama`), I don't think I can optimially train with a RetNet again (both in terms of VRAM consumption and throughput).

* `config.llama.yaml` / `ar+nar-llama-8`: The most recent-ishly trained weights after learning from my mistakes.
	+ This configuration utilizes Llama's attention-based transformer as the underlying architecture, making use of creature comforts like RoPE, GQA, and memory-efficient attention (trained under `xformers`, shouldn't really affect things).
		+ Prompt and response embeddings are NOT summed (each RVQ level only attends to the current RVQ level).
		+ Utilizes a HF tokenizer for "optimal" vocab.
		+ The current RVQ level is included as a token as well to help guide NAR tasks better.
	+ This model received a few days of training on my 4xV100s, stepping up the duration window to *try* and better make the model inference for longer utterances.
		+ Some sessions end up training the current duration window for a few epochs, but I don't know how much it affected things.
	+ ~~However, it seems to *only* do well with long utterances. Short utterances fumble. I believe further training with a variety of durations should allow the AR to handle a variety of durations.~~
		- ~~I believe the "slowly stepping up the context length" only works for text, and not audio.~~
        - Addendum: Additional brief training for a variety of duration lengths seemed to have mostly fixed this issue.
        - Addendum addendum: Properly creating the position IDs per-segment rather than the whole sequence, also helps a lot.
	+ Zero-shot performance leaves a bit to be desired, as it did not receive the special training prioritizing shuffling between speakers rather than the global pool of utterances.
        - Addendum: Additional brief training for sampling based on speaker per "epoch" (per dataloader, not dataset) seemed to slightly improve it.
	+ Testing showed that, despite also stepping up the prompt duration, it *really* likes three second prompts.
	+ Definitely needs additional training, but the next way to go is unknown.
        + Naturally, training it on a "next RVQ level is half as likely" distribution introduces some crust as the later RVQ levels are less accurate, introducing noise and artifacts.
        + As a fix for the above, naively training it on equally distributed RVQ levels *does* lobotomize the AR.
        + Additional training on the AR will see huge diminishing returns, so I don't know if it's worth doing so.
    + Seems to be a decent foundation for "distillation", at the very least for LoRA training.
    	- Addendum: it seems to serve fine for patch-training a few extra tweaks, to non-unified position IDs, split classifier heads, and para-parallel decoding for the AR.

## Experiments

Under `./models/experiments/` are some failed models, but are included to serve as references for my errors. Do ***not*** use them unless you're curious, or know what you're doing.

* `config.llama.split.yaml` / `ar-llama-1` + `nar-llama-8`: The above model, but split and trained a little bit more.
	+ This experiment is to see whether the AR and NAR benefitted from being split up after enough pretraining, to un-"lobotomize" any penalties from attending to two different tasks (as the AR predicts the next token, and the NAR predicts the same token but a different level).
	+ I believe I trained each separate model an additional extra day for another additional audio-duration window for similar training lengths.
	+ ~~I don't think audio quality differs a non-trivial amount to warrant splitting the model.~~
        - Addendum: From recent experiments, it does seem a NAR-only model is beneficial; I will need to explore this in the future.

* `config.dac.yaml` / `ar+nar-dac-llama-9`: Utilizes [Descript-Audio-Codec](https://github.com/descriptinc/descript-audio-codec/) instead as the audio backend.
    + This utilizies the 44KHz (erroneously at 44,000 Hz instead of 44,100 Hz) model at 9 RVQ levels (majorly trained at 8, then the 9th was included).
        + Originally experimented with with feeding 24Khz audio through the 44Khz model (naively assuming nothing would go wrong), but artifacts in the output proved to be too much.
        + Later experimented with the 24Khz model, but training would *always* diverge.
    + *Heavily* benefits from inferencing only the first four RVQ levels; levels afterwards includes far too much noise in the final output.
        + I imagine the nature of DAC itself amplifies errors in the remaining RVQ levels (either due to less resilliency to errors in the codes, or each RVQ level affecting hte final waveform more).
        + Addendum: restricting to the first four RVQ levels seems to help remove noisy artifacts, but quality is hindered as there's still less RVQ levels to rely on.
    + Has not received as much training as the EnCodec-based models.
        + Because of this, performance leaves more to be desired.
    + Further experimentation is needed, but the next approach is unknown.
        + Train a NAR only model to help bolster the remaining RVQ levels (outputted utterances seem a bit sluggish).
        + Continue training the AR+NAR to try and bolster the AR tasks (as it's quite lacking at the moment).
        + Delve into other, exotic features, such as utilizing DAC's decoding embeddings (which might not be necessary at all since it seems *fine* at the moment).
        	+ Addendum: This seems unneccessary, as freezing to these embeddings is harmful, and not freezing them will just inevitably cause them to shift elsewhere.

* `config.dac-nar-len.yaml` / `nar-len-llama-9`: A DAC-based model, but is a pure NAR model (+ autoregressive length task) .
	+ Originally thought to be bunk from inferencing tests having audio drastically drop off into silence, but I suppose it was just some issue that eventually resolved itself.
      + Addendum: I don't know what magic I did for that model, but I cannot recreate a decent EnCodec-backed model instead, despite the test trainer working fine.
	+ Suffers from the same problems the above model suffers from (terrible quality).
	+ *Huge* performance gains, but may definitely suffer from some specific qualities in the outputs, if it does get trained right.

* `config.llama-x4.yaml` / `ar+nar-llama-8`: The above `ar+nar-llama-8` model, but with para-parallel decoding for the AR in-post.
	+ This mostly serves as a proof-of-concept for speeding up inferencing by reducing the number of steps required, by decoding multiple tokens in parallel with a similar approach to how the NAR decodes in parallel.
		+ Trained with the trainer's batch-by-durations sampler for a maximum duration batch size of 100 seconds (750 resp tokens), with ProdigyOpt at bfloat16 (no AMP) on my 4070Ti (because I can't be assed to fire up my 4xV100 machine again for a simple test).
	+ The model definitely needs to be retrained as there's some errors for the additional tokens.
		+ If these cannot be nailed out with more training, then I imagine a similar approach to speculative decoding where the nth tokens are discarded if the confidence is low.
		+ Greedy sampling might be beneficial instead for this, as the NAR does benefit greatly from low temperatures / greedy sampling.
    + It seems naively just adjusting the "causal size" (amount of tokens to predict into the future, and in turn, how many tokens are returned per step) introduces crackles at fixed intervals.

Some additional configurations have been explored with, but experiments have not been fruitful:
* Exotic wrappers like `BitNet` seemed to yield little gains in inferencing, somehow. The memory savings is pretty much unneccessary as the models are already manageable at ~200M parameters.
* Mamba / Mamba2-based models have shown that it's ***really*** hard to have an AR+NAR model. I really do not want to bother throwing the compute at another ~~meme~~ arch I can't easily make use of all the other tech to throw at.

Some current "achitectural features" are in-use, but their effects need to be experimented with further:
* `split_classifier_heads` is still a mystery whether it's truly helpful or not (each RVQ level gets its own output head).
* `audio_embeddings_sum` is also a mystery whether it matters if each later RVQ level should "see" the past levels through summing embeddings, or if not doing it is preferable.
* Disabling `unified_position_ids` seems to help quality more often than not, but I'm still unsure if it's beneficial in practice.