gmastrapas commited on
Commit
c967bd1
·
verified ·
1 Parent(s): 3d813dc

Model update

Browse files
Files changed (1) hide show
  1. blocks_jvlm.py +23 -24
blocks_jvlm.py CHANGED
@@ -1,7 +1,6 @@
1
  # Copyright 2025 Jina AI. All rights reserved.
2
 
3
  from abc import ABCMeta, abstractmethod
4
- from contextlib import nullcontext
5
  from copy import deepcopy
6
  from functools import wraps
7
  from math import prod, sqrt
@@ -712,7 +711,6 @@ def eager_attention_forward(
712
  dropout: float = 0.0,
713
  **_,
714
  ):
715
- assert isinstance(module.num_key_value_groups, int)
716
  key_states = repeat_kv(key, module.num_key_value_groups)
717
  value_states = repeat_kv(value, module.num_key_value_groups)
718
 
@@ -1239,7 +1237,13 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
1239
  if config.pooling_type == ImagePooling2DType.attention_2wide:
1240
  pooling_input_size *= 2
1241
 
1242
- attn_implementation, _ = self._resolve_attn_pooling(attn_implementation)
 
 
 
 
 
 
1243
  self.pooling = MHSDPA(
1244
  config.attn_pooling_config,
1245
  hidden_size=pooling_input_size,
@@ -1280,23 +1284,6 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
1280
  self.projector_dropout = Dropout(config.projector_dropout)
1281
  self.feature_dropout = Dropout(config.feature_dropout)
1282
 
1283
- @staticmethod
1284
- def _resolve_attn_pooling(attn_implementation: Optional[str] = None):
1285
- """
1286
- Flash Attention can cause Inf grads in the attention pooling layer because of
1287
- very large batch sizes. Setting this to sdpa does not cost us much since
1288
- sequence lengths in the case of attention pooling are tiny
1289
- """
1290
- attn_runtime_ctx = nullcontext()
1291
- if (
1292
- attn_implementation is not None
1293
- and attn_implementation.startswith('flash')
1294
- ):
1295
- attn_implementation = 'sdpa'
1296
- attn_runtime_ctx = sdpa_kernel(backends=[SDPBackend.MATH])
1297
-
1298
- return attn_implementation, attn_runtime_ctx
1299
-
1300
  def forward(
1301
  self,
1302
  image_features: torch.Tensor,
@@ -1361,10 +1348,22 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
1361
  image_features = image_features.contiguous()
1362
  if self.pooling_type == ImagePooling2DType.attention_meanq:
1363
  query = image_features.mean(-2, keepdim=True)
1364
- attn_implementation, attn_runtime_ctx = self._resolve_attn_pooling(
1365
- attn_implementation
1366
- )
1367
- with attn_runtime_ctx:
 
 
 
 
 
 
 
 
 
 
 
 
1368
  image_features, _ = self.pooling(
1369
  xq=query,
1370
  xk=image_features,
 
1
  # Copyright 2025 Jina AI. All rights reserved.
2
 
3
  from abc import ABCMeta, abstractmethod
 
4
  from copy import deepcopy
5
  from functools import wraps
6
  from math import prod, sqrt
 
711
  dropout: float = 0.0,
712
  **_,
713
  ):
 
714
  key_states = repeat_kv(key, module.num_key_value_groups)
715
  value_states = repeat_kv(value, module.num_key_value_groups)
716
 
 
1237
  if config.pooling_type == ImagePooling2DType.attention_2wide:
1238
  pooling_input_size *= 2
1239
 
1240
+ # Flash Attention can cause Inf grads in the attention pooling layer
1241
+ # because of very large batch sizes. Setting this to sdpa does not cost us
1242
+ # much since sequence lengths in the case of attention pooling are very
1243
+ # small
1244
+ attn_implementation = attn_implementation or 'eager'
1245
+ if attn_implementation.startswith('flash'):
1246
+ attn_implementation = 'sdpa'
1247
  self.pooling = MHSDPA(
1248
  config.attn_pooling_config,
1249
  hidden_size=pooling_input_size,
 
1284
  self.projector_dropout = Dropout(config.projector_dropout)
1285
  self.feature_dropout = Dropout(config.feature_dropout)
1286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1287
  def forward(
1288
  self,
1289
  image_features: torch.Tensor,
 
1348
  image_features = image_features.contiguous()
1349
  if self.pooling_type == ImagePooling2DType.attention_meanq:
1350
  query = image_features.mean(-2, keepdim=True)
1351
+ # Flash Attention can cause Inf grads in the attention pooling layer
1352
+ # because of very large batch sizes. Setting this to sdpa does not cost
1353
+ # us much since sequence lengths in the case of attention pooling are
1354
+ # very small
1355
+ attn_implementation = attn_implementation or 'eager'
1356
+ if attn_implementation.startswith('flash'):
1357
+ attn_implementation = 'sdpa'
1358
+ if attn_implementation == 'sdpa':
1359
+ with sdpa_kernel(backends=[SDPBackend.MATH]):
1360
+ image_features, _ = self.pooling(
1361
+ xq=query,
1362
+ xk=image_features,
1363
+ attn_implementation='sdpa',
1364
+ **kwargs,
1365
+ )
1366
+ else:
1367
  image_features, _ = self.pooling(
1368
  xq=query,
1369
  xk=image_features,