vall-e / README.md
mrq
ugh
e626434
|
raw
history blame
No virus
7.62 kB
metadata
license: agpl-3.0

This repo catalogs my weights for use with my VALL-E implementation as I try and iron out the kinks.

The model currently is in a semi-usable state, and I'm releasing them now in hopes that it also helps jumpstart anyone else that wants to use them.

To reiterate, this is by no means complete. I am not passing this off as competitive.

Models

This repo contains the following configurations under ./models/:

  • config.retnet.yaml / ar+nar-retnet-8: The previously released weights.

    • This configuration utilizes a RetNet (retention based "transformer") as the underlying architecture due to a number of misleading interpretations with comparisons, for better or for worse.
      • Prompt and response embeddings are summed (further RVQ levels gets the previous RVQ levels' embeddings factored in).
      • Tokenizer is a homebrewed "naive" implementation.
    • This model received the most training time between my 4070Ti, 7900XTX, and a few rental rigs to training further progress, entirely at bfloat16 with prodigyopt (and a few optimizer restarts).
    • The later part of training aimed to shuffle between speakers rather than the global pool of utterances to better focus on zero-shot performance. Due to this, I feel it achieved decent zero-shot performance.
    • However, due to the dataset being aggressively trimmed under 12 seconds for memory savings during training, it suffers trying to inference non-short utterances. Additional training may fix this, the following models seemed to adapt well to longer utterances.
      • From the ar+nar-llama-8 experiment, I believe this can be "fixed" with additional training on the currently processed dataset.
    • Prior testing showed that longer prompt durations results in better utterances.
    • Can benefit from additional training, but I recall the average loss being around 1.9 to 2.1.
      • However, due to regressions (or bias from working under llama), I don't think I can optimially train with a RetNet again (both in terms of VRAM consumption and throughput).
  • config.llama.yaml / ar+nar-llama-8: The most recent-ishly trained weights after learning from my mistakes.

    • This configuration utilizes Llama's attention-based transformer as the underlying architecture, making use of creature comforts like RoPE, GQA, and memory-efficient attention (trained under xformers, shouldn't really affect things).
      • Prompt and response embeddings are NOT summed (each RVQ level only attends to the current RVQ level).
      • Utilizes a HF tokenizer for "optimal" vocab.
      • The current RVQ level is included as a token as well to help guide NAR tasks better.
    • This model received a few days of training on my 4xV100s, stepping up the duration window to try and better make the model inference for longer utterances.
      • Some sessions end up training the current duration window for a few epochs, but I don't know how much it affected things.
    • However, it seems to only do well with long utterances. Short utterances fumble. I believe further training with a variety of durations should allow the AR to handle a variety of durations.
      • I believe the "slowly stepping up the context length" only works for text, and not audio.
      • Addendum: Additional brief training for a variety of duration lengths seemed to have mostly fixed this issue.
      • Addendum addendum: Properly creating the position IDs per-segment rather than the whole sequence, also helps a lot.
    • Zero-shot performance leaves a bit to be desired, as it did not receive the special training prioritizing shuffling between speakers rather than the global pool of utterances.
      • Addendum: Additional brief training for sampling based on speaker per "epoch" (per dataloader, not dataset) seemed to slightly improve it.
    • Testing showed that, despite also stepping up the prompt duration, it really likes three second prompts.
    • Definitely needs additional training, but the next way to go is unknown.
      • Naturally, training it on a "next RVQ level is half as likely" distribution introduces some crust as the later RVQ levels are less accurate, introducing noise and artifacts.
      • As a fix for the above, naively training it on equally distributed RVQ levels does lobotomize the AR.
      • Additional training on the AR will see huge diminishing returns, so I don't know if it's worth doing so.
    • Seems to be a decent foundation for "distillation", at the very least for LoRA training.
      • Addendum: it seems to serve fine for patch-training a few extra tweaks, to non-unified position IDs, split classifier heads, and para-parallel decoding for the AR.

Some additional configurations have been explored with, but experiments have not been fruitful:

  • Exotic wrappers like BitNet seemed to yield little gains in inferencing, somehow. The memory savings is pretty much unneccessary as the models are already manageable at ~200M parameters.
  • Mamba / Mamba2-based models have shown that it's really hard to have an AR+NAR model. I really do not want to bother throwing the compute at another meme arch I can't easily make use of all the other tech to throw at.
  • a pure NAR (plus length predictor) cannot be realized with the current architecture.
    • Transformer-based (or at least attention based) models can't seem to handle generating the initial (RVQ level 0) tokens from "thin air" (be it special tokens to repeating the input prompt).
    • A diffusion-based model will definitely work, as those are good at generating from noise.
    • The performance gains seem nice as the biggest "bottleneck" is the initial (RVQ level 0) AR pass, but it seems to require a lot of effort.
  • a model using Descript-Audio-Codec:
    • the 24KHz model will not converge no matter what. However, naively using just the first 8 RVQ levels might not be good enough, as there's too many codebooks for viable use.
    • the 44KHz model was erroneously assumed to be an even 44KHz, when in reality it's 44.1KHz. All of my audio has to be requantized, as there's some stuttering in it.
      • Because of this, training losses are high and it's having a hard time trying to converge.
    • It has sub-servicable output for the first 4 RVQ levels, but it's massive cope to try and use it as a model.
    • I believe there's hope to use it when I requantize my audio properly.
  • a model with a causal size >1 (sampling more than one token for the AR):
    • re-using an exisitng model or training from scratch does not have fruitful results.
    • there's an inherent periodic stutter that doesn't seem to be able to be trained out, but it might require exotic sampling methods.
    • unfortunately it requires:
      • either something similar to Medusa heads, where there's additional parameters to perform speculative sampling,
      • a solution similar to what VALL-E 2 uses with group token embeddings or whatever, which will harm the NAR tasks in an AR+NAR model.

Some current "achitectural features" are in-use, but their effects need to be experimented with further:

  • split_classifier_heads is still a mystery whether it's truly helpful or not (each RVQ level gets its own output head).
  • audio_embeddings_sum is also a mystery whether it matters if each later RVQ level should "see" the past levels through summing embeddings, or if not doing it is preferable.
  • Disabling unified_position_ids seems to help quality more often than not, but I'm still unsure if it's beneficial in practice.