Upload modeling_sdar.py with huggingface_hub
Browse files- modeling_sdar.py +726 -136
modeling_sdar.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
print(f"--- I am EXECUTING modeling_sdar.py from location: {__file__} ---")
|
| 2 |
# This file is modified based on https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3/modeling_qwen3.py.
|
| 3 |
#
|
| 4 |
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
@@ -22,10 +21,11 @@ print(f"--- I am EXECUTING modeling_sdar.py from location: {__file__} ---")
|
|
| 22 |
# See the License for the specific language governing permissions and
|
| 23 |
# limitations under the License.
|
| 24 |
|
| 25 |
-
from typing import Callable, Optional, Tuple, Union
|
| 26 |
|
| 27 |
import torch
|
| 28 |
from torch import nn
|
|
|
|
| 29 |
|
| 30 |
from transformers.activations import ACT2FN
|
| 31 |
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
|
@@ -45,7 +45,8 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
|
|
| 45 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 46 |
from transformers.processing_utils import Unpack
|
| 47 |
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
| 48 |
-
from
|
|
|
|
| 49 |
|
| 50 |
from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
|
| 51 |
|
|
@@ -71,6 +72,286 @@ if is_torch_flex_attn_available():
|
|
| 71 |
logger = logging.get_logger(__name__)
|
| 72 |
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
@use_kernel_forward_from_hub("RMSNorm")
|
| 75 |
class SDARRMSNorm(nn.Module):
|
| 76 |
def __init__(self, hidden_size, eps=1e-6):
|
|
@@ -129,6 +410,7 @@ def rotate_half(x):
|
|
| 129 |
|
| 130 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 131 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
|
| 132 |
Args:
|
| 133 |
q (`torch.Tensor`): The query tensor.
|
| 134 |
k (`torch.Tensor`): The key tensor.
|
|
@@ -255,8 +537,6 @@ class SDARAttention(nn.Module):
|
|
| 255 |
hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 256 |
value_states = self.v_proj(hidden_states).view(
|
| 257 |
hidden_shape).transpose(1, 2)
|
| 258 |
-
|
| 259 |
-
|
| 260 |
|
| 261 |
cos, sin = position_embeddings
|
| 262 |
query_states, key_states = apply_rotary_pos_emb(
|
|
@@ -274,96 +554,47 @@ class SDARAttention(nn.Module):
|
|
| 274 |
value_states = torch.cat(
|
| 275 |
[past_value_states, value_states], dim=-2)
|
| 276 |
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
if torch.all(attention_mask): # decoding
|
| 280 |
-
query_states = query_states.transpose(1, 2)
|
| 281 |
-
key_states = key_states.transpose(1, 2)
|
| 282 |
-
value_states = value_states.transpose(1, 2)
|
| 283 |
-
attn_output = flash_attn_func(
|
| 284 |
-
query_states,
|
| 285 |
-
key_states,
|
| 286 |
-
value_states,
|
| 287 |
-
causal=False,
|
| 288 |
-
softmax_scale=self.scaling
|
| 289 |
-
)
|
| 290 |
-
|
| 291 |
-
else: # prefilling
|
| 292 |
-
attn_output = F.scaled_dot_product_attention(
|
| 293 |
query=query_states,
|
| 294 |
key=key_states,
|
| 295 |
value=value_states,
|
| 296 |
-
|
| 297 |
-
|
| 298 |
scale=self.scaling,
|
| 299 |
-
|
| 300 |
-
)
|
| 301 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 302 |
-
'''
|
| 303 |
-
|
| 304 |
-
#print(query_states.shape, key_states.shape, value_states.shape)
|
| 305 |
-
|
| 306 |
-
# --- After RoPE and KV-cache handling, expand KV to all heads ---
|
| 307 |
-
key_states = repeat_kv(key_states, self.num_key_value_groups) # [B, H, K, D]
|
| 308 |
-
value_states = repeat_kv(value_states, self.num_key_value_groups) # [B, H, K, D]
|
| 309 |
-
|
| 310 |
-
# --- Convert a 0/1 or bool 4D mask into an *additive* mask, and align to [B, H, Q, K] ---
|
| 311 |
-
attn_mask = None
|
| 312 |
-
if attention_mask is not None:
|
| 313 |
-
k_len = key_states.shape[-2]
|
| 314 |
-
am = attention_mask
|
| 315 |
-
# Support either 2D [B, K] or 4D [B, 1/H, Q, K]
|
| 316 |
-
if am.dim() == 2:
|
| 317 |
-
am = am[:, None, None, :k_len] # -> [B,1,1,K]
|
| 318 |
-
else:
|
| 319 |
-
am = am[:, :, :, :k_len] # -> [B,1/H,Q,K]
|
| 320 |
-
|
| 321 |
-
finfo_min = torch.finfo(query_states.dtype).min
|
| 322 |
-
# 0/1 or bool -> float additive mask: 1->0, 0->-inf
|
| 323 |
-
if am.dtype == torch.bool:
|
| 324 |
-
zero = torch.zeros((), dtype=query_states.dtype, device=am.device)
|
| 325 |
-
neginf = torch.full((), finfo_min, dtype=query_states.dtype, device=am.device)
|
| 326 |
-
am = torch.where(am, zero, neginf)
|
| 327 |
-
else:
|
| 328 |
-
# For 0/1 float masks: values > 0 are treated as visible
|
| 329 |
-
am = am.to(query_states.dtype)
|
| 330 |
-
am = torch.where(am > 0, torch.zeros_like(am), torch.full_like(am, finfo_min))
|
| 331 |
-
|
| 332 |
-
# Expand to all heads
|
| 333 |
-
#if am.shape[1] == 1 and self.num_attention_heads > 1:
|
| 334 |
-
# am = am.expand(am.shape[0], self.num_attention_heads, am.shape[2], am.shape[3])
|
| 335 |
-
|
| 336 |
-
#attn_mask = am.contiguous()
|
| 337 |
-
attn_mask = am
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
bsz, q_len = input_shape
|
| 341 |
-
|
| 342 |
-
if q_len == 1 and past_key_value is not None:
|
| 343 |
-
# --- Decoding: flash-attn ---
|
| 344 |
-
q = query_states.transpose(1, 2) # [B,Q,H,D]
|
| 345 |
-
k = key_states.transpose(1, 2)
|
| 346 |
-
v = value_states.transpose(1, 2)
|
| 347 |
-
attn_output = flash_attn_func(
|
| 348 |
-
q, k, v,
|
| 349 |
-
causal=True, # For decoding, explicitly set causal=True
|
| 350 |
-
softmax_scale=self.scaling
|
| 351 |
)
|
| 352 |
-
|
|
|
|
|
|
|
| 353 |
else:
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
attn_output = self.o_proj(attn_output)
|
| 366 |
-
return attn_output,
|
| 367 |
|
| 368 |
|
| 369 |
class SDARDecoderLayer(GradientCheckpointingLayer):
|
|
@@ -739,6 +970,7 @@ class SDARModel(SDARPreTrainedModel):
|
|
| 739 |
"""
|
| 740 |
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 741 |
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
|
|
|
| 742 |
Args:
|
| 743 |
attention_mask (`torch.Tensor`):
|
| 744 |
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
|
@@ -834,6 +1066,79 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
|
| 834 |
def get_decoder(self):
|
| 835 |
return self.model
|
| 836 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 837 |
@can_return_tuple
|
| 838 |
@auto_docstring
|
| 839 |
def forward(
|
|
@@ -849,65 +1154,344 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
|
| 849 |
output_hidden_states: Optional[bool] = None,
|
| 850 |
cache_position: Optional[torch.LongTensor] = None,
|
| 851 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 852 |
**kwargs: Unpack[KwargsForCausalLM],
|
| 853 |
) -> CausalLMOutputWithPast:
|
| 854 |
-
|
| 855 |
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 856 |
-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 857 |
-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 858 |
-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 859 |
-
Example:
|
| 860 |
-
```python
|
| 861 |
-
>>> from transformers import AutoTokenizer, SDARForCausalLM
|
| 862 |
-
>>> model = SDARForCausalLM.from_pretrained("DiffuOpen/SDAR-1.7B-Chat")
|
| 863 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("DiffuOpen/SDAR-1.7B-Chat")
|
| 864 |
-
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 865 |
-
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 866 |
-
>>> # Generate
|
| 867 |
-
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 868 |
-
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 869 |
-
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 870 |
-
```"""
|
| 871 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 872 |
output_hidden_states = (
|
| 873 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 874 |
)
|
| 875 |
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
position_ids=position_ids
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 889 |
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 899 |
else:
|
| 900 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 901 |
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
logits
|
|
|
|
|
|
|
| 909 |
|
| 910 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 911 |
loss=loss,
|
| 912 |
logits=logits,
|
| 913 |
past_key_values=outputs.past_key_values,
|
|
@@ -915,6 +1499,12 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
|
| 915 |
attentions=outputs.attentions,
|
| 916 |
)
|
| 917 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 918 |
|
| 919 |
__all__ = [
|
| 920 |
"SDARForCausalLM",
|
|
|
|
|
|
|
| 1 |
# This file is modified based on https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3/modeling_qwen3.py.
|
| 2 |
#
|
| 3 |
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
|
|
| 21 |
# See the License for the specific language governing permissions and
|
| 22 |
# limitations under the License.
|
| 23 |
|
| 24 |
+
from typing import Callable, Optional, Tuple, Union, List
|
| 25 |
|
| 26 |
import torch
|
| 27 |
from torch import nn
|
| 28 |
+
from einops import rearrange
|
| 29 |
|
| 30 |
from transformers.activations import ACT2FN
|
| 31 |
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
|
|
|
| 45 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 46 |
from transformers.processing_utils import Unpack
|
| 47 |
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
| 48 |
+
from configuration_sdar import SDARConfig
|
| 49 |
+
from fused_linear_diffusion_cross_entropy import FusedLinearDiffusionCrossEntropyLoss
|
| 50 |
|
| 51 |
from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
|
| 52 |
|
|
|
|
| 72 |
logger = logging.get_logger(__name__)
|
| 73 |
|
| 74 |
|
| 75 |
+
def modify_padded_position_ids_2d(position_ids: torch.LongTensor) -> torch.LongTensor:
|
| 76 |
+
"""
|
| 77 |
+
使用完全向量化的 PyTorch 操作修改一个 batch 的 packed position_ids。
|
| 78 |
+
这个函数假设输入是一个 2D Tensor,形状为 (batch_size, sequence_length)。
|
| 79 |
+
它会独立地处理 batch 中的每一行。
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
position_ids: 二维 PyTorch Tensor, shape (batch_size, sequence_length).
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
修改后的 position_ids Tensor, shape (batch_size, sequence_length).
|
| 86 |
+
"""
|
| 87 |
+
if position_ids.dim() != 2:
|
| 88 |
+
raise ValueError(f"Input tensor must be 2D, but got {position_ids.dim()} dimensions.")
|
| 89 |
+
|
| 90 |
+
batch_size, seq_len = position_ids.shape
|
| 91 |
+
device = position_ids.device
|
| 92 |
+
|
| 93 |
+
col_indices = torch.arange(seq_len, device=device, dtype=position_ids.dtype).expand(batch_size, -1)
|
| 94 |
+
mask = (position_ids != 0)
|
| 95 |
+
|
| 96 |
+
masked_indices = col_indices * mask
|
| 97 |
+
last_nonzero_idx = torch.max(masked_indices, dim=1).values
|
| 98 |
+
has_nonzero = torch.any(mask, dim=1)
|
| 99 |
+
pad_start_idx = torch.where(has_nonzero, last_nonzero_idx + 1, torch.tensor(0, device=device, dtype=position_ids.dtype))
|
| 100 |
+
|
| 101 |
+
padding_mask = col_indices >= pad_start_idx.unsqueeze(1)
|
| 102 |
+
new_pad_values = col_indices - pad_start_idx.unsqueeze(1)
|
| 103 |
+
position_ids = torch.where(padding_mask, new_pad_values, position_ids)
|
| 104 |
+
|
| 105 |
+
return position_ids
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def calculate_token_nums(position_ids: torch.Tensor):
|
| 109 |
+
"""
|
| 110 |
+
使用 PyTorch 高效计算一个批次中每个打包序列的长度。
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
position_ids (torch.Tensor): 一个 2D Tensor,形状为 (batch_size, sequence_length)。
|
| 114 |
+
例如:tensor([[0,1,2,3,4,0,1,2,3,4,5,0,1,2,3,0,0,0]])
|
| 115 |
+
Returns:
|
| 116 |
+
list[list[int]]: 一个嵌套列表,包含每个批次项中各个序列的长度。
|
| 117 |
+
例如:[[5, 6, 4, 1, 1, 1]]
|
| 118 |
+
"""
|
| 119 |
+
# 检查输入是否为 2D Tensor
|
| 120 |
+
if position_ids.dim() != 2:
|
| 121 |
+
raise ValueError(f"输入必须是 2D Tensor,但得到了 {position_ids.dim()}D")
|
| 122 |
+
|
| 123 |
+
all_lengths = []
|
| 124 |
+
|
| 125 |
+
# 我们按批次逐行处理。因为每行的序列长度数量不同(ragged),
|
| 126 |
+
# 所以 Python 循环在批次维度上是最高效且最清晰的写法。
|
| 127 |
+
# 循环内部的操作是完全向量化的。
|
| 128 |
+
for pids_row in position_ids:
|
| 129 |
+
# 获取当前行的总长度
|
| 130 |
+
seq_len = pids_row.shape[0]
|
| 131 |
+
|
| 132 |
+
# 1. 找到所有值为 0 的元素的索引
|
| 133 |
+
# pids_row == 0 会返回一个布尔 Tensor: [True, False, ..., True, ...]
|
| 134 |
+
# torch.nonzero 会返回这些 True 值的索引
|
| 135 |
+
# .flatten() 将其从 (N, 1) 形状的 Tensor 变为 (N,) 形状
|
| 136 |
+
zero_indices = torch.nonzero(pids_row == 0).flatten()
|
| 137 |
+
|
| 138 |
+
# 2. 将序列的总长度作为一个额外的切分点添加到末尾
|
| 139 |
+
# 这对于计算最后一个序列的长度至关重要
|
| 140 |
+
# 注意:要确保新创建的 tensor 和原始 tensor 在同一个设备上 (cpu/cuda)
|
| 141 |
+
split_points = torch.cat([
|
| 142 |
+
zero_indices,
|
| 143 |
+
torch.tensor([seq_len], device=pids_row.device, dtype=zero_indices.dtype)
|
| 144 |
+
])
|
| 145 |
+
|
| 146 |
+
# 3. 计算相邻切分点之间的差值,这就是我们想要的长度
|
| 147 |
+
# torch.diff([a, b, c, d]) 会返回 [b-a, c-b, d-c]
|
| 148 |
+
lengths = torch.diff(split_points)
|
| 149 |
+
|
| 150 |
+
all_lengths.append(lengths)
|
| 151 |
+
|
| 152 |
+
return all_lengths
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def forward_add_noise_packed(
|
| 156 |
+
inputs_ids: torch.Tensor,
|
| 157 |
+
num_tokens_list: List[torch.Tensor],
|
| 158 |
+
prompt_mask: torch.Tensor,
|
| 159 |
+
mask_id: int,
|
| 160 |
+
eps: float = 1e-3,
|
| 161 |
+
max_tries: int = 10,
|
| 162 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 163 |
+
"""
|
| 164 |
+
为一批打包(packed)序列的 token ID 添加噪声。
|
| 165 |
+
|
| 166 |
+
此函数保留了为每个逻辑样本(在每个批次项内拼接)生成独立随机噪声率的逻辑。
|
| 167 |
+
它会随机将一部分 token 的 ID 替换为 mask_id。
|
| 168 |
+
这个过程会避开被 prompt_mask 标记的位置。
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
inputs_ids (torch.Tensor):
|
| 172 |
+
输入的 token ID 张量,形状为 (bsz, total_tokens)。
|
| 173 |
+
num_tokens_list (List[torch.Tensor]):
|
| 174 |
+
一个张量列表,长度为 bsz。列表中的每个张量记录了对应批次项中
|
| 175 |
+
每个逻辑样本的长度。例如: [tensor([len1, len2]), tensor([len3, len4, len5])].
|
| 176 |
+
prompt_mask (torch.Tensor):
|
| 177 |
+
布尔型张量,形状为 (bsz, total_tokens),值为 True 的位置表示是 prompt,
|
| 178 |
+
不应添加噪声。
|
| 179 |
+
mask_id (int):
|
| 180 |
+
用于替换的 mask token 的 ID。
|
| 181 |
+
eps (float):
|
| 182 |
+
微小值,用于防止噪声率 t 恰好为 0,确保 p_mask > 0。
|
| 183 |
+
max_tries (int):
|
| 184 |
+
为确保至少一个非 prompt token 被 mask,对每个批次项尝试的最大次数。
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 188 |
+
- noisy_input_ids (torch.Tensor):
|
| 189 |
+
添加噪声后的 token ID 张量,形状为 (bsz, total_tokens)。
|
| 190 |
+
- final_masked_indices (torch.Tensor):
|
| 191 |
+
布尔型张量,标记了哪些位置被实际 mask 了,形状为 (bsz, total_tokens)。
|
| 192 |
+
- p_masks (torch.Tensor):
|
| 193 |
+
一个一维张量,包含了被 mask 的 token 对应的实际噪声率。
|
| 194 |
+
"""
|
| 195 |
+
# 1. 验证和获取形状
|
| 196 |
+
bsz, total_tokens = inputs_ids.shape
|
| 197 |
+
device = inputs_ids.device
|
| 198 |
+
|
| 199 |
+
# 检查输入的一致性
|
| 200 |
+
assert len(num_tokens_list) == bsz, f"num_tokens_list 的长度 ({len(num_tokens_list)}) 必须等于 bsz ({bsz})"
|
| 201 |
+
assert prompt_mask.shape == (bsz, total_tokens), f"prompt_mask 形状不匹配, 期望 {(bsz, total_tokens)}, 得到 {prompt_mask.shape}"
|
| 202 |
+
|
| 203 |
+
# 准备结果容器
|
| 204 |
+
noisy_ids_list = []
|
| 205 |
+
final_masked_indices_list = []
|
| 206 |
+
p_masks_per_token_list = []
|
| 207 |
+
|
| 208 |
+
# 2. 在批次维度上迭代
|
| 209 |
+
# 这是处理不同打包结构最直接有效的方法
|
| 210 |
+
for i in range(bsz):
|
| 211 |
+
# 提取当前批次项的数据
|
| 212 |
+
current_ids = inputs_ids[i:i+1] # shape: (1, total_tokens)
|
| 213 |
+
current_num_tokens = num_tokens_list[i]
|
| 214 |
+
current_prompt_mask = prompt_mask[i:i+1] # shape: (1, total_tokens)
|
| 215 |
+
|
| 216 |
+
num_samples_in_item = len(current_num_tokens)
|
| 217 |
+
# 验证当前批次项的 token 总数是否匹配
|
| 218 |
+
assert total_tokens == torch.sum(current_num_tokens), \
|
| 219 |
+
f"批次项 {i} 的 num_tokens 之和 ({torch.sum(current_num_tokens)}) 与 total_tokens ({total_tokens}) 不匹配"
|
| 220 |
+
|
| 221 |
+
eligible_for_masking = ~current_prompt_mask
|
| 222 |
+
|
| 223 |
+
# 如果没有任何 token 可以被 mask,直接使用原始输入,并设置 p_mask 为 eps
|
| 224 |
+
if not eligible_for_masking.any():
|
| 225 |
+
noisy_ids_list.append(current_ids)
|
| 226 |
+
final_masked_indices_list.append(torch.zeros_like(current_prompt_mask, dtype=torch.bool))
|
| 227 |
+
# p_mask_per_token 的形状应为 (1, total_tokens) 以便后续拼接
|
| 228 |
+
p_masks_per_token_list.append(torch.full((1, total_tokens), eps, device=device, dtype=torch.float))
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
# --- 尝试生成 mask,确保至少 mask 一个 token ---
|
| 232 |
+
final_masked_indices_item = torch.zeros_like(current_prompt_mask, dtype=torch.bool)
|
| 233 |
+
p_mask_per_token = None
|
| 234 |
+
|
| 235 |
+
for _ in range(max_tries):
|
| 236 |
+
# 为每个逻辑样本生成一个独立的噪声率 t
|
| 237 |
+
t = torch.rand(num_samples_in_item, device=device)
|
| 238 |
+
p_mask_per_sample = (1 - eps) * t + eps
|
| 239 |
+
|
| 240 |
+
# 将每个样本的噪声率扩展到其所有 token 上
|
| 241 |
+
p_mask_per_token_1d = torch.repeat_interleave(p_mask_per_sample, current_num_tokens)
|
| 242 |
+
p_mask_per_token = p_mask_per_token_1d.unsqueeze(0) # shape: (1, total_tokens)
|
| 243 |
+
|
| 244 |
+
# 根据噪声率生成随机 mask
|
| 245 |
+
masked_indices = torch.rand_like(p_mask_per_token) < p_mask_per_token
|
| 246 |
+
# 应用 prompt mask,确保 prompt 不被 mask
|
| 247 |
+
final_masked_indices_item = masked_indices & eligible_for_masking
|
| 248 |
+
|
| 249 |
+
# 如果成功 mask 了至少一个 token,则跳出尝试循环
|
| 250 |
+
if final_masked_indices_item.any():
|
| 251 |
+
break
|
| 252 |
+
|
| 253 |
+
# 如果 max_tries 之后仍然没有 mask 任何 token (极小概率),就强制 mask 一个可 mask 的 token
|
| 254 |
+
if not final_masked_indices_item.any():
|
| 255 |
+
eligible_indices = torch.nonzero(eligible_for_masking.squeeze(0), as_tuple=True)[0]
|
| 256 |
+
if len(eligible_indices) > 0:
|
| 257 |
+
# 随机选择一个可 mask 的位置
|
| 258 |
+
random_choice = torch.randint(0, len(eligible_indices), (1,)).item()
|
| 259 |
+
force_mask_idx = eligible_indices[random_choice]
|
| 260 |
+
final_masked_indices_item[0, force_mask_idx] = True
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
# --- 根据最终的 mask 生成带噪声的 IDs ---
|
| 264 |
+
noisy_ids_item = torch.where(
|
| 265 |
+
final_masked_indices_item,
|
| 266 |
+
mask_id,
|
| 267 |
+
current_ids
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# 保存这个批次项的结果
|
| 271 |
+
noisy_ids_list.append(noisy_ids_item)
|
| 272 |
+
final_masked_indices_list.append(final_masked_indices_item)
|
| 273 |
+
p_masks_per_token_list.append(p_mask_per_token)
|
| 274 |
+
|
| 275 |
+
# 3. 将列表中的结果堆叠成最终的批处理张量
|
| 276 |
+
noisy_input_ids = torch.cat(noisy_ids_list, dim=0)
|
| 277 |
+
final_masked_indices = torch.cat(final_masked_indices_list, dim=0)
|
| 278 |
+
p_mask_full = torch.cat(p_masks_per_token_list, dim=0)
|
| 279 |
+
|
| 280 |
+
# 4. 提取被 mask 位置对应的噪声率
|
| 281 |
+
p_masks = p_mask_full[final_masked_indices]
|
| 282 |
+
|
| 283 |
+
return noisy_input_ids, final_masked_indices, p_masks
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
|
| 287 |
+
"""
|
| 288 |
+
Constructs the specialized block diffusion attention mask for training
|
| 289 |
+
composed of three masks:
|
| 290 |
+
- **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
|
| 291 |
+
- **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
|
| 292 |
+
- **Block Causal Mask (M_BC)**: Attention to update x0
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
b, h: Batch and head indices (ignored for mask logic).
|
| 296 |
+
q_idx, kv_idx: Query and Key indices.
|
| 297 |
+
seq_len: Total sequence length.
|
| 298 |
+
block_size: Defines the block structure.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
A boolean attention mask.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
# Indicate whether token belongs to xt or x0
|
| 305 |
+
x0_flag_q = q_idx >= n
|
| 306 |
+
x0_flag_kv = kv_idx >= n
|
| 307 |
+
|
| 308 |
+
# Compute block indices
|
| 309 |
+
block_q = torch.where(
|
| 310 |
+
x0_flag_q == 1, (q_idx - n) // block_size, q_idx // block_size
|
| 311 |
+
)
|
| 312 |
+
block_kv = torch.where(
|
| 313 |
+
x0_flag_kv == 1, (kv_idx - n) // block_size, kv_idx // block_size
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# **1. Block Diagonal Mask (M_BD) **
|
| 317 |
+
block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
|
| 318 |
+
|
| 319 |
+
# **2. Offset Block-Causal Mask (M_OBC) **
|
| 320 |
+
offset_block_causal = (block_q > block_kv) & (
|
| 321 |
+
x0_flag_kv == 1) & (x0_flag_q == 0)
|
| 322 |
+
|
| 323 |
+
# **3. Block-Causal Mask (M_BC) **
|
| 324 |
+
block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
|
| 325 |
+
|
| 326 |
+
# **4. Combine Masks **
|
| 327 |
+
return block_diagonal | offset_block_causal | block_causal
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def block_attn_mask(num_tokens, block_size, device):
|
| 331 |
+
masks = []
|
| 332 |
+
for i in range(len(num_tokens)):
|
| 333 |
+
cur_masks = []
|
| 334 |
+
for num in num_tokens[i]:
|
| 335 |
+
# 全部返回 n*n 而非 2n*2n
|
| 336 |
+
single_mask = block_diff_mask(
|
| 337 |
+
b=None,
|
| 338 |
+
h=None,
|
| 339 |
+
q_idx=torch.arange(num * 2, device=device)[:, None],
|
| 340 |
+
kv_idx=torch.arange(num * 2, device=device)[None, :],
|
| 341 |
+
block_size=block_size,
|
| 342 |
+
n=num,
|
| 343 |
+
)
|
| 344 |
+
cur_masks.append(single_mask)
|
| 345 |
+
masks.append(torch.block_diag(*cur_masks))
|
| 346 |
+
masks = torch.stack(masks, dim=0)
|
| 347 |
+
return masks
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
|
| 351 |
+
def fused_flex_attention(query, key, value, attention_mask, **kwargs):
|
| 352 |
+
return flex_attention(query, key, value, block_mask=attention_mask, **kwargs)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
@use_kernel_forward_from_hub("RMSNorm")
|
| 356 |
class SDARRMSNorm(nn.Module):
|
| 357 |
def __init__(self, hidden_size, eps=1e-6):
|
|
|
|
| 410 |
|
| 411 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 412 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 413 |
+
|
| 414 |
Args:
|
| 415 |
q (`torch.Tensor`): The query tensor.
|
| 416 |
k (`torch.Tensor`): The key tensor.
|
|
|
|
| 537 |
hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 538 |
value_states = self.v_proj(hidden_states).view(
|
| 539 |
hidden_shape).transpose(1, 2)
|
|
|
|
|
|
|
| 540 |
|
| 541 |
cos, sin = position_embeddings
|
| 542 |
query_states, key_states = apply_rotary_pos_emb(
|
|
|
|
| 554 |
value_states = torch.cat(
|
| 555 |
[past_value_states, value_states], dim=-2)
|
| 556 |
|
| 557 |
+
if self.training:
|
| 558 |
+
attn_output, attn_weights = fused_flex_attention(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
query=query_states,
|
| 560 |
key=key_states,
|
| 561 |
value=value_states,
|
| 562 |
+
attention_mask=attention_mask,
|
| 563 |
+
enable_gqa=True,
|
| 564 |
scale=self.scaling,
|
| 565 |
+
return_lse=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
)
|
| 567 |
+
attn_weights = attn_weights.to(
|
| 568 |
+
value_states.dtype) if attn_weights is not None else None
|
| 569 |
+
attn_output = rearrange(attn_output, 'b h l d -> b l (h d)')
|
| 570 |
else:
|
| 571 |
+
attention_mask = attention_mask.bool() if attention_mask is not None else None
|
| 572 |
+
attn_weights = None
|
| 573 |
+
if torch.all(attention_mask): # decoding
|
| 574 |
+
query_states = query_states.transpose(1, 2)
|
| 575 |
+
key_states = key_states.transpose(1, 2)
|
| 576 |
+
value_states = value_states.transpose(1, 2)
|
| 577 |
+
attn_output = flash_attn_func(
|
| 578 |
+
query_states,
|
| 579 |
+
key_states,
|
| 580 |
+
value_states,
|
| 581 |
+
causal=False,
|
| 582 |
+
softmax_scale=self.scaling
|
| 583 |
+
)
|
| 584 |
+
attn_output = rearrange(attn_output, 'b l h d -> b l (h d)')
|
| 585 |
+
else: # prefilling
|
| 586 |
+
attn_output = F.scaled_dot_product_attention(
|
| 587 |
+
query=query_states,
|
| 588 |
+
key=key_states,
|
| 589 |
+
value=value_states,
|
| 590 |
+
attn_mask=attention_mask,
|
| 591 |
+
is_causal=False,
|
| 592 |
+
scale=self.scaling,
|
| 593 |
+
enable_gqa=True
|
| 594 |
+
)
|
| 595 |
+
attn_output = rearrange(attn_output, 'b h l d -> b l (h d)')
|
| 596 |
attn_output = self.o_proj(attn_output)
|
| 597 |
+
return attn_output, attn_weights # , attn_weights
|
| 598 |
|
| 599 |
|
| 600 |
class SDARDecoderLayer(GradientCheckpointingLayer):
|
|
|
|
| 970 |
"""
|
| 971 |
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 972 |
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 973 |
+
|
| 974 |
Args:
|
| 975 |
attention_mask (`torch.Tensor`):
|
| 976 |
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
|
|
|
| 1066 |
def get_decoder(self):
|
| 1067 |
return self.model
|
| 1068 |
|
| 1069 |
+
def prepare_for_bd_training(self, inputs_ids, position_ids, prompt_mask, masked_indices=None, p_mask_input=None):
|
| 1070 |
+
bsz, seq_len = inputs_ids.shape
|
| 1071 |
+
num_tokens = calculate_token_nums(position_ids) # List[torch.Tensor]
|
| 1072 |
+
|
| 1073 |
+
# 如果手动传入了 masked_indices,就直接用它
|
| 1074 |
+
if masked_indices is not None:
|
| 1075 |
+
# 手动mask模式:用于RL训练或固定mask实验
|
| 1076 |
+
# 注意:外部传入的masked_indices已经只在response部分(通过 & response_mask),不需要再次过滤
|
| 1077 |
+
noisy_inputs_ids = torch.where(masked_indices, self.config.mask_token_id, inputs_ids)
|
| 1078 |
+
logits_to_keep_half = masked_indices # (B, L) bool
|
| 1079 |
+
# 生成默认的p_mask:扁平化后的噪声率,形状为(M,),其中M=sum(masked_indices)
|
| 1080 |
+
# 默认值0.5表示中等噪声水平(用于扩散loss)
|
| 1081 |
+
M = masked_indices.sum().item()
|
| 1082 |
+
p_mask = torch.full((M,), 0.5, device=inputs_ids.device, dtype=torch.float)
|
| 1083 |
+
else:
|
| 1084 |
+
# 随机mask模式:用于Block Diffusion预训练
|
| 1085 |
+
# 返回:noisy_inputs_ids (B, L), logits_to_keep_half (B, L) bool, p_mask (M,) float
|
| 1086 |
+
noisy_inputs_ids, logits_to_keep_half, p_mask = forward_add_noise_packed(
|
| 1087 |
+
inputs_ids=inputs_ids,
|
| 1088 |
+
num_tokens_list=num_tokens,
|
| 1089 |
+
prompt_mask=prompt_mask,
|
| 1090 |
+
mask_id=self.config.mask_token_id,
|
| 1091 |
+
)
|
| 1092 |
+
|
| 1093 |
+
# 确保两个分支返回的形状一致
|
| 1094 |
+
# logits_to_keep_half: (B, L) bool - 标记哪些位置被mask
|
| 1095 |
+
# p_mask: (M,) float - 每个被mask位置的噪声率,其中M = sum(logits_to_keep_half)
|
| 1096 |
+
assert logits_to_keep_half.shape == (bsz, seq_len), f"logits_to_keep_half shape error: {logits_to_keep_half.shape}"
|
| 1097 |
+
assert p_mask.shape == (logits_to_keep_half.sum(),), f"p_mask shape error: {p_mask.shape}, expected ({logits_to_keep_half.sum()},)"
|
| 1098 |
+
|
| 1099 |
+
# 如果提供了p_mask_input(用于RL训练),计算p_to_keep
|
| 1100 |
+
# p_to_keep表示从masked位置中选出p_mask=True的位置
|
| 1101 |
+
p_to_keep = None
|
| 1102 |
+
if p_mask_input is not None:
|
| 1103 |
+
# 注意:外部传入的p_mask_input已经只在response部分(通过 & response_mask),不需要再次过滤
|
| 1104 |
+
# p_mask_input (B, L), logits_to_keep_half (B, L)
|
| 1105 |
+
# p_to_keep (M,) bool,其中M=sum(logits_to_keep_half)
|
| 1106 |
+
p_to_keep = p_mask_input[logits_to_keep_half]
|
| 1107 |
+
|
| 1108 |
+
router_noisy_part_list = []
|
| 1109 |
+
for i in range(bsz):
|
| 1110 |
+
cur_router_noisy_part = (torch.arange(num_tokens[i].shape[0] *2) % 2 == 0).to(inputs_ids.device)
|
| 1111 |
+
cur_router_noisy_part = cur_router_noisy_part.repeat_interleave(num_tokens[i].repeat_interleave(2))
|
| 1112 |
+
router_noisy_part_list.append(cur_router_noisy_part)
|
| 1113 |
+
router_noisy_part = torch.stack(router_noisy_part_list, dim=0)
|
| 1114 |
+
|
| 1115 |
+
# concated inputs_ids: (bzs, seq_len x 2)
|
| 1116 |
+
concat_inputs_ids = inputs_ids.repeat(1, 2)
|
| 1117 |
+
# concated logits_to_keep: (bsz, seq_len x 2)
|
| 1118 |
+
logits_to_keep = torch.zeros(
|
| 1119 |
+
bsz, 2 * seq_len, dtype=torch.bool, device=inputs_ids.device)
|
| 1120 |
+
# concated position_ids: (bsz, seq_len x 2)
|
| 1121 |
+
concat_position_ids = torch.zeros(
|
| 1122 |
+
bsz, 2 * seq_len, dtype=position_ids.dtype, device=position_ids.device)
|
| 1123 |
+
for i in range(bsz):
|
| 1124 |
+
concat_inputs_ids[i][router_noisy_part[i]] = noisy_inputs_ids[i]
|
| 1125 |
+
concat_inputs_ids[i][~router_noisy_part[i]] = inputs_ids[i]
|
| 1126 |
+
|
| 1127 |
+
logits_to_keep[i][router_noisy_part[i]] = logits_to_keep_half[i]
|
| 1128 |
+
|
| 1129 |
+
concat_position_ids[i][router_noisy_part[i]] = position_ids[i]
|
| 1130 |
+
concat_position_ids[i][~router_noisy_part[i]] = position_ids[i]
|
| 1131 |
+
|
| 1132 |
+
# create flex_attention mask
|
| 1133 |
+
attention_mask = block_attn_mask(num_tokens, self.config.block_size, inputs_ids.device)
|
| 1134 |
+
flex_attention_mask_3d = create_block_mask(
|
| 1135 |
+
lambda b, h, q_idx, kv_idx: attention_mask[b, q_idx, kv_idx],
|
| 1136 |
+
B=attention_mask.size(0), H=None,
|
| 1137 |
+
Q_LEN=attention_mask.size(1), KV_LEN=attention_mask.size(2),
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
return concat_inputs_ids, concat_position_ids, flex_attention_mask_3d, logits_to_keep_half, logits_to_keep, p_mask, p_to_keep
|
| 1141 |
+
|
| 1142 |
@can_return_tuple
|
| 1143 |
@auto_docstring
|
| 1144 |
def forward(
|
|
|
|
| 1154 |
output_hidden_states: Optional[bool] = None,
|
| 1155 |
cache_position: Optional[torch.LongTensor] = None,
|
| 1156 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1157 |
+
masked_indices: Optional[torch.Tensor] = None,
|
| 1158 |
+
return_logits: bool = False,
|
| 1159 |
+
# RL training parameters
|
| 1160 |
+
compute_rl_loss: bool = False,
|
| 1161 |
+
p_mask: Optional[torch.Tensor] = None,
|
| 1162 |
+
adv: Optional[torch.Tensor] = None,
|
| 1163 |
+
adv_optimization: bool = False,
|
| 1164 |
+
logp_old_tok: Optional[torch.Tensor] = None,
|
| 1165 |
+
logp_ref_tok: Optional[torch.Tensor] = None,
|
| 1166 |
+
is_real: Optional[torch.Tensor] = None,
|
| 1167 |
+
ppo_eps: float = 0.2,
|
| 1168 |
+
kl_beta: float = 0.0,
|
| 1169 |
+
use_kl_estimator_k3: bool = True,
|
| 1170 |
+
return_entropy: bool = False,
|
| 1171 |
+
dynamic_threshold: Optional[float] = None,
|
| 1172 |
+
loss_mean: bool = True,
|
| 1173 |
**kwargs: Unpack[KwargsForCausalLM],
|
| 1174 |
) -> CausalLMOutputWithPast:
|
| 1175 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1176 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1177 |
output_hidden_states = (
|
| 1178 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1179 |
)
|
| 1180 |
|
| 1181 |
+
if self.training:
|
| 1182 |
+
assert inputs_embeds is None, "only support input_ids during training"
|
| 1183 |
+
|
| 1184 |
+
prompt_mask = (labels == -100) if labels is not None else None
|
| 1185 |
+
position_ids = modify_padded_position_ids_2d(position_ids)
|
| 1186 |
+
|
| 1187 |
+
(
|
| 1188 |
+
concat_inputs_ids,
|
| 1189 |
+
concat_position_ids,
|
| 1190 |
+
flex_attention_mask_3d,
|
| 1191 |
+
logits_to_keep_half,
|
| 1192 |
+
logits_to_keep,
|
| 1193 |
+
p_mask_out,
|
| 1194 |
+
p_to_keep,
|
| 1195 |
+
) = self.prepare_for_bd_training(
|
| 1196 |
+
input_ids, position_ids, prompt_mask, masked_indices, p_mask_input=p_mask
|
| 1197 |
+
)
|
| 1198 |
|
| 1199 |
+
outputs = self.model(
|
| 1200 |
+
input_ids=concat_inputs_ids,
|
| 1201 |
+
attention_mask=flex_attention_mask_3d,
|
| 1202 |
+
position_ids=concat_position_ids,
|
| 1203 |
+
output_attentions=output_attentions,
|
| 1204 |
+
output_hidden_states=output_hidden_states,
|
| 1205 |
+
return_dict=True,
|
| 1206 |
+
cache_position=cache_position,
|
| 1207 |
+
**kwargs,
|
| 1208 |
+
)
|
| 1209 |
+
|
| 1210 |
+
hidden_states = outputs.last_hidden_state
|
| 1211 |
+
hidden_states = hidden_states[logits_to_keep].contiguous()
|
| 1212 |
+
|
| 1213 |
+
# 初始化 entropy
|
| 1214 |
+
entropy = torch.tensor(0.0, device=input_ids.device)
|
| 1215 |
+
|
| 1216 |
+
# ====================== RL loss��PPO) ======================
|
| 1217 |
+
if compute_rl_loss:
|
| 1218 |
+
assert p_to_keep is not None, "p_mask must be provided for RL loss computation."
|
| 1219 |
+
assert adv is not None, "adv must be provided for RL loss computation."
|
| 1220 |
+
assert is_real is not None, "is_real must be provided for RL loss computation."
|
| 1221 |
+
assert labels is not None, "labels must be provided for RL loss computation."
|
| 1222 |
+
assert masked_indices is not None, "masked_indices must be provided for RL loss computation."
|
| 1223 |
+
|
| 1224 |
+
device = input_ids.device
|
| 1225 |
+
|
| 1226 |
+
# logits (M, V) — 保持原样
|
| 1227 |
+
logits = self.lm_head(hidden_states)
|
| 1228 |
+
|
| 1229 |
+
# mask — 保持原样
|
| 1230 |
+
is_real_tensor = (
|
| 1231 |
+
is_real.to(device=device, dtype=torch.bool)
|
| 1232 |
+
if torch.is_tensor(is_real)
|
| 1233 |
+
else torch.tensor(is_real, dtype=torch.bool, device=device)
|
| 1234 |
+
)
|
| 1235 |
+
p_mask_real = p_mask & is_real_tensor.unsqueeze(1) # (B, L)
|
| 1236 |
+
p_to_keep_real = p_mask_real[masked_indices] # (M,) bool
|
| 1237 |
+
|
| 1238 |
+
# 选出 logits — 保持原样
|
| 1239 |
+
logits_p = logits[p_to_keep_real] # (N, V)
|
| 1240 |
+
N = p_to_keep_real.sum().item()
|
| 1241 |
+
total_response_tokens = (labels != -100).sum().item()
|
| 1242 |
+
total_p_mask = p_mask.sum().item()
|
| 1243 |
+
total_masked_indices = masked_indices.sum().item()
|
| 1244 |
+
total_is_real = is_real_tensor.sum().item() if is_real_tensor.dim() > 0 else (1 if is_real_tensor.item() else 0)
|
| 1245 |
+
|
| 1246 |
+
|
| 1247 |
+
# log_softmax
|
| 1248 |
+
log_probs_p = torch.nn.functional.log_softmax(logits_p, dim=-1)
|
| 1249 |
+
|
| 1250 |
+
# labels / logp — 保持原样
|
| 1251 |
+
labels_p = labels[masked_indices][p_to_keep_real] # (N,)
|
| 1252 |
+
logp_p = log_probs_p.gather(dim=-1, index=labels_p.unsqueeze(-1)).squeeze(-1)
|
| 1253 |
+
|
| 1254 |
+
# entropy(可选)
|
| 1255 |
+
if return_entropy:
|
| 1256 |
+
with torch.no_grad():
|
| 1257 |
+
entropy_p = -(log_probs_p.exp() * log_probs_p).sum(dim=-1)
|
| 1258 |
+
entropy = entropy_p.mean() if entropy_p.numel() > 0 else torch.tensor(0.0, device=device)
|
| 1259 |
+
del entropy_p
|
| 1260 |
+
|
| 1261 |
+
# advantage 处理
|
| 1262 |
+
adv_tensor = adv.to(device) if torch.is_tensor(adv) else torch.tensor(adv, dtype=torch.float, device=device)
|
| 1263 |
+
adv_optimization=False
|
| 1264 |
+
if adv_optimization:
|
| 1265 |
+
# token级别优化:对相同前缀取最大advantage(剪枝优化版本)
|
| 1266 |
+
response_mask = (labels != -100) # (B, L)
|
| 1267 |
+
bsz, seq_len = input_ids.shape
|
| 1268 |
+
|
| 1269 |
+
# 预计算每个样本的response起始位置
|
| 1270 |
+
response_starts = torch.full((bsz,), seq_len, dtype=torch.long, device=device)
|
| 1271 |
+
for b in range(bsz):
|
| 1272 |
+
if response_mask[b].any():
|
| 1273 |
+
response_starts[b] = response_mask[b].long().argmax()
|
| 1274 |
+
|
| 1275 |
+
# 剪枝1: 找出已经是最大advantage的样本,直接填充不参与比较
|
| 1276 |
+
max_adv_value = adv_tensor.max()
|
| 1277 |
+
is_max_adv = (adv_tensor == max_adv_value) # (B,) bool
|
| 1278 |
+
|
| 1279 |
+
# 创建优化后的 advantage map (B, L),确保dtype与adv_tensor一致
|
| 1280 |
+
optimized_adv = torch.zeros_like(labels, dtype=adv_tensor.dtype)
|
| 1281 |
+
|
| 1282 |
+
# 对于已是最大advantage的样本,直接填充
|
| 1283 |
+
for b in range(bsz):
|
| 1284 |
+
if is_max_adv[b]:
|
| 1285 |
+
optimized_adv[b][response_mask[b]] = max_adv_value
|
| 1286 |
+
|
| 1287 |
+
# 统计信息
|
| 1288 |
+
total_response_tokens = 0
|
| 1289 |
+
updated_tokens = 0
|
| 1290 |
+
skipped_tokens = 0
|
| 1291 |
+
original_adv_sum = 0.0
|
| 1292 |
+
optimized_adv_sum = 0.0
|
| 1293 |
+
|
| 1294 |
+
# 按position处理,批量比较前缀
|
| 1295 |
+
for pos in range(seq_len):
|
| 1296 |
+
valid_samples = response_mask[:, pos] # (B,)
|
| 1297 |
+
if not valid_samples.any():
|
| 1298 |
+
continue
|
| 1299 |
+
|
| 1300 |
+
# 剪枝2: 排除已是最大advantage的样本
|
| 1301 |
+
valid_samples = valid_samples & ~is_max_adv
|
| 1302 |
+
if not valid_samples.any():
|
| 1303 |
+
# 所有样本都是最大值,统计后跳过
|
| 1304 |
+
max_count = (response_mask[:, pos] & is_max_adv).sum().item()
|
| 1305 |
+
total_response_tokens += max_count
|
| 1306 |
+
skipped_tokens += max_count
|
| 1307 |
+
original_adv_sum += max_adv_value.item() * max_count
|
| 1308 |
+
optimized_adv_sum += max_adv_value.item() * max_count
|
| 1309 |
+
continue
|
| 1310 |
+
|
| 1311 |
+
# 获取所有需要处理的样本索引
|
| 1312 |
+
valid_indices = valid_samples.nonzero(as_tuple=True)[0] # (N,)
|
| 1313 |
+
|
| 1314 |
+
for b in valid_indices:
|
| 1315 |
+
b_item = b.item()
|
| 1316 |
+
response_start = response_starts[b_item].item()
|
| 1317 |
+
prefix_len = pos + 1 - response_start
|
| 1318 |
+
|
| 1319 |
+
if prefix_len <= 0:
|
| 1320 |
+
optimized_adv[b_item, pos] = adv_tensor[b_item]
|
| 1321 |
+
continue
|
| 1322 |
+
|
| 1323 |
+
# 找出所有response起始位置相同且在pos位置有效的样本(包括已是最大值的)
|
| 1324 |
+
same_start_mask = (response_starts == response_start) & response_mask[:, pos]
|
| 1325 |
+
same_start_indices = same_start_mask.nonzero(as_tuple=True)[0]
|
| 1326 |
+
|
| 1327 |
+
if len(same_start_indices) == 1:
|
| 1328 |
+
# 只有自己,不需要比较
|
| 1329 |
+
optimized_adv[b_item, pos] = adv_tensor[b_item]
|
| 1330 |
+
total_response_tokens += 1
|
| 1331 |
+
original_adv_sum += adv_tensor[b_item].item()
|
| 1332 |
+
optimized_adv_sum += adv_tensor[b_item].item()
|
| 1333 |
+
continue
|
| 1334 |
+
|
| 1335 |
+
# 剪枝3: 如果候选中有最大advantage样本,可以直接用最大值
|
| 1336 |
+
has_max_in_candidates = (same_start_mask & is_max_adv).any()
|
| 1337 |
+
|
| 1338 |
+
prefix_end = pos + 1
|
| 1339 |
+
current_prefix = input_ids[b_item, response_start:prefix_end]
|
| 1340 |
+
|
| 1341 |
+
# 批量比较:提取所有候选样本的前缀
|
| 1342 |
+
prefixes = input_ids[same_start_indices, response_start:prefix_end] # (M, prefix_len)
|
| 1343 |
+
|
| 1344 |
+
# 使用广播比较:(M, prefix_len) vs (prefix_len,)
|
| 1345 |
+
matches = (prefixes == current_prefix.unsqueeze(0)).all(dim=1) # (M,)
|
| 1346 |
+
|
| 1347 |
+
# 找到匹配的样本
|
| 1348 |
+
matching_indices = same_start_indices[matches]
|
| 1349 |
+
|
| 1350 |
+
# 在相同前缀的样本中取最大 advantage
|
| 1351 |
+
original_adv_value = adv_tensor[b_item].item()
|
| 1352 |
+
if matching_indices.numel() > 0:
|
| 1353 |
+
# 剪枝4: 如果匹配中有最大值样本,直接用最大值
|
| 1354 |
+
if has_max_in_candidates and is_max_adv[matching_indices].any():
|
| 1355 |
+
max_adv = max_adv_value
|
| 1356 |
+
else:
|
| 1357 |
+
max_adv = adv_tensor[matching_indices].max()
|
| 1358 |
+
|
| 1359 |
+
optimized_adv[b_item, pos] = max_adv
|
| 1360 |
+
# 统计
|
| 1361 |
+
if abs(max_adv.item() - original_adv_value) > 1e-6:
|
| 1362 |
+
updated_tokens += 1
|
| 1363 |
+
original_adv_sum += original_adv_value
|
| 1364 |
+
optimized_adv_sum += max_adv.item()
|
| 1365 |
+
else:
|
| 1366 |
+
optimized_adv[b_item, pos] = adv_tensor[b_item]
|
| 1367 |
+
original_adv_sum += original_adv_value
|
| 1368 |
+
optimized_adv_sum += original_adv_value
|
| 1369 |
+
|
| 1370 |
+
total_response_tokens += 1
|
| 1371 |
+
|
| 1372 |
+
# 输出统计信息
|
| 1373 |
+
if total_response_tokens > 0:
|
| 1374 |
+
update_ratio = updated_tokens / total_response_tokens
|
| 1375 |
+
skip_ratio = skipped_tokens / total_response_tokens
|
| 1376 |
+
avg_original = original_adv_sum / total_response_tokens
|
| 1377 |
+
avg_optimized = optimized_adv_sum / total_response_tokens
|
| 1378 |
+
print(f"[Adv Optimization] Total: {total_response_tokens}, "
|
| 1379 |
+
f"Updated: {updated_tokens} ({update_ratio:.2%}), "
|
| 1380 |
+
f"Skipped: {skipped_tokens} ({skip_ratio:.2%}), "
|
| 1381 |
+
f"Avg adv: {avg_original:.4f} -> {avg_optimized:.4f} "
|
| 1382 |
+
f"(+{avg_optimized - avg_original:.4f})")
|
| 1383 |
+
|
| 1384 |
+
# 使用优化后的 advantage
|
| 1385 |
+
adv_expanded = optimized_adv
|
| 1386 |
+
else:
|
| 1387 |
+
# 不优化:直接使用原始 advantage
|
| 1388 |
+
adv_expanded = adv_tensor.unsqueeze(1).expand_as(p_mask)
|
| 1389 |
+
|
| 1390 |
+
adv_p = adv_expanded[masked_indices][p_to_keep_real]
|
| 1391 |
+
|
| 1392 |
+
# old logp
|
| 1393 |
+
if logp_old_tok is not None and logp_old_tok.numel() > 0:
|
| 1394 |
+
logp_old_p = logp_old_tok.to(device)[masked_indices][p_to_keep_real]
|
| 1395 |
+
else:
|
| 1396 |
+
logp_old_p = logp_p.detach()
|
| 1397 |
+
|
| 1398 |
+
# ratio/exp
|
| 1399 |
+
ratio_p = (logp_p - logp_old_p).clamp(-10.0, 10.0).exp()
|
| 1400 |
+
clipped = ratio_p.clamp(1 - ppo_eps, 1 + ppo_eps+0.08)
|
| 1401 |
+
surrogate_p = torch.minimum(ratio_p * adv_p, clipped * adv_p)
|
| 1402 |
+
# 输出离1最远的ratio值
|
| 1403 |
+
# if not torch.allclose(ratio_p, torch.ones_like(ratio_p)):
|
| 1404 |
+
furthest_value = ratio_p[torch.abs(ratio_p - 1).argmax()]
|
| 1405 |
+
# print(f"Furthest ratio from 1: {furthest_value.item()}")
|
| 1406 |
+
|
| 1407 |
+
# Policy loss: use mean or sum based on loss_mean parameter
|
| 1408 |
+
num_masked = masked_indices.sum().item()
|
| 1409 |
+
num_loss_elements = surrogate_p.numel()
|
| 1410 |
+
print(f"masked_indices.sum()={num_masked}, surrogate_p.numel()={num_loss_elements}")
|
| 1411 |
+
if loss_mean:
|
| 1412 |
+
policy_loss = -surrogate_p.mean()
|
| 1413 |
+
else:
|
| 1414 |
+
policy_loss = -surrogate_p.sum()
|
| 1415 |
+
|
| 1416 |
+
# KL(可选)
|
| 1417 |
+
kl_loss = torch.tensor(0.0, device=device)
|
| 1418 |
+
if kl_beta > 0 and logp_ref_tok is not None:
|
| 1419 |
+
logp_ref_p = logp_ref_tok.to(device)[masked_indices][p_to_keep_real]
|
| 1420 |
+
kl_seq_p = logp_p - logp_ref_p
|
| 1421 |
+
|
| 1422 |
+
if use_kl_estimator_k3:
|
| 1423 |
+
kl_seq_p = (-kl_seq_p).clamp(-10.0, 10.0).exp() - 1.0 + kl_seq_p
|
| 1424 |
+
|
| 1425 |
+
# KL loss: use mean or sum based on loss_mean parameter
|
| 1426 |
+
if loss_mean:
|
| 1427 |
+
kl_loss = kl_beta * kl_seq_p.mean()
|
| 1428 |
+
else:
|
| 1429 |
+
kl_loss = kl_beta * kl_seq_p.sum()
|
| 1430 |
+
del logp_ref_p, kl_seq_p
|
| 1431 |
+
|
| 1432 |
+
loss = policy_loss + kl_loss
|
| 1433 |
+
kl_loss_value = kl_loss.detach().clone()
|
| 1434 |
+
|
| 1435 |
+
# 清理
|
| 1436 |
+
del logits, logits_p, log_probs_p, labels_p
|
| 1437 |
+
del is_real_tensor, p_mask_real, p_to_keep_real
|
| 1438 |
+
del adv_tensor, adv_expanded, adv_p
|
| 1439 |
+
del logp_p, logp_old_p, ratio_p, clipped, surrogate_p
|
| 1440 |
+
del policy_loss, kl_loss
|
| 1441 |
+
|
| 1442 |
+
logits = None
|
| 1443 |
+
|
| 1444 |
+
# ====================== GRPO / return logits ======================
|
| 1445 |
+
elif return_logits:
|
| 1446 |
+
logits = self.lm_head(hidden_states)
|
| 1447 |
+
loss = None
|
| 1448 |
+
|
| 1449 |
+
# ====================== Block Diffusion fused loss ======================
|
| 1450 |
+
else:
|
| 1451 |
+
assert labels is not None, "Labels must be provided for training."
|
| 1452 |
+
answer_len = (labels != -100).sum()
|
| 1453 |
+
loss_fct = FusedLinearDiffusionCrossEntropyLoss(reduction="sum")
|
| 1454 |
+
loss = loss_fct(
|
| 1455 |
+
x=hidden_states,
|
| 1456 |
+
target=labels[logits_to_keep_half].contiguous(),
|
| 1457 |
+
weight=self.lm_head.weight,
|
| 1458 |
+
bias=self.lm_head.bias,
|
| 1459 |
+
p_mask=p_mask_out,
|
| 1460 |
+
)
|
| 1461 |
+
loss = loss / answer_len
|
| 1462 |
+
logits = None
|
| 1463 |
+
|
| 1464 |
+
# ====================== eval / inference ======================
|
| 1465 |
else:
|
| 1466 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 1467 |
+
input_ids=input_ids,
|
| 1468 |
+
attention_mask=attention_mask,
|
| 1469 |
+
position_ids=position_ids,
|
| 1470 |
+
past_key_values=past_key_values,
|
| 1471 |
+
inputs_embeds=inputs_embeds,
|
| 1472 |
+
use_cache=use_cache,
|
| 1473 |
+
output_attentions=output_attentions,
|
| 1474 |
+
output_hidden_states=output_hidden_states,
|
| 1475 |
+
cache_position=cache_position,
|
| 1476 |
+
**kwargs,
|
| 1477 |
+
)
|
| 1478 |
|
| 1479 |
+
hidden_states = outputs.last_hidden_state
|
| 1480 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1481 |
+
hidden_states = hidden_states[:, slice_indices, :].contiguous()
|
| 1482 |
+
|
| 1483 |
+
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
| 1484 |
+
if fuse_linear_and_cross_entropy:
|
| 1485 |
+
logits = None
|
| 1486 |
+
else:
|
| 1487 |
+
logits = self.lm_head(hidden_states)
|
| 1488 |
|
| 1489 |
+
loss = None
|
| 1490 |
+
if labels is not None:
|
| 1491 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 1492 |
+
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1493 |
+
|
| 1494 |
+
output = CausalLMOutputWithPast(
|
| 1495 |
loss=loss,
|
| 1496 |
logits=logits,
|
| 1497 |
past_key_values=outputs.past_key_values,
|
|
|
|
| 1499 |
attentions=outputs.attentions,
|
| 1500 |
)
|
| 1501 |
|
| 1502 |
+
if self.training and compute_rl_loss:
|
| 1503 |
+
output.entropy = entropy
|
| 1504 |
+
output.kl_loss = kl_loss_value if "kl_loss_value" in locals() else torch.tensor(0.0, device=input_ids.device)
|
| 1505 |
+
|
| 1506 |
+
return output
|
| 1507 |
+
|
| 1508 |
|
| 1509 |
__all__ = [
|
| 1510 |
"SDARForCausalLM",
|