norm function changes dtypes

#13
Files changed (1) hide show
  1. modeling_intern_vit.py +3 -2
modeling_intern_vit.py CHANGED
@@ -293,9 +293,10 @@ class InternVisionEncoderLayer(nn.Module):
293
  Args:
294
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
295
  """
296
- hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
297
 
298
- hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
 
 
299
 
300
  return hidden_states
301
 
 
293
  Args:
294
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
295
  """
 
296
 
297
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1)
298
+
299
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)
300
 
301
  return hidden_states
302