Model update
Browse files- blocks_jvlm.py +23 -24
blocks_jvlm.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
# Copyright 2025 Jina AI. All rights reserved.
|
| 2 |
|
| 3 |
from abc import ABCMeta, abstractmethod
|
| 4 |
-
from contextlib import nullcontext
|
| 5 |
from copy import deepcopy
|
| 6 |
from functools import wraps
|
| 7 |
from math import prod, sqrt
|
|
@@ -712,7 +711,6 @@ def eager_attention_forward(
|
|
| 712 |
dropout: float = 0.0,
|
| 713 |
**_,
|
| 714 |
):
|
| 715 |
-
assert isinstance(module.num_key_value_groups, int)
|
| 716 |
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 717 |
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 718 |
|
|
@@ -1239,7 +1237,13 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
|
|
| 1239 |
if config.pooling_type == ImagePooling2DType.attention_2wide:
|
| 1240 |
pooling_input_size *= 2
|
| 1241 |
|
| 1242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1243 |
self.pooling = MHSDPA(
|
| 1244 |
config.attn_pooling_config,
|
| 1245 |
hidden_size=pooling_input_size,
|
|
@@ -1280,23 +1284,6 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
|
|
| 1280 |
self.projector_dropout = Dropout(config.projector_dropout)
|
| 1281 |
self.feature_dropout = Dropout(config.feature_dropout)
|
| 1282 |
|
| 1283 |
-
@staticmethod
|
| 1284 |
-
def _resolve_attn_pooling(attn_implementation: Optional[str] = None):
|
| 1285 |
-
"""
|
| 1286 |
-
Flash Attention can cause Inf grads in the attention pooling layer because of
|
| 1287 |
-
very large batch sizes. Setting this to sdpa does not cost us much since
|
| 1288 |
-
sequence lengths in the case of attention pooling are tiny
|
| 1289 |
-
"""
|
| 1290 |
-
attn_runtime_ctx = nullcontext()
|
| 1291 |
-
if (
|
| 1292 |
-
attn_implementation is not None
|
| 1293 |
-
and attn_implementation.startswith('flash')
|
| 1294 |
-
):
|
| 1295 |
-
attn_implementation = 'sdpa'
|
| 1296 |
-
attn_runtime_ctx = sdpa_kernel(backends=[SDPBackend.MATH])
|
| 1297 |
-
|
| 1298 |
-
return attn_implementation, attn_runtime_ctx
|
| 1299 |
-
|
| 1300 |
def forward(
|
| 1301 |
self,
|
| 1302 |
image_features: torch.Tensor,
|
|
@@ -1361,10 +1348,22 @@ class VisionLanguageConnector(GradientCheckpointingLayer):
|
|
| 1361 |
image_features = image_features.contiguous()
|
| 1362 |
if self.pooling_type == ImagePooling2DType.attention_meanq:
|
| 1363 |
query = image_features.mean(-2, keepdim=True)
|
| 1364 |
-
|
| 1365 |
-
|
| 1366 |
-
|
| 1367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1368 |
image_features, _ = self.pooling(
|
| 1369 |
xq=query,
|
| 1370 |
xk=image_features,
|
|
|
|
| 1 |
# Copyright 2025 Jina AI. All rights reserved.
|
| 2 |
|
| 3 |
from abc import ABCMeta, abstractmethod
|
|
|
|
| 4 |
from copy import deepcopy
|
| 5 |
from functools import wraps
|
| 6 |
from math import prod, sqrt
|
|
|
|
| 711 |
dropout: float = 0.0,
|
| 712 |
**_,
|
| 713 |
):
|
|
|
|
| 714 |
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 715 |
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 716 |
|
|
|
|
| 1237 |
if config.pooling_type == ImagePooling2DType.attention_2wide:
|
| 1238 |
pooling_input_size *= 2
|
| 1239 |
|
| 1240 |
+
# Flash Attention can cause Inf grads in the attention pooling layer
|
| 1241 |
+
# because of very large batch sizes. Setting this to sdpa does not cost us
|
| 1242 |
+
# much since sequence lengths in the case of attention pooling are very
|
| 1243 |
+
# small
|
| 1244 |
+
attn_implementation = attn_implementation or 'eager'
|
| 1245 |
+
if attn_implementation.startswith('flash'):
|
| 1246 |
+
attn_implementation = 'sdpa'
|
| 1247 |
self.pooling = MHSDPA(
|
| 1248 |
config.attn_pooling_config,
|
| 1249 |
hidden_size=pooling_input_size,
|
|
|
|
| 1284 |
self.projector_dropout = Dropout(config.projector_dropout)
|
| 1285 |
self.feature_dropout = Dropout(config.feature_dropout)
|
| 1286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1287 |
def forward(
|
| 1288 |
self,
|
| 1289 |
image_features: torch.Tensor,
|
|
|
|
| 1348 |
image_features = image_features.contiguous()
|
| 1349 |
if self.pooling_type == ImagePooling2DType.attention_meanq:
|
| 1350 |
query = image_features.mean(-2, keepdim=True)
|
| 1351 |
+
# Flash Attention can cause Inf grads in the attention pooling layer
|
| 1352 |
+
# because of very large batch sizes. Setting this to sdpa does not cost
|
| 1353 |
+
# us much since sequence lengths in the case of attention pooling are
|
| 1354 |
+
# very small
|
| 1355 |
+
attn_implementation = attn_implementation or 'eager'
|
| 1356 |
+
if attn_implementation.startswith('flash'):
|
| 1357 |
+
attn_implementation = 'sdpa'
|
| 1358 |
+
if attn_implementation == 'sdpa':
|
| 1359 |
+
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
| 1360 |
+
image_features, _ = self.pooling(
|
| 1361 |
+
xq=query,
|
| 1362 |
+
xk=image_features,
|
| 1363 |
+
attn_implementation='sdpa',
|
| 1364 |
+
**kwargs,
|
| 1365 |
+
)
|
| 1366 |
+
else:
|
| 1367 |
image_features, _ = self.pooling(
|
| 1368 |
xq=query,
|
| 1369 |
xk=image_features,
|