Diffusers
English
File size: 4,875 Bytes
d69d12f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
Usage:

python upsample_lora_rank.py --repo_id="cocktailpeanut/optimus" \
    --filename="optimus.safetensors" \
    --new_lora_path="optimus_16.safetensors" \
    --new_rank=16
"""

import torch
from huggingface_hub import hf_hub_download
import safetensors.torch
import fire


def orthogonal_extension(matrix, target_rows):
    """
    Extends the given matrix to have target_rows rows by adding orthogonal rows.

    Args:
        matrix (torch.Tensor): Original matrix of shape [original_rows, columns].
        target_rows (int): Desired number of rows.

    Returns:
        extended_matrix (torch.Tensor): Matrix of shape [target_rows, columns].
    """
    original_rows, cols = matrix.shape
    assert target_rows >= original_rows, "Target rows must be greater than or equal to original rows."

    # Perform QR decomposition
    Q, R = torch.linalg.qr(matrix.T, mode="reduced")  # Transpose to get [columns, original_rows]
    Q = Q.T  # Back to [original_rows, columns]

    # Generate orthogonal vectors
    if target_rows > original_rows:
        additional_rows = target_rows - original_rows
        random_matrix = torch.randn(additional_rows, cols, dtype=matrix.dtype, device=matrix.device)
        # Orthogonalize against existing Q
        for i in range(additional_rows):
            v = random_matrix[i]
            v = v - Q.T @ (Q @ v)
            v = v / v.norm()
            Q = torch.cat([Q, v.unsqueeze(0)], dim=0)
    extended_matrix = Q
    return extended_matrix


def increase_lora_rank_orthogonal(state_dict, target_rank=16):
    """
    Increases the rank of all LoRA matrices in the given state dict using orthogonal extension.

    Args:
        state_dict (dict): The state dict containing LoRA matrices.
        target_rank (int): Desired higher rank.

    Returns:
        new_state_dict (dict): State dict with increased-rank LoRA matrices.
    """
    new_state_dict = state_dict.copy()
    for key in state_dict.keys():
        if "lora_A.weight" in key:
            lora_A_key = key
            lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
            if lora_B_key in state_dict:
                lora_A = state_dict[lora_A_key]
                dtype = lora_A.dtype
                lora_A = lora_A.to("cuda", torch.float32)
                lora_B = state_dict[lora_B_key]
                lora_B = lora_B.to("cuda", torch.float32)

                original_rank = lora_A.shape[0]

                # Extend lora_A and lora_B
                lora_A_new = orthogonal_extension(lora_A, target_rank).to(dtype)
                lora_B_new = orthogonal_extension(lora_B.T, target_rank).T.to(dtype)  # Transpose to match dimensions

                # Update the state dict
                new_state_dict[lora_A_key] = lora_A_new
                new_state_dict[lora_B_key] = lora_B_new

                print(
                    f"Increased rank of {lora_A_key} and {lora_B_key} from {original_rank} to {target_rank} using orthogonal extension"
                )

    return new_state_dict


def compare_approximation_error(orig_state_dict, new_state_dict):
    for key in orig_state_dict:
        if "lora_A.weight" in key:
            lora_A_key = key
            lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
            lora_A_old = orig_state_dict[lora_A_key]
            lora_B_old = orig_state_dict[lora_B_key]
            lora_A_new = new_state_dict[lora_A_key]
            lora_B_new = new_state_dict[lora_B_key]

            # Original delta_W
            delta_W_old = (lora_B_old @ lora_A_old).to("cuda")

            # Approximated delta_W
            delta_W_new = lora_B_new @ lora_A_new

            # Compute the approximation error
            error = torch.norm(delta_W_old - delta_W_new, p="fro") / torch.norm(delta_W_old, p="fro")
            print(f"Relative error for {lora_A_key}: {error.item():.6f}")


def main(
    repo_id: str,
    filename: str,
    new_rank: int,
    check_error: bool = False,
    new_lora_path: str = None,
):
    # ckpt_path = hf_hub_download(repo_id="TheLastBen/The_Hound", filename="sandor_clegane_single_layer.safetensors")
    if new_lora_path is None:
        raise ValueError("Please provide a path to serialize the converted state dict.")

    ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
    original_state_dict = safetensors.torch.load_file(ckpt_path)
    new_state_dict = increase_lora_rank_orthogonal(original_state_dict, target_rank=new_rank)

    if check_error:
        compare_approximation_error(original_state_dict, new_state_dict)

    new_state_dict = {k: v.to("cpu").contiguous() for k, v in new_state_dict.items()}
    # safetensors.torch.save_file(new_state_dict, "sandor_clegane_single_layer_32.safetensors")
    safetensors.torch.save_file(new_state_dict, new_lora_path)


if __name__ == "__main__":
    fire.Fire(main)