Jina AI org
edited Feb 27

torchified math operations into torch operations to avoid compiling error.

  • results are examined with the following script by rebuilding alibi tensor before/after, results are identical
  • check inference with small/base/de checkpoint
  • port to es backbone
import torch
import math

def rebuild_alibi_tensor(
    size: int, device: 'cuda' = None
):
    # Alibi
    # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
    # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
    # of the logits, which makes the math work out *after* applying causal masking. If no causal masking
    # will be applied, it is necessary to construct the diagonal mask.
    n_heads = 8

    def _get_alibi_head_slopes(n_heads: int) -> torch.Tensor:
        assert (n_heads & (n_heads - 1) == 0) and n_heads != 0 # n_heads is power of 2
        n_heads = torch.tensor(n_heads)
        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(torch.log2(n) - 3)))
            ratio = start
            indices = torch.arange(n)
            return start * torch.pow(ratio, indices)

        return get_slopes_power_of_2(n_heads)

    context_position = torch.arange(size, device=device)[:, None]
    memory_position = torch.arange(size, device=device)[None, :]
    relative_position = torch.abs(memory_position - context_position)
    # [n_heads, max_token_length, max_token_length]
    relative_position = relative_position.unsqueeze(0).expand(n_heads, -1, -1)
    slopes = _get_alibi_head_slopes(n_heads).to(device) * -1
    alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position
    # [1, n_heads, max_token_length, max_token_length]
    alibi = alibi.unsqueeze(0)
    assert alibi.shape == torch.Size([1, n_heads, size, size])

    # self._current_alibi_size = size
    return alibi


def rebuild_alibi_tensor_original(
        size: int, device: 'cuda' = None
    ):
        # Alibi
        # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
        # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
        # of the logits, which makes the math work out *after* applying causal masking. If no causal masking
        # will be applied, it is necessary to construct the diagonal mask.
        n_heads = 8

        def _get_alibi_head_slopes(n_heads: int):
            def get_slopes_power_of_2(n):
                start = 2 ** (-(2 ** -(math.log2(n) - 3)))
                ratio = start
                return [start * ratio**i for i in range(n)]

            if math.log2(n_heads).is_integer():
                return get_slopes_power_of_2(
                    n_heads
                )  # In the paper, we only train models that have 2^a heads for some a. This function has
            else:  # some good properties that only occur when the input is a power of 2. To maintain that even
                closest_power_of_2 = 2 ** math.floor(
                    math.log2(n_heads)
                )  # when the number of heads is not a power of 2, we use this workaround.
                return (
                    get_slopes_power_of_2(closest_power_of_2)
                    + _get_alibi_head_slopes(2 * closest_power_of_2)[0::2][
                        : n_heads - closest_power_of_2
                    ]
                )

        context_position = torch.arange(size, device=device)[:, None]
        memory_position = torch.arange(size, device=device)[None, :]
        relative_position = torch.abs(memory_position - context_position)
        # [n_heads, max_token_length, max_token_length]
        relative_position = relative_position.unsqueeze(0).expand(n_heads, -1, -1)
        slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device) * -1
        alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position
        # [1, n_heads, max_token_length, max_token_length]
        alibi = alibi.unsqueeze(0)
        assert alibi.shape == torch.Size([1, n_heads, size, size])

        # self._current_alibi_size = size
        return alibi

# print(rebuild_alibi_tensor(8))
a = rebuild_alibi_tensor(8)
b = rebuild_alibi_tensor_original(8)
# print(rebuild_alibi_tensor_original(8))
print(torch.equal(a, b))

>>> True
bwang0911 changed pull request status to closed

Sign up or log in to comment