vall-e / README.md
ecker's picture
Update README.md
e1f07f0 verified
|
raw
history blame
10.3 kB
metadata
license: agpl-3.0

This repo catalogs my weights for use with my VALL-E implementation as I try and iron out the kinks.

The model currently is in a semi-usable state, and I'm releasing them now in hopes that it also helps jumpstart anyone else that wants to use them.

To reiterate, this is by no means complete. I am not passing this off as competitive.

Models

This repo contains the following configurations under ./models/:

  • config.retnet.yaml / ar+nar-retnet-8: The previously released weights.

    • This configuration utilizes a RetNet (retention based "transformer") as the underlying architecture due to a number of misleading interpretations with comparisons, for better or for worse.
      • Prompt and response embeddings are summed (further RVQ levels gets the previous RVQ levels' embeddings factored in).
      • Tokenizer is a homebrewed "naive" implementation.
    • This model received the most training time between my 4070Ti, 7900XTX, and a few rental rigs to training further progress, entirely at bfloat16 with prodigyopt (and a few optimizer restarts).
    • The later part of training aimed to shuffle between speakers rather than the global pool of utterances to better focus on zero-shot performance. Due to this, I feel it achieved decent zero-shot performance.
    • However, due to the dataset being aggressively trimmed under 12 seconds for memory savings during training, it suffers trying to inference non-short utterances. Additional training may fix this, the following models seemed to adapt well to longer utterances.
      • From the ar+nar-llama-8 experiment, I believe this can be "fixed" with additional training on the currently processed dataset.
    • Prior testing showed that longer prompt durations results in better utterances.
    • Can benefit from additional training, but I recall the average loss being around 1.9 to 2.1.
      • However, due to regressions (or bias from working under llama), I don't think I can optimially train with a RetNet again (both in terms of VRAM consumption and throughput).
  • config.llama.yaml / ar+nar-llama-8: The most recent-ishly trained weights after learning from my mistakes.

    • This configuration utilizes Llama's attention-based transformer as the underlying architecture, making use of creature comforts like RoPE, GQA, and memory-efficient attention (trained under xformers, shouldn't really affect things).
      • Prompt and response embeddings are NOT summed (each RVQ level only attends to the current RVQ level).
      • Utilizes a HF tokenizer for "optimal" vocab.
      • The current RVQ level is included as a token as well to help guide NAR tasks better.
    • This model received a few days of training on my 4xV100s, stepping up the duration window to try and better make the model inference for longer utterances.
      • Some sessions end up training the current duration window for a few epochs, but I don't know how much it affected things.
    • However, it seems to only do well with long utterances. Short utterances fumble. I believe further training with a variety of durations should allow the AR to handle a variety of durations.
      • I believe the "slowly stepping up the context length" only works for text, and not audio.
      • Addendum: Additional brief training for a variety of duration lengths seemed to have mostly fixed this issue.
      • Addendum addendum: Properly creating the position IDs per-segment rather than the whole sequence, also helps a lot.
    • Zero-shot performance leaves a bit to be desired, as it did not receive the special training prioritizing shuffling between speakers rather than the global pool of utterances.
      • Addendum: Additional brief training for sampling based on speaker per "epoch" (per dataloader, not dataset) seemed to slightly improve it.
    • Testing showed that, despite also stepping up the prompt duration, it really likes three second prompts.
    • Definitely needs additional training, but the next way to go is unknown.
      • Naturally, training it on a "next RVQ level is half as likely" distribution introduces some crust as the later RVQ levels are less accurate, introducing noise and artifacts.
      • As a fix for the above, naively training it on equally distributed RVQ levels does lobotomize the AR.
      • Additional training on the AR will see huge diminishing returns, so I don't know if it's worth doing so.
    • Seems to be a decent foundation for "distillation", at the very least for LoRA training.
      • Addendum: it seems to serve fine for patch-training a few extra tweaks, to non-unified position IDs, split classifier heads, and para-parallel decoding for the AR.

Experiments

Under ./models/experiments/ are some failed models, but are included to serve as references for my errors. Do not use them unless you're curious, or know what you're doing.

  • config.llama.split.yaml / ar-llama-1 + nar-llama-8: The above model, but split and trained a little bit more.

    • This experiment is to see whether the AR and NAR benefitted from being split up after enough pretraining, to un-"lobotomize" any penalties from attending to two different tasks (as the AR predicts the next token, and the NAR predicts the same token but a different level).
    • I believe I trained each separate model an additional extra day for another additional audio-duration window for similar training lengths.
    • I don't think audio quality differs a non-trivial amount to warrant splitting the model.
      • Addendum: From recent experiments, it does seem a NAR-only model is beneficial; I will need to explore this in the future.
  • config.dac.yaml / ar+nar-dac-llama-9: Utilizes Descript-Audio-Codec instead as the audio backend.

    • This utilizies the 44KHz (erroneously at 44,000 Hz instead of 44,100 Hz) model at 9 RVQ levels (majorly trained at 8, then the 9th was included).
      • Originally experimented with with feeding 24Khz audio through the 44Khz model (naively assuming nothing would go wrong), but artifacts in the output proved to be too much.
      • Later experimented with the 24Khz model, but training would always diverge.
    • Heavily benefits from inferencing only the first four RVQ levels; levels afterwards includes far too much noise in the final output.
      • I imagine the nature of DAC itself amplifies errors in the remaining RVQ levels (either due to less resilliency to errors in the codes, or each RVQ level affecting hte final waveform more).
      • Addendum: restricting to the first four RVQ levels seems to help remove noisy artifacts, but quality is hindered as there's still less RVQ levels to rely on.
    • Has not received as much training as the EnCodec-based models.
      • Because of this, performance leaves more to be desired.
    • Further experimentation is needed, but the next approach is unknown.
      • Train a NAR only model to help bolster the remaining RVQ levels (outputted utterances seem a bit sluggish).
      • Continue training the AR+NAR to try and bolster the AR tasks (as it's quite lacking at the moment).
      • Delve into other, exotic features, such as utilizing DAC's decoding embeddings (which might not be necessary at all since it seems fine at the moment).
        • Addendum: This seems unneccessary, as freezing to these embeddings is harmful, and not freezing them will just inevitably cause them to shift elsewhere.
  • config.llama-x4.yaml / ar+nar-llama-8: The above ar+nar-llama-8 model, but with para-parallel decoding for the AR in-post.

    • This mostly serves as a proof-of-concept for speeding up inferencing by reducing the number of steps required, by decoding multiple tokens in parallel with a similar approach to how the NAR decodes in parallel.
      • Trained with the trainer's batch-by-durations sampler for a maximum duration batch size of 100 seconds (750 resp tokens), with ProdigyOpt at bfloat16 (no AMP) on my 4070Ti (because I can't be assed to fire up my 4xV100 machine again for a simple test).
    • The model definitely needs to be retrained as there's some errors for the additional tokens.
      • If these cannot be nailed out with more training, then I imagine a similar approach to speculative decoding where the nth tokens are discarded if the confidence is low.
      • Greedy sampling might be beneficial instead for this, as the NAR does benefit greatly from low temperatures / greedy sampling.
    • It seems naively just adjusting the "causal size" (amount of tokens to predict into the future, and in turn, how many tokens are returned per step) introduces crackles at fixed intervals.

Some additional configurations have been explored with, but experiments have not been fruitful:

  • Exotic wrappers like BitNet seemed to yield little gains in inferencing, somehow. The memory savings is pretty much unneccessary as the models are already manageable at ~200M parameters.
  • Mamba / Mamba2-based models have shown that it's really hard to have an AR+NAR model. I really do not want to bother throwing the compute at another meme arch I can't easily make use of all the other tech to throw at.
  • a pure NAR (plus length predictor) cannot be realized with the current architecture.
    • Transformer-based (or at least attention based) models can't seem to handle generating the initial (RVQ level 0) tokens from "thin air" (be it special tokens to repeating the input prompt).
    • A diffusion-based model will definitely work, as those are good at generating from noise.
    • The performance gains seem nice as the biggest "bottleneck" is the initial (RVQ level 0) AR pass, but it seems to require a lot of effort.

Some current "achitectural features" are in-use, but their effects need to be experimented with further:

  • split_classifier_heads is still a mystery whether it's truly helpful or not (each RVQ level gets its own output head).
  • audio_embeddings_sum is also a mystery whether it matters if each later RVQ level should "see" the past levels through summing embeddings, or if not doing it is preferable.
  • Disabling unified_position_ids seems to help quality more often than not, but I'm still unsure if it's beneficial in practice.