Edit model card

This repo catalogs my weights for use with my VALL-E implementation as I try and iron out the kinks.

The model currently is in a semi-usable state, and I'm releasing them now in hopes that it also helps jumpstart anyone else that wants to use them.

To reiterate, this is by no means complete. I am not passing this off as competitive.

Models

This repo contains the following configurations under ./models/:

  • config.retnet.yaml / ar+nar-retnet-8: The previously released weights.

    • This configuration utilizes a RetNet (retention based "transformer") as the underlying architecture due to a number of misleading interpretations with comparisons, for better or for worse.
      • Prompt and response embeddings are summed (further RVQ levels gets the previous RVQ levels' embeddings factored in).
      • Tokenizer is a homebrewed "naive" implementation.
    • This model received the most training time between my 4070Ti, 7900XTX, and a few rental rigs to training further progress, entirely at bfloat16 with prodigyopt (and a few optimizer restarts).
    • The later part of training aimed to shuffle between speakers rather than the global pool of utterances to better focus on zero-shot performance. Due to this, I feel it achieved decent zero-shot performance.
    • However, due to the dataset being aggressively trimmed under 12 seconds for memory savings during training, it suffers trying to inference non-short utterances. Additional training may fix this, the following models seemed to adapt well to longer utterances.
      • From the ar+nar-llama-8 experiment, I believe this can be "fixed" with additional training on the currently processed dataset.
    • Prior testing showed that longer prompt durations results in better utterances.
    • Can benefit from additional training, but I recall the average loss being around 1.9 to 2.1.
      • However, due to regressions (or bias from working under llama), I don't think I can optimially train with a RetNet again (both in terms of VRAM consumption and throughput).
  • config.llama.yaml / ar+nar-llama-8: The most recent-ishly trained weights after learning from my mistakes.

    • This configuration utilizes Llama's attention-based transformer as the underlying architecture, making use of creature comforts like RoPE, GQA, and memory-efficient attention (trained under xformers, shouldn't really affect things).
      • Prompt and response embeddings are NOT summed (each RVQ level only attends to the current RVQ level).
      • Utilizes a HF tokenizer for "optimal" vocab.
      • The current RVQ level is included as a token as well to help guide NAR tasks better.
    • This model received a few days of training on my 4xV100s, stepping up the duration window to try and better make the model inference for longer utterances.
      • Some sessions end up training the current duration window for a few epochs, but I don't know how much it affected things.
    • However, it seems to only do well with long utterances. Short utterances fumble. I believe further training with a variety of durations should allow the AR to handle a variety of durations.
      • I believe the "slowly stepping up the context length" only works for text, and not audio.
      • Addendum: Additional brief training for a variety of duration lengths seemed to have mostly fixed this issue.
      • Addendum addendum: Properly creating the position IDs per-segment rather than the whole sequence, also helps a lot.
    • Zero-shot performance leaves a bit to be desired, as it did not receive the special training prioritizing shuffling between speakers rather than the global pool of utterances.
      • Addendum: Additional brief training for sampling based on speaker per "epoch" (per dataloader, not dataset) seemed to slightly improve it.
    • Testing showed that, despite also stepping up the prompt duration, it really likes three second prompts.
    • Definitely needs additional training, but the next way to go is unknown.
      • Naturally, training it on a "next RVQ level is half as likely" distribution introduces some crust as the later RVQ levels are less accurate, introducing noise and artifacts.
      • As a fix for the above, naively training it on equally distributed RVQ levels does lobotomize the AR.
      • Additional training on the AR will see huge diminishing returns, so I don't know if it's worth doing so.
    • Seems to be a decent foundation for "distillation", at the very least for LoRA training.
      • Addendum: it seems to serve fine for patch-training a few extra tweaks, to non-unified position IDs, split classifier heads, and para-parallel decoding for the AR.
  • config.llama-tts+stt.yaml / ar+nar-tts+stt-llama-8: The above, but with partially trained for STT.

    • These weights use the above weights but with additional training for the default tts task and a new stt task (at a 3:1 ratio).
    • Initially was trained with duration_range: [3.0, 60.0] and sample_shuffle: True for a few hours, but then pivoted to my standard duration_range: [3.0, 12.0] and sample_shuffle: False
      • Will need the former training to "undo" any issues with durations, as it usually came up before.
    • stt task simply takes a piece of audio and outputs a transcription using IPA phonemes (that the model already is trained against for its text inputs).
      • Can be done with --task=stt and an empty ("") text input through the CLI interface or the Speech-to-Text tab in the web UI.
    • This mainly serves as a stepping stone before pivoting towards SpeechX tasks.
      • I first need a good mechanism to make sure I can extend existing weights with additional tasks, but with a simple enough task.
      • This also maybe seems to help bolster the initial TTS task by helping the model have a better internal state (or something to that tune).
    • STT is not perfect against voices that aren't close to a normal speaking voice (as per the dataset), unlike TTS where you can easily have "sounds close enough" and room for errors.

Some additional configurations have been explored with, but experiments have not been fruitful:

  • Exotic wrappers like BitNet seemed to yield little gains in inferencing, somehow. The memory savings is pretty much unneccessary as the models are already manageable at ~200M parameters.
  • Mamba / Mamba2-based models have shown that it's really hard to have an AR+NAR model. I really do not want to bother throwing the compute at another meme arch I can't easily make use of all the other tech to throw at.
  • a pure NAR (plus length predictor) cannot be realized with the current architecture.
    • Transformer-based (or at least attention based) models can't seem to handle generating the initial (RVQ level 0) tokens from "thin air" (be it special tokens to repeating the input prompt).
    • A diffusion-based model will definitely work, as those are good at generating from noise.
    • The performance gains seem nice as the biggest "bottleneck" is the initial (RVQ level 0) AR pass, but it seems to require a lot of effort.
  • a model using Descript-Audio-Codec:
    • the 24KHz model will not converge no matter what. However, naively using just the first 8 RVQ levels might not be good enough, as there's too many codebooks for viable use.
    • the 44KHz model was erroneously assumed to be an even 44KHz, when in reality it's 44.1KHz. All of my audio has to be requantized, as there's some stuttering in it.
      • Because of this, training losses are high and it's having a hard time trying to converge.
    • It has sub-servicable output for the first 4 RVQ levels, but it's massive cope to try and use it as a model.
    • I believe there's hope to use it when I requantize my audio properly.
      • Addendum: even after properly processing my audio, the loss is actually worse than before. I imagine DAC just cannot be used as an intermediary for an LM.
  • a model with a causal size >1 (sampling more than one token for the AR):
    • re-using an existing model or training from scratch does not have fruitful results.
    • there's an inherent periodic stutter that doesn't seem to be able to be trained out, but it might require exotic sampling methods.
    • unfortunately it requires:
      • either something similar to Medusa heads, where there's additional parameters to perform speculative sampling,
      • a solution similar to what VALL-E 2 uses with group token embeddings or whatever, which will harm the NAR tasks in an AR+NAR model.
    • I just don't understand where the issue lies, since parallel decoding does work, as evidence with the NAR.

Some current "achitectural features" are in-use, but their effects need to be experimented with further:

  • split_classifier_heads is still a mystery whether it's truly helpful or not (each RVQ level gets its own output head).
  • audio_embeddings_sum is also a mystery whether it matters if each later RVQ level should "see" the past levels through summing embeddings, or if not doing it is preferable.
  • Disabling unified_position_ids seems to help quality more often than not, but I'm still unsure if it's beneficial in practice.

LoRAs

This repo also contains some LoRAs to serve as a reference under ./loras/.

Using a LoRA is the same as a base model, except you're required to have the base model already (obviously). Just use the LoRA's config YAML to load from instead to use it.

The only caveat is that my original dataset does contain these samples already, but given the sheer size of it, they're probably underutilized.

  • However, the base model already has almost adequate output from these speakers, but not enough to be satisfactory.

  • config.lora.glados.yaml / lora-glados-r128-a128:

    • A simple LoRA of GLaDOS from both Portal and Portal 2.
    • Trained for 250 steps (48000 samples, 821 samples per epoch).
  • config.lora.sam.yaml / lora-sam-r128-a128:

    • A simple LoRA of Sam from the non-remaster Sam and Max Telltale games.
    • Trained for 250 steps (48000 samples, 1555 samples per epoch).
  • config.lora.max.yaml / lora-max-r128-a128:

    • A simple LoRA of Max from the non-remaster Sam and Max Telltale games.
    • Trained for 250 steps (48000 samples, 1292 samples per epoch).
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .