Shared memory error

#15
by marktenenholtz - opened

Hi, thanks for the model!

Whenever I use the 8k or 128k version of this model on my 2x4090 rig, I get the following error. I tried tinkering around with the blocksparse configurations, but to no avail.

153 [rank0]:     return fn(*args, **kwargs)
154 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 487, in checkpoint
155 [rank0]:     return CheckpointFunction.apply(function, preserve, *args)
156 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
157 [rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
158 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 262, in forward
159 [rank0]:     outputs = run_function(*args)
160 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
161 [rank0]:     return self._call_impl(*args, **kwargs)
162 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
163 [rank0]:     return forward_call(*args, **kwargs)
164 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
165 [rank0]:     output = module._old_forward(*args, **kwargs)
166 [rank0]:   File "/home/mark/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/5e0fbf02d6d35e27bf7633df1b45494e57693d2f/modeling_phi3_small.py", line 671, in forward
167 [rank0]:     hidden_states, self_attn_weights, present_key_values = self.self_attn(
168 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
169 [rank0]:     return self._call_impl(*args, **kwargs)
170 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
171 [rank0]:     return forward_call(*args, **kwargs)
172 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
173 [rank0]:     output = module._old_forward(*args, **kwargs)
174 [rank0]:   File "/home/mark/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/5e0fbf02d6d35e27bf7633df1b45494e57693d2f/modeling_phi3_small.py", line 616, in forward
175 [rank0]:     attn_function_output = self._apply_blocksparse_attention(
176 [rank0]:   File "/home/mark/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/5e0fbf02d6d35e27bf7633df1b45494e57693d2f/modeling_phi3_small.py", line 382, in _apply_blocksparse_attention
177 [rank0]:     context_layer = self._blocksparse_layer(
178 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
179 [rank0]:     return self._call_impl(*args, **kwargs)
180 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
181 [rank0]:     return forward_call(*args, **kwargs)
182 [rank0]:   File "/home/mark/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/5e0fbf02d6d35e27bf7633df1b45494e57693d2f/triton_blocksparse_attention_layer.py", line 165, in forward
183 [rank0]:     return blocksparse_flash_attn_padded_fwd(
184 [rank0]:   File "/home/mark/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/5e0fbf02d6d35e27bf7633df1b45494e57693d2f/triton_flash_blocksparse_attn.py", line 994, in blocksparse_flash_attn_padded_fwd
185 [rank0]:     _fwd_kernel_batch_inference[grid](
186 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/triton/runtime/jit.py", line 167, in <lambda>
187 [rank0]:     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
188 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 305, in run
189 [rank0]:     return self.fn.run(*args, **kwargs)
190 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/triton/runtime/jit.py", line 425, in run
191 [rank0]:     kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas,  # number of warps/ctas per instance
192 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/triton/compiler/compiler.py", line 255, in __getattribute__
193 [rank0]:     self._init_handles()
194 [rank0]:   File "/home/mark/miniforge3/envs/unsloth/lib/python3.10/site-packages/triton/compiler/compiler.py", line 248, in _init_handles
195 [rank0]:     raise OutOfResources(self.shared, max_shared, "shared memory")
196 [rank0]: triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 180224, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

Here is my Trainer config:

args = TrainingArguments(
        'output',
        learning_rate=lr,  
        gradient_accumulation_steps=accum,
        lr_scheduler_type='cosine', 
        bf16=True,
        bf16_full_eval=True,
        tf32=True,
        optim='paged_adamw_8bit',
        evaluation_strategy="epoch", 
        logging_steps=1,
        per_device_train_batch_size=bs, 
        per_device_eval_batch_size=bs,
        greater_is_better=False, 
        group_by_length=True,
        num_train_epochs=epochs, 
        weight_decay=wd, 
        save_strategy='epoch',
        save_total_limit=1,
        dataloader_num_workers=4,
        dataloader_pin_memory=True,
        ddp_find_unused_parameters=False,
        gradient_checkpointing=True,
)

I confirmed that I'm not running out of GPU memory. I ran Llama-3-8B with the same batch size (4) and parameters, and I even tried dropping Phi's batch size to 1.

I am getting the same error. Can someone help resolve it ?
In my case, sometime it's working and sometimes i am getting this error. not sure why. I am using g5.12xlarge aws sagemaker instance.

Microsoft org
edited Jun 5

Hi
Can you try changing the num_stages=1 here. Specifically in
triton_flash_blocksparse_attn.py::L1023

_fwd_kernel_batch_inference[grid](
    q, k, v, out,
    sm_scale,
    q_batch_starts,
    q_batch_ends,
    k_batch_starts,
    k_batch_ends,
    q_batch_ids,
    q_start_sids,

    *q.stride(),
    *k.stride(),
    *v.stride(),
    *out.stride(),

    layout_crow_indices,
    layout_col_indices,
    *layout_crow_indices.stride(),
    *layout_col_indices.stride(),

    q_k_ratio,
    HAS_BATCH_DIM = True,
    D_HEAD = head_size,
    BLOCK_M = block_size,
    BLOCK_N = block_size,
    BLOCK_D = block_d,
    BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
    EVEN_D = block_d == head_size,
    num_warps = 1 if q_len == 1 else 4,
    num_stages = 1  # <---- instead of 3
    )

Let me know if that fixes the issue. If it does, then can create a PR for this.

That does fix it @bapatra , but now I'm running into the same issue as this commenter: https://huggingface.co/microsoft/Phi-3-small-8k-instruct/discussions/11

As you can see from my config, I have bf16 and bf16_full_eval enabled. I even tried disabling tf32 but that didn't help.

@ericxihuilin I went through that thread and tried that. It didn't work for me. Not sure if it makes a difference or not, but I'm doing QLoRA training.

@marktenenholtz
We are facing the same issue. Did you find a solution?

Microsoft org

The recommended adjustment layer is

"target_modules": [
"o_proj",
"qkv_proj"
]

@LeeStott Using those target_modules gives the error:

ValueError: Target modules {'o_proj', 'qkv_proj'} not found in the base model. Please check the target modules and try again.

Sign up or log in to comment