Auraithm commited on
Commit
8b5e6db
·
verified ·
1 Parent(s): a1dec5b

Upload modeling_sdar.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_sdar.py +726 -136
modeling_sdar.py CHANGED
@@ -1,4 +1,3 @@
1
- print(f"--- I am EXECUTING modeling_sdar.py from location: {__file__} ---")
2
  # This file is modified based on https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3/modeling_qwen3.py.
3
  #
4
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
@@ -22,10 +21,11 @@ print(f"--- I am EXECUTING modeling_sdar.py from location: {__file__} ---")
22
  # See the License for the specific language governing permissions and
23
  # limitations under the License.
24
 
25
- from typing import Callable, Optional, Tuple, Union
26
 
27
  import torch
28
  from torch import nn
 
29
 
30
  from transformers.activations import ACT2FN
31
  from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
@@ -45,7 +45,8 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
45
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
46
  from transformers.processing_utils import Unpack
47
  from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
48
- from .configuration_sdar import SDARConfig
 
49
 
50
  from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
51
 
@@ -71,6 +72,286 @@ if is_torch_flex_attn_available():
71
  logger = logging.get_logger(__name__)
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  @use_kernel_forward_from_hub("RMSNorm")
75
  class SDARRMSNorm(nn.Module):
76
  def __init__(self, hidden_size, eps=1e-6):
@@ -129,6 +410,7 @@ def rotate_half(x):
129
 
130
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
131
  """Applies Rotary Position Embedding to the query and key tensors.
 
132
  Args:
133
  q (`torch.Tensor`): The query tensor.
134
  k (`torch.Tensor`): The key tensor.
@@ -255,8 +537,6 @@ class SDARAttention(nn.Module):
255
  hidden_states).view(hidden_shape)).transpose(1, 2)
256
  value_states = self.v_proj(hidden_states).view(
257
  hidden_shape).transpose(1, 2)
258
-
259
-
260
 
261
  cos, sin = position_embeddings
262
  query_states, key_states = apply_rotary_pos_emb(
@@ -274,96 +554,47 @@ class SDARAttention(nn.Module):
274
  value_states = torch.cat(
275
  [past_value_states, value_states], dim=-2)
276
 
277
- '''
278
- attention_mask = attention_mask.bool() if attention_mask is not None else None
279
- if torch.all(attention_mask): # decoding
280
- query_states = query_states.transpose(1, 2)
281
- key_states = key_states.transpose(1, 2)
282
- value_states = value_states.transpose(1, 2)
283
- attn_output = flash_attn_func(
284
- query_states,
285
- key_states,
286
- value_states,
287
- causal=False,
288
- softmax_scale=self.scaling
289
- )
290
-
291
- else: # prefilling
292
- attn_output = F.scaled_dot_product_attention(
293
  query=query_states,
294
  key=key_states,
295
  value=value_states,
296
- attn_mask=attention_mask,
297
- is_causal=False,
298
  scale=self.scaling,
299
- enable_gqa=True
300
- )
301
- attn_output = attn_output.transpose(1, 2).contiguous()
302
- '''
303
-
304
- #print(query_states.shape, key_states.shape, value_states.shape)
305
-
306
- # --- After RoPE and KV-cache handling, expand KV to all heads ---
307
- key_states = repeat_kv(key_states, self.num_key_value_groups) # [B, H, K, D]
308
- value_states = repeat_kv(value_states, self.num_key_value_groups) # [B, H, K, D]
309
-
310
- # --- Convert a 0/1 or bool 4D mask into an *additive* mask, and align to [B, H, Q, K] ---
311
- attn_mask = None
312
- if attention_mask is not None:
313
- k_len = key_states.shape[-2]
314
- am = attention_mask
315
- # Support either 2D [B, K] or 4D [B, 1/H, Q, K]
316
- if am.dim() == 2:
317
- am = am[:, None, None, :k_len] # -> [B,1,1,K]
318
- else:
319
- am = am[:, :, :, :k_len] # -> [B,1/H,Q,K]
320
-
321
- finfo_min = torch.finfo(query_states.dtype).min
322
- # 0/1 or bool -> float additive mask: 1->0, 0->-inf
323
- if am.dtype == torch.bool:
324
- zero = torch.zeros((), dtype=query_states.dtype, device=am.device)
325
- neginf = torch.full((), finfo_min, dtype=query_states.dtype, device=am.device)
326
- am = torch.where(am, zero, neginf)
327
- else:
328
- # For 0/1 float masks: values > 0 are treated as visible
329
- am = am.to(query_states.dtype)
330
- am = torch.where(am > 0, torch.zeros_like(am), torch.full_like(am, finfo_min))
331
-
332
- # Expand to all heads
333
- #if am.shape[1] == 1 and self.num_attention_heads > 1:
334
- # am = am.expand(am.shape[0], self.num_attention_heads, am.shape[2], am.shape[3])
335
-
336
- #attn_mask = am.contiguous()
337
- attn_mask = am
338
-
339
-
340
- bsz, q_len = input_shape
341
-
342
- if q_len == 1 and past_key_value is not None:
343
- # --- Decoding: flash-attn ---
344
- q = query_states.transpose(1, 2) # [B,Q,H,D]
345
- k = key_states.transpose(1, 2)
346
- v = value_states.transpose(1, 2)
347
- attn_output = flash_attn_func(
348
- q, k, v,
349
- causal=True, # For decoding, explicitly set causal=True
350
- softmax_scale=self.scaling
351
  )
352
- attn_output = attn_output.transpose(1, 2).contiguous()
 
 
353
  else:
354
- attn_output = F.scaled_dot_product_attention(
355
- query=query_states, # [B,H,Q,D]
356
- key=key_states, # [B,H,K,D]
357
- value=value_states, # [B,H,K,D]
358
- attn_mask=attn_mask, # float additive mask
359
- is_causal=False, # All constraints are already encoded in the mask
360
- scale=self.scaling,
361
- )
362
- attn_output = attn_output.transpose(1, 2).contiguous() # -> [B,Q,H,D]
363
-
364
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  attn_output = self.o_proj(attn_output)
366
- return attn_output, None # , attn_weights
367
 
368
 
369
  class SDARDecoderLayer(GradientCheckpointingLayer):
@@ -739,6 +970,7 @@ class SDARModel(SDARPreTrainedModel):
739
  """
740
  Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
741
  `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
 
742
  Args:
743
  attention_mask (`torch.Tensor`):
744
  A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
@@ -834,6 +1066,79 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
834
  def get_decoder(self):
835
  return self.model
836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837
  @can_return_tuple
838
  @auto_docstring
839
  def forward(
@@ -849,65 +1154,344 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
849
  output_hidden_states: Optional[bool] = None,
850
  cache_position: Optional[torch.LongTensor] = None,
851
  logits_to_keep: Union[int, torch.Tensor] = 0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852
  **kwargs: Unpack[KwargsForCausalLM],
853
  ) -> CausalLMOutputWithPast:
854
- r"""
855
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
856
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
857
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
858
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
859
- Example:
860
- ```python
861
- >>> from transformers import AutoTokenizer, SDARForCausalLM
862
- >>> model = SDARForCausalLM.from_pretrained("DiffuOpen/SDAR-1.7B-Chat")
863
- >>> tokenizer = AutoTokenizer.from_pretrained("DiffuOpen/SDAR-1.7B-Chat")
864
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
865
- >>> inputs = tokenizer(prompt, return_tensors="pt")
866
- >>> # Generate
867
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
868
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
869
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
870
- ```"""
871
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
872
  output_hidden_states = (
873
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
874
  )
875
 
876
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
877
- outputs: BaseModelOutputWithPast = self.model(
878
- input_ids=input_ids,
879
- attention_mask=attention_mask,
880
- position_ids=position_ids,
881
- past_key_values=past_key_values,
882
- inputs_embeds=inputs_embeds,
883
- use_cache=use_cache,
884
- output_attentions=output_attentions,
885
- output_hidden_states=output_hidden_states,
886
- cache_position=cache_position,
887
- **kwargs,
888
- )
 
 
 
 
889
 
890
- hidden_states = outputs.last_hidden_state
891
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
892
- slice_indices = slice(-logits_to_keep,
893
- None) if isinstance(logits_to_keep, int) else logits_to_keep
894
- hidden_states = hidden_states[:, slice_indices, :].contiguous()
895
- fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
896
- if fuse_linear_and_cross_entropy:
897
- # When using fused_linear_ce_loss, we do not compute the whole logits on HBM
898
- logits = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
899
  else:
900
- logits = self.lm_head(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
901
 
902
- loss = None
903
- if labels is not None:
904
- # FusedLinearCrossEntropyLoss will be implemented by monkey patch when training
905
- # We don't use it when inferencing
906
- loss_fct = nn.CrossEntropyLoss() # nn.CE
907
- loss = loss_fct(
908
- logits.view(-1, self.config.vocab_size), labels.view(-1))
 
 
909
 
910
- return CausalLMOutputWithPast(
 
 
 
 
 
911
  loss=loss,
912
  logits=logits,
913
  past_key_values=outputs.past_key_values,
@@ -915,6 +1499,12 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
915
  attentions=outputs.attentions,
916
  )
917
 
 
 
 
 
 
 
918
 
919
  __all__ = [
920
  "SDARForCausalLM",
 
 
1
  # This file is modified based on https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3/modeling_qwen3.py.
2
  #
3
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
 
21
  # See the License for the specific language governing permissions and
22
  # limitations under the License.
23
 
24
+ from typing import Callable, Optional, Tuple, Union, List
25
 
26
  import torch
27
  from torch import nn
28
+ from einops import rearrange
29
 
30
  from transformers.activations import ACT2FN
31
  from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
 
45
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
46
  from transformers.processing_utils import Unpack
47
  from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
48
+ from configuration_sdar import SDARConfig
49
+ from fused_linear_diffusion_cross_entropy import FusedLinearDiffusionCrossEntropyLoss
50
 
51
  from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
52
 
 
72
  logger = logging.get_logger(__name__)
73
 
74
 
75
+ def modify_padded_position_ids_2d(position_ids: torch.LongTensor) -> torch.LongTensor:
76
+ """
77
+ 使用完全向量化的 PyTorch 操作修改一个 batch 的 packed position_ids。
78
+ 这个函数假设输入是一个 2D Tensor,形状为 (batch_size, sequence_length)。
79
+ 它会独立地处理 batch 中的每一行。
80
+
81
+ Args:
82
+ position_ids: 二维 PyTorch Tensor, shape (batch_size, sequence_length).
83
+
84
+ Returns:
85
+ 修改后的 position_ids Tensor, shape (batch_size, sequence_length).
86
+ """
87
+ if position_ids.dim() != 2:
88
+ raise ValueError(f"Input tensor must be 2D, but got {position_ids.dim()} dimensions.")
89
+
90
+ batch_size, seq_len = position_ids.shape
91
+ device = position_ids.device
92
+
93
+ col_indices = torch.arange(seq_len, device=device, dtype=position_ids.dtype).expand(batch_size, -1)
94
+ mask = (position_ids != 0)
95
+
96
+ masked_indices = col_indices * mask
97
+ last_nonzero_idx = torch.max(masked_indices, dim=1).values
98
+ has_nonzero = torch.any(mask, dim=1)
99
+ pad_start_idx = torch.where(has_nonzero, last_nonzero_idx + 1, torch.tensor(0, device=device, dtype=position_ids.dtype))
100
+
101
+ padding_mask = col_indices >= pad_start_idx.unsqueeze(1)
102
+ new_pad_values = col_indices - pad_start_idx.unsqueeze(1)
103
+ position_ids = torch.where(padding_mask, new_pad_values, position_ids)
104
+
105
+ return position_ids
106
+
107
+
108
+ def calculate_token_nums(position_ids: torch.Tensor):
109
+ """
110
+ 使用 PyTorch 高效计算一个批次中每个打包序列的长度。
111
+
112
+ Args:
113
+ position_ids (torch.Tensor): 一个 2D Tensor,形状为 (batch_size, sequence_length)。
114
+ 例如:tensor([[0,1,2,3,4,0,1,2,3,4,5,0,1,2,3,0,0,0]])
115
+ Returns:
116
+ list[list[int]]: 一个嵌套列表,包含每个批次项中各个序列的长度。
117
+ 例如:[[5, 6, 4, 1, 1, 1]]
118
+ """
119
+ # 检查输入是否为 2D Tensor
120
+ if position_ids.dim() != 2:
121
+ raise ValueError(f"输入必须是 2D Tensor,但得到了 {position_ids.dim()}D")
122
+
123
+ all_lengths = []
124
+
125
+ # 我们按批次逐行处理。因为每行的序列长度数量不同(ragged),
126
+ # 所以 Python 循环在批次维度上是最高效且最清晰的写法。
127
+ # 循环内部的操作是完全向量化的。
128
+ for pids_row in position_ids:
129
+ # 获取当前行的总长度
130
+ seq_len = pids_row.shape[0]
131
+
132
+ # 1. 找到所有值为 0 的元素的索引
133
+ # pids_row == 0 会返回一个布尔 Tensor: [True, False, ..., True, ...]
134
+ # torch.nonzero 会返回这些 True 值的索引
135
+ # .flatten() 将其从 (N, 1) 形状的 Tensor 变为 (N,) 形状
136
+ zero_indices = torch.nonzero(pids_row == 0).flatten()
137
+
138
+ # 2. 将序列的总长度作为一个额外的切分点添加到末尾
139
+ # 这对于计算最后一个序列的长度至关重要
140
+ # 注意:要确保新创建的 tensor 和原始 tensor 在同一个设备上 (cpu/cuda)
141
+ split_points = torch.cat([
142
+ zero_indices,
143
+ torch.tensor([seq_len], device=pids_row.device, dtype=zero_indices.dtype)
144
+ ])
145
+
146
+ # 3. 计算相邻切分点之间的差值,这就是我们想要的长度
147
+ # torch.diff([a, b, c, d]) 会返回 [b-a, c-b, d-c]
148
+ lengths = torch.diff(split_points)
149
+
150
+ all_lengths.append(lengths)
151
+
152
+ return all_lengths
153
+
154
+
155
+ def forward_add_noise_packed(
156
+ inputs_ids: torch.Tensor,
157
+ num_tokens_list: List[torch.Tensor],
158
+ prompt_mask: torch.Tensor,
159
+ mask_id: int,
160
+ eps: float = 1e-3,
161
+ max_tries: int = 10,
162
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
163
+ """
164
+ 为一批打包(packed)序列的 token ID 添加噪声。
165
+
166
+ 此函数保留了为每个逻辑样本(在每个批次项内拼接)生成独立随机噪声率的逻辑。
167
+ 它会随机将一部分 token 的 ID 替换为 mask_id。
168
+ 这个过程会避开被 prompt_mask 标记的位置。
169
+
170
+ Args:
171
+ inputs_ids (torch.Tensor):
172
+ 输入的 token ID 张量,形状为 (bsz, total_tokens)。
173
+ num_tokens_list (List[torch.Tensor]):
174
+ 一个张量列表,长度为 bsz。列表中的每个张量记录了对应批次项中
175
+ 每个逻辑样本的长度。例如: [tensor([len1, len2]), tensor([len3, len4, len5])].
176
+ prompt_mask (torch.Tensor):
177
+ 布尔型张量,形状为 (bsz, total_tokens),值为 True 的位置表示是 prompt,
178
+ 不应添加噪声。
179
+ mask_id (int):
180
+ 用于替换的 mask token 的 ID。
181
+ eps (float):
182
+ 微小值,用于防止噪声率 t 恰好为 0,确保 p_mask > 0。
183
+ max_tries (int):
184
+ 为确保至少一个非 prompt token 被 mask,对每个批次项尝试的最大次数。
185
+
186
+ Returns:
187
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
188
+ - noisy_input_ids (torch.Tensor):
189
+ 添加噪声后的 token ID 张量,形状为 (bsz, total_tokens)。
190
+ - final_masked_indices (torch.Tensor):
191
+ 布尔型张量,标记了哪些位置被实际 mask 了,形状为 (bsz, total_tokens)。
192
+ - p_masks (torch.Tensor):
193
+ 一个一维张量,包含了被 mask 的 token 对应的实际噪声率。
194
+ """
195
+ # 1. 验证和获取形状
196
+ bsz, total_tokens = inputs_ids.shape
197
+ device = inputs_ids.device
198
+
199
+ # 检查输入的一致性
200
+ assert len(num_tokens_list) == bsz, f"num_tokens_list 的长度 ({len(num_tokens_list)}) 必须等于 bsz ({bsz})"
201
+ assert prompt_mask.shape == (bsz, total_tokens), f"prompt_mask 形状不匹配, 期望 {(bsz, total_tokens)}, 得到 {prompt_mask.shape}"
202
+
203
+ # 准备结果容器
204
+ noisy_ids_list = []
205
+ final_masked_indices_list = []
206
+ p_masks_per_token_list = []
207
+
208
+ # 2. 在批次维度上迭代
209
+ # 这是处理不同打包结构最直接有效的方法
210
+ for i in range(bsz):
211
+ # 提取当前批次项的数据
212
+ current_ids = inputs_ids[i:i+1] # shape: (1, total_tokens)
213
+ current_num_tokens = num_tokens_list[i]
214
+ current_prompt_mask = prompt_mask[i:i+1] # shape: (1, total_tokens)
215
+
216
+ num_samples_in_item = len(current_num_tokens)
217
+ # 验证当前批次项的 token 总数是否匹配
218
+ assert total_tokens == torch.sum(current_num_tokens), \
219
+ f"批次项 {i} 的 num_tokens 之和 ({torch.sum(current_num_tokens)}) 与 total_tokens ({total_tokens}) 不匹配"
220
+
221
+ eligible_for_masking = ~current_prompt_mask
222
+
223
+ # 如果没有任何 token 可以被 mask,直接使用原始输入,并设置 p_mask 为 eps
224
+ if not eligible_for_masking.any():
225
+ noisy_ids_list.append(current_ids)
226
+ final_masked_indices_list.append(torch.zeros_like(current_prompt_mask, dtype=torch.bool))
227
+ # p_mask_per_token 的形状应为 (1, total_tokens) 以便后续拼接
228
+ p_masks_per_token_list.append(torch.full((1, total_tokens), eps, device=device, dtype=torch.float))
229
+ continue
230
+
231
+ # --- 尝试生成 mask,确保至少 mask 一个 token ---
232
+ final_masked_indices_item = torch.zeros_like(current_prompt_mask, dtype=torch.bool)
233
+ p_mask_per_token = None
234
+
235
+ for _ in range(max_tries):
236
+ # 为每个逻辑样本生成一个独立的噪声率 t
237
+ t = torch.rand(num_samples_in_item, device=device)
238
+ p_mask_per_sample = (1 - eps) * t + eps
239
+
240
+ # 将每个样本的噪声率扩展到其所有 token 上
241
+ p_mask_per_token_1d = torch.repeat_interleave(p_mask_per_sample, current_num_tokens)
242
+ p_mask_per_token = p_mask_per_token_1d.unsqueeze(0) # shape: (1, total_tokens)
243
+
244
+ # 根据噪声率生成随机 mask
245
+ masked_indices = torch.rand_like(p_mask_per_token) < p_mask_per_token
246
+ # 应用 prompt mask,确保 prompt 不被 mask
247
+ final_masked_indices_item = masked_indices & eligible_for_masking
248
+
249
+ # 如果成功 mask 了至少一个 token,则跳出尝试循环
250
+ if final_masked_indices_item.any():
251
+ break
252
+
253
+ # 如果 max_tries 之后仍然没有 mask 任何 token (极小概率),就强制 mask 一个可 mask 的 token
254
+ if not final_masked_indices_item.any():
255
+ eligible_indices = torch.nonzero(eligible_for_masking.squeeze(0), as_tuple=True)[0]
256
+ if len(eligible_indices) > 0:
257
+ # 随机选择一个可 mask 的位置
258
+ random_choice = torch.randint(0, len(eligible_indices), (1,)).item()
259
+ force_mask_idx = eligible_indices[random_choice]
260
+ final_masked_indices_item[0, force_mask_idx] = True
261
+
262
+
263
+ # --- 根据最终的 mask 生成带噪声的 IDs ---
264
+ noisy_ids_item = torch.where(
265
+ final_masked_indices_item,
266
+ mask_id,
267
+ current_ids
268
+ )
269
+
270
+ # 保存这个批次项的结果
271
+ noisy_ids_list.append(noisy_ids_item)
272
+ final_masked_indices_list.append(final_masked_indices_item)
273
+ p_masks_per_token_list.append(p_mask_per_token)
274
+
275
+ # 3. 将列表中的结果堆叠成最终的批处理张量
276
+ noisy_input_ids = torch.cat(noisy_ids_list, dim=0)
277
+ final_masked_indices = torch.cat(final_masked_indices_list, dim=0)
278
+ p_mask_full = torch.cat(p_masks_per_token_list, dim=0)
279
+
280
+ # 4. 提取被 mask 位置对应的噪声率
281
+ p_masks = p_mask_full[final_masked_indices]
282
+
283
+ return noisy_input_ids, final_masked_indices, p_masks
284
+
285
+
286
+ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
287
+ """
288
+ Constructs the specialized block diffusion attention mask for training
289
+ composed of three masks:
290
+ - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
291
+ - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
292
+ - **Block Causal Mask (M_BC)**: Attention to update x0
293
+
294
+ Args:
295
+ b, h: Batch and head indices (ignored for mask logic).
296
+ q_idx, kv_idx: Query and Key indices.
297
+ seq_len: Total sequence length.
298
+ block_size: Defines the block structure.
299
+
300
+ Returns:
301
+ A boolean attention mask.
302
+ """
303
+
304
+ # Indicate whether token belongs to xt or x0
305
+ x0_flag_q = q_idx >= n
306
+ x0_flag_kv = kv_idx >= n
307
+
308
+ # Compute block indices
309
+ block_q = torch.where(
310
+ x0_flag_q == 1, (q_idx - n) // block_size, q_idx // block_size
311
+ )
312
+ block_kv = torch.where(
313
+ x0_flag_kv == 1, (kv_idx - n) // block_size, kv_idx // block_size
314
+ )
315
+
316
+ # **1. Block Diagonal Mask (M_BD) **
317
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
318
+
319
+ # **2. Offset Block-Causal Mask (M_OBC) **
320
+ offset_block_causal = (block_q > block_kv) & (
321
+ x0_flag_kv == 1) & (x0_flag_q == 0)
322
+
323
+ # **3. Block-Causal Mask (M_BC) **
324
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
325
+
326
+ # **4. Combine Masks **
327
+ return block_diagonal | offset_block_causal | block_causal
328
+
329
+
330
+ def block_attn_mask(num_tokens, block_size, device):
331
+ masks = []
332
+ for i in range(len(num_tokens)):
333
+ cur_masks = []
334
+ for num in num_tokens[i]:
335
+ # 全部返回 n*n 而非 2n*2n
336
+ single_mask = block_diff_mask(
337
+ b=None,
338
+ h=None,
339
+ q_idx=torch.arange(num * 2, device=device)[:, None],
340
+ kv_idx=torch.arange(num * 2, device=device)[None, :],
341
+ block_size=block_size,
342
+ n=num,
343
+ )
344
+ cur_masks.append(single_mask)
345
+ masks.append(torch.block_diag(*cur_masks))
346
+ masks = torch.stack(masks, dim=0)
347
+ return masks
348
+
349
+
350
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
351
+ def fused_flex_attention(query, key, value, attention_mask, **kwargs):
352
+ return flex_attention(query, key, value, block_mask=attention_mask, **kwargs)
353
+
354
+
355
  @use_kernel_forward_from_hub("RMSNorm")
356
  class SDARRMSNorm(nn.Module):
357
  def __init__(self, hidden_size, eps=1e-6):
 
410
 
411
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
412
  """Applies Rotary Position Embedding to the query and key tensors.
413
+
414
  Args:
415
  q (`torch.Tensor`): The query tensor.
416
  k (`torch.Tensor`): The key tensor.
 
537
  hidden_states).view(hidden_shape)).transpose(1, 2)
538
  value_states = self.v_proj(hidden_states).view(
539
  hidden_shape).transpose(1, 2)
 
 
540
 
541
  cos, sin = position_embeddings
542
  query_states, key_states = apply_rotary_pos_emb(
 
554
  value_states = torch.cat(
555
  [past_value_states, value_states], dim=-2)
556
 
557
+ if self.training:
558
+ attn_output, attn_weights = fused_flex_attention(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  query=query_states,
560
  key=key_states,
561
  value=value_states,
562
+ attention_mask=attention_mask,
563
+ enable_gqa=True,
564
  scale=self.scaling,
565
+ return_lse=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  )
567
+ attn_weights = attn_weights.to(
568
+ value_states.dtype) if attn_weights is not None else None
569
+ attn_output = rearrange(attn_output, 'b h l d -> b l (h d)')
570
  else:
571
+ attention_mask = attention_mask.bool() if attention_mask is not None else None
572
+ attn_weights = None
573
+ if torch.all(attention_mask): # decoding
574
+ query_states = query_states.transpose(1, 2)
575
+ key_states = key_states.transpose(1, 2)
576
+ value_states = value_states.transpose(1, 2)
577
+ attn_output = flash_attn_func(
578
+ query_states,
579
+ key_states,
580
+ value_states,
581
+ causal=False,
582
+ softmax_scale=self.scaling
583
+ )
584
+ attn_output = rearrange(attn_output, 'b l h d -> b l (h d)')
585
+ else: # prefilling
586
+ attn_output = F.scaled_dot_product_attention(
587
+ query=query_states,
588
+ key=key_states,
589
+ value=value_states,
590
+ attn_mask=attention_mask,
591
+ is_causal=False,
592
+ scale=self.scaling,
593
+ enable_gqa=True
594
+ )
595
+ attn_output = rearrange(attn_output, 'b h l d -> b l (h d)')
596
  attn_output = self.o_proj(attn_output)
597
+ return attn_output, attn_weights # , attn_weights
598
 
599
 
600
  class SDARDecoderLayer(GradientCheckpointingLayer):
 
970
  """
971
  Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
972
  `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
973
+
974
  Args:
975
  attention_mask (`torch.Tensor`):
976
  A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
 
1066
  def get_decoder(self):
1067
  return self.model
1068
 
1069
+ def prepare_for_bd_training(self, inputs_ids, position_ids, prompt_mask, masked_indices=None, p_mask_input=None):
1070
+ bsz, seq_len = inputs_ids.shape
1071
+ num_tokens = calculate_token_nums(position_ids) # List[torch.Tensor]
1072
+
1073
+ # 如果手动传入了 masked_indices,就直接用它
1074
+ if masked_indices is not None:
1075
+ # 手动mask模式:用于RL训练或固定mask实验
1076
+ # 注意:外部传入的masked_indices已经只在response部分(通过 & response_mask),不需要再次过滤
1077
+ noisy_inputs_ids = torch.where(masked_indices, self.config.mask_token_id, inputs_ids)
1078
+ logits_to_keep_half = masked_indices # (B, L) bool
1079
+ # 生成默认的p_mask:扁平化后的噪声率,形状为(M,),其中M=sum(masked_indices)
1080
+ # 默认值0.5表示中等噪声水平(用于扩散loss)
1081
+ M = masked_indices.sum().item()
1082
+ p_mask = torch.full((M,), 0.5, device=inputs_ids.device, dtype=torch.float)
1083
+ else:
1084
+ # 随机mask模式:用于Block Diffusion预训练
1085
+ # 返回:noisy_inputs_ids (B, L), logits_to_keep_half (B, L) bool, p_mask (M,) float
1086
+ noisy_inputs_ids, logits_to_keep_half, p_mask = forward_add_noise_packed(
1087
+ inputs_ids=inputs_ids,
1088
+ num_tokens_list=num_tokens,
1089
+ prompt_mask=prompt_mask,
1090
+ mask_id=self.config.mask_token_id,
1091
+ )
1092
+
1093
+ # 确保两个分支返回的形状一致
1094
+ # logits_to_keep_half: (B, L) bool - 标记哪些位置被mask
1095
+ # p_mask: (M,) float - 每个被mask位置的噪声率,其中M = sum(logits_to_keep_half)
1096
+ assert logits_to_keep_half.shape == (bsz, seq_len), f"logits_to_keep_half shape error: {logits_to_keep_half.shape}"
1097
+ assert p_mask.shape == (logits_to_keep_half.sum(),), f"p_mask shape error: {p_mask.shape}, expected ({logits_to_keep_half.sum()},)"
1098
+
1099
+ # 如果提供了p_mask_input(用于RL训练),计算p_to_keep
1100
+ # p_to_keep表示从masked位置中选出p_mask=True的位置
1101
+ p_to_keep = None
1102
+ if p_mask_input is not None:
1103
+ # 注意:外部传入的p_mask_input已经只在response部分(通过 & response_mask),不需要再次过滤
1104
+ # p_mask_input (B, L), logits_to_keep_half (B, L)
1105
+ # p_to_keep (M,) bool,其中M=sum(logits_to_keep_half)
1106
+ p_to_keep = p_mask_input[logits_to_keep_half]
1107
+
1108
+ router_noisy_part_list = []
1109
+ for i in range(bsz):
1110
+ cur_router_noisy_part = (torch.arange(num_tokens[i].shape[0] *2) % 2 == 0).to(inputs_ids.device)
1111
+ cur_router_noisy_part = cur_router_noisy_part.repeat_interleave(num_tokens[i].repeat_interleave(2))
1112
+ router_noisy_part_list.append(cur_router_noisy_part)
1113
+ router_noisy_part = torch.stack(router_noisy_part_list, dim=0)
1114
+
1115
+ # concated inputs_ids: (bzs, seq_len x 2)
1116
+ concat_inputs_ids = inputs_ids.repeat(1, 2)
1117
+ # concated logits_to_keep: (bsz, seq_len x 2)
1118
+ logits_to_keep = torch.zeros(
1119
+ bsz, 2 * seq_len, dtype=torch.bool, device=inputs_ids.device)
1120
+ # concated position_ids: (bsz, seq_len x 2)
1121
+ concat_position_ids = torch.zeros(
1122
+ bsz, 2 * seq_len, dtype=position_ids.dtype, device=position_ids.device)
1123
+ for i in range(bsz):
1124
+ concat_inputs_ids[i][router_noisy_part[i]] = noisy_inputs_ids[i]
1125
+ concat_inputs_ids[i][~router_noisy_part[i]] = inputs_ids[i]
1126
+
1127
+ logits_to_keep[i][router_noisy_part[i]] = logits_to_keep_half[i]
1128
+
1129
+ concat_position_ids[i][router_noisy_part[i]] = position_ids[i]
1130
+ concat_position_ids[i][~router_noisy_part[i]] = position_ids[i]
1131
+
1132
+ # create flex_attention mask
1133
+ attention_mask = block_attn_mask(num_tokens, self.config.block_size, inputs_ids.device)
1134
+ flex_attention_mask_3d = create_block_mask(
1135
+ lambda b, h, q_idx, kv_idx: attention_mask[b, q_idx, kv_idx],
1136
+ B=attention_mask.size(0), H=None,
1137
+ Q_LEN=attention_mask.size(1), KV_LEN=attention_mask.size(2),
1138
+ )
1139
+
1140
+ return concat_inputs_ids, concat_position_ids, flex_attention_mask_3d, logits_to_keep_half, logits_to_keep, p_mask, p_to_keep
1141
+
1142
  @can_return_tuple
1143
  @auto_docstring
1144
  def forward(
 
1154
  output_hidden_states: Optional[bool] = None,
1155
  cache_position: Optional[torch.LongTensor] = None,
1156
  logits_to_keep: Union[int, torch.Tensor] = 0,
1157
+ masked_indices: Optional[torch.Tensor] = None,
1158
+ return_logits: bool = False,
1159
+ # RL training parameters
1160
+ compute_rl_loss: bool = False,
1161
+ p_mask: Optional[torch.Tensor] = None,
1162
+ adv: Optional[torch.Tensor] = None,
1163
+ adv_optimization: bool = False,
1164
+ logp_old_tok: Optional[torch.Tensor] = None,
1165
+ logp_ref_tok: Optional[torch.Tensor] = None,
1166
+ is_real: Optional[torch.Tensor] = None,
1167
+ ppo_eps: float = 0.2,
1168
+ kl_beta: float = 0.0,
1169
+ use_kl_estimator_k3: bool = True,
1170
+ return_entropy: bool = False,
1171
+ dynamic_threshold: Optional[float] = None,
1172
+ loss_mean: bool = True,
1173
  **kwargs: Unpack[KwargsForCausalLM],
1174
  ) -> CausalLMOutputWithPast:
1175
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1176
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1177
  output_hidden_states = (
1178
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1179
  )
1180
 
1181
+ if self.training:
1182
+ assert inputs_embeds is None, "only support input_ids during training"
1183
+
1184
+ prompt_mask = (labels == -100) if labels is not None else None
1185
+ position_ids = modify_padded_position_ids_2d(position_ids)
1186
+
1187
+ (
1188
+ concat_inputs_ids,
1189
+ concat_position_ids,
1190
+ flex_attention_mask_3d,
1191
+ logits_to_keep_half,
1192
+ logits_to_keep,
1193
+ p_mask_out,
1194
+ p_to_keep,
1195
+ ) = self.prepare_for_bd_training(
1196
+ input_ids, position_ids, prompt_mask, masked_indices, p_mask_input=p_mask
1197
+ )
1198
 
1199
+ outputs = self.model(
1200
+ input_ids=concat_inputs_ids,
1201
+ attention_mask=flex_attention_mask_3d,
1202
+ position_ids=concat_position_ids,
1203
+ output_attentions=output_attentions,
1204
+ output_hidden_states=output_hidden_states,
1205
+ return_dict=True,
1206
+ cache_position=cache_position,
1207
+ **kwargs,
1208
+ )
1209
+
1210
+ hidden_states = outputs.last_hidden_state
1211
+ hidden_states = hidden_states[logits_to_keep].contiguous()
1212
+
1213
+ # 初始化 entropy
1214
+ entropy = torch.tensor(0.0, device=input_ids.device)
1215
+
1216
+ # ====================== RL loss��PPO) ======================
1217
+ if compute_rl_loss:
1218
+ assert p_to_keep is not None, "p_mask must be provided for RL loss computation."
1219
+ assert adv is not None, "adv must be provided for RL loss computation."
1220
+ assert is_real is not None, "is_real must be provided for RL loss computation."
1221
+ assert labels is not None, "labels must be provided for RL loss computation."
1222
+ assert masked_indices is not None, "masked_indices must be provided for RL loss computation."
1223
+
1224
+ device = input_ids.device
1225
+
1226
+ # logits (M, V) — 保持原样
1227
+ logits = self.lm_head(hidden_states)
1228
+
1229
+ # mask — 保持原样
1230
+ is_real_tensor = (
1231
+ is_real.to(device=device, dtype=torch.bool)
1232
+ if torch.is_tensor(is_real)
1233
+ else torch.tensor(is_real, dtype=torch.bool, device=device)
1234
+ )
1235
+ p_mask_real = p_mask & is_real_tensor.unsqueeze(1) # (B, L)
1236
+ p_to_keep_real = p_mask_real[masked_indices] # (M,) bool
1237
+
1238
+ # 选出 logits — 保持原样
1239
+ logits_p = logits[p_to_keep_real] # (N, V)
1240
+ N = p_to_keep_real.sum().item()
1241
+ total_response_tokens = (labels != -100).sum().item()
1242
+ total_p_mask = p_mask.sum().item()
1243
+ total_masked_indices = masked_indices.sum().item()
1244
+ total_is_real = is_real_tensor.sum().item() if is_real_tensor.dim() > 0 else (1 if is_real_tensor.item() else 0)
1245
+
1246
+
1247
+ # log_softmax
1248
+ log_probs_p = torch.nn.functional.log_softmax(logits_p, dim=-1)
1249
+
1250
+ # labels / logp — 保持原样
1251
+ labels_p = labels[masked_indices][p_to_keep_real] # (N,)
1252
+ logp_p = log_probs_p.gather(dim=-1, index=labels_p.unsqueeze(-1)).squeeze(-1)
1253
+
1254
+ # entropy(可选)
1255
+ if return_entropy:
1256
+ with torch.no_grad():
1257
+ entropy_p = -(log_probs_p.exp() * log_probs_p).sum(dim=-1)
1258
+ entropy = entropy_p.mean() if entropy_p.numel() > 0 else torch.tensor(0.0, device=device)
1259
+ del entropy_p
1260
+
1261
+ # advantage 处理
1262
+ adv_tensor = adv.to(device) if torch.is_tensor(adv) else torch.tensor(adv, dtype=torch.float, device=device)
1263
+ adv_optimization=False
1264
+ if adv_optimization:
1265
+ # token级别优化:对相同前缀取最大advantage(剪枝优化版本)
1266
+ response_mask = (labels != -100) # (B, L)
1267
+ bsz, seq_len = input_ids.shape
1268
+
1269
+ # 预计算每个样本的response起始位置
1270
+ response_starts = torch.full((bsz,), seq_len, dtype=torch.long, device=device)
1271
+ for b in range(bsz):
1272
+ if response_mask[b].any():
1273
+ response_starts[b] = response_mask[b].long().argmax()
1274
+
1275
+ # 剪枝1: 找出已经是最大advantage的样本,直接填充不参与比较
1276
+ max_adv_value = adv_tensor.max()
1277
+ is_max_adv = (adv_tensor == max_adv_value) # (B,) bool
1278
+
1279
+ # 创建优化后的 advantage map (B, L),确保dtype与adv_tensor一致
1280
+ optimized_adv = torch.zeros_like(labels, dtype=adv_tensor.dtype)
1281
+
1282
+ # 对于已是最大advantage的样本,直接填充
1283
+ for b in range(bsz):
1284
+ if is_max_adv[b]:
1285
+ optimized_adv[b][response_mask[b]] = max_adv_value
1286
+
1287
+ # 统计信息
1288
+ total_response_tokens = 0
1289
+ updated_tokens = 0
1290
+ skipped_tokens = 0
1291
+ original_adv_sum = 0.0
1292
+ optimized_adv_sum = 0.0
1293
+
1294
+ # 按position处理,批量比较前缀
1295
+ for pos in range(seq_len):
1296
+ valid_samples = response_mask[:, pos] # (B,)
1297
+ if not valid_samples.any():
1298
+ continue
1299
+
1300
+ # 剪枝2: 排除已是最大advantage的样本
1301
+ valid_samples = valid_samples & ~is_max_adv
1302
+ if not valid_samples.any():
1303
+ # 所有样本都是最大值,统计后跳过
1304
+ max_count = (response_mask[:, pos] & is_max_adv).sum().item()
1305
+ total_response_tokens += max_count
1306
+ skipped_tokens += max_count
1307
+ original_adv_sum += max_adv_value.item() * max_count
1308
+ optimized_adv_sum += max_adv_value.item() * max_count
1309
+ continue
1310
+
1311
+ # 获取所有需要处理的样本索引
1312
+ valid_indices = valid_samples.nonzero(as_tuple=True)[0] # (N,)
1313
+
1314
+ for b in valid_indices:
1315
+ b_item = b.item()
1316
+ response_start = response_starts[b_item].item()
1317
+ prefix_len = pos + 1 - response_start
1318
+
1319
+ if prefix_len <= 0:
1320
+ optimized_adv[b_item, pos] = adv_tensor[b_item]
1321
+ continue
1322
+
1323
+ # 找出所有response起始位置相同且在pos位置有效的样本(包括已是最大值的)
1324
+ same_start_mask = (response_starts == response_start) & response_mask[:, pos]
1325
+ same_start_indices = same_start_mask.nonzero(as_tuple=True)[0]
1326
+
1327
+ if len(same_start_indices) == 1:
1328
+ # 只有自己,不需要比较
1329
+ optimized_adv[b_item, pos] = adv_tensor[b_item]
1330
+ total_response_tokens += 1
1331
+ original_adv_sum += adv_tensor[b_item].item()
1332
+ optimized_adv_sum += adv_tensor[b_item].item()
1333
+ continue
1334
+
1335
+ # 剪枝3: 如果候选中有最大advantage样本,可以直接用最大值
1336
+ has_max_in_candidates = (same_start_mask & is_max_adv).any()
1337
+
1338
+ prefix_end = pos + 1
1339
+ current_prefix = input_ids[b_item, response_start:prefix_end]
1340
+
1341
+ # 批量比较:提取所有候选样本的前缀
1342
+ prefixes = input_ids[same_start_indices, response_start:prefix_end] # (M, prefix_len)
1343
+
1344
+ # 使用广播比较:(M, prefix_len) vs (prefix_len,)
1345
+ matches = (prefixes == current_prefix.unsqueeze(0)).all(dim=1) # (M,)
1346
+
1347
+ # 找到匹配的样本
1348
+ matching_indices = same_start_indices[matches]
1349
+
1350
+ # 在相同前缀的样本中取最大 advantage
1351
+ original_adv_value = adv_tensor[b_item].item()
1352
+ if matching_indices.numel() > 0:
1353
+ # 剪枝4: 如果匹配中有最大值样本,直接用最大值
1354
+ if has_max_in_candidates and is_max_adv[matching_indices].any():
1355
+ max_adv = max_adv_value
1356
+ else:
1357
+ max_adv = adv_tensor[matching_indices].max()
1358
+
1359
+ optimized_adv[b_item, pos] = max_adv
1360
+ # 统计
1361
+ if abs(max_adv.item() - original_adv_value) > 1e-6:
1362
+ updated_tokens += 1
1363
+ original_adv_sum += original_adv_value
1364
+ optimized_adv_sum += max_adv.item()
1365
+ else:
1366
+ optimized_adv[b_item, pos] = adv_tensor[b_item]
1367
+ original_adv_sum += original_adv_value
1368
+ optimized_adv_sum += original_adv_value
1369
+
1370
+ total_response_tokens += 1
1371
+
1372
+ # 输出统计信息
1373
+ if total_response_tokens > 0:
1374
+ update_ratio = updated_tokens / total_response_tokens
1375
+ skip_ratio = skipped_tokens / total_response_tokens
1376
+ avg_original = original_adv_sum / total_response_tokens
1377
+ avg_optimized = optimized_adv_sum / total_response_tokens
1378
+ print(f"[Adv Optimization] Total: {total_response_tokens}, "
1379
+ f"Updated: {updated_tokens} ({update_ratio:.2%}), "
1380
+ f"Skipped: {skipped_tokens} ({skip_ratio:.2%}), "
1381
+ f"Avg adv: {avg_original:.4f} -> {avg_optimized:.4f} "
1382
+ f"(+{avg_optimized - avg_original:.4f})")
1383
+
1384
+ # 使用优化后的 advantage
1385
+ adv_expanded = optimized_adv
1386
+ else:
1387
+ # 不优化:直接使用原始 advantage
1388
+ adv_expanded = adv_tensor.unsqueeze(1).expand_as(p_mask)
1389
+
1390
+ adv_p = adv_expanded[masked_indices][p_to_keep_real]
1391
+
1392
+ # old logp
1393
+ if logp_old_tok is not None and logp_old_tok.numel() > 0:
1394
+ logp_old_p = logp_old_tok.to(device)[masked_indices][p_to_keep_real]
1395
+ else:
1396
+ logp_old_p = logp_p.detach()
1397
+
1398
+ # ratio/exp
1399
+ ratio_p = (logp_p - logp_old_p).clamp(-10.0, 10.0).exp()
1400
+ clipped = ratio_p.clamp(1 - ppo_eps, 1 + ppo_eps+0.08)
1401
+ surrogate_p = torch.minimum(ratio_p * adv_p, clipped * adv_p)
1402
+ # 输出离1最远的ratio值
1403
+ # if not torch.allclose(ratio_p, torch.ones_like(ratio_p)):
1404
+ furthest_value = ratio_p[torch.abs(ratio_p - 1).argmax()]
1405
+ # print(f"Furthest ratio from 1: {furthest_value.item()}")
1406
+
1407
+ # Policy loss: use mean or sum based on loss_mean parameter
1408
+ num_masked = masked_indices.sum().item()
1409
+ num_loss_elements = surrogate_p.numel()
1410
+ print(f"masked_indices.sum()={num_masked}, surrogate_p.numel()={num_loss_elements}")
1411
+ if loss_mean:
1412
+ policy_loss = -surrogate_p.mean()
1413
+ else:
1414
+ policy_loss = -surrogate_p.sum()
1415
+
1416
+ # KL(可选)
1417
+ kl_loss = torch.tensor(0.0, device=device)
1418
+ if kl_beta > 0 and logp_ref_tok is not None:
1419
+ logp_ref_p = logp_ref_tok.to(device)[masked_indices][p_to_keep_real]
1420
+ kl_seq_p = logp_p - logp_ref_p
1421
+
1422
+ if use_kl_estimator_k3:
1423
+ kl_seq_p = (-kl_seq_p).clamp(-10.0, 10.0).exp() - 1.0 + kl_seq_p
1424
+
1425
+ # KL loss: use mean or sum based on loss_mean parameter
1426
+ if loss_mean:
1427
+ kl_loss = kl_beta * kl_seq_p.mean()
1428
+ else:
1429
+ kl_loss = kl_beta * kl_seq_p.sum()
1430
+ del logp_ref_p, kl_seq_p
1431
+
1432
+ loss = policy_loss + kl_loss
1433
+ kl_loss_value = kl_loss.detach().clone()
1434
+
1435
+ # 清理
1436
+ del logits, logits_p, log_probs_p, labels_p
1437
+ del is_real_tensor, p_mask_real, p_to_keep_real
1438
+ del adv_tensor, adv_expanded, adv_p
1439
+ del logp_p, logp_old_p, ratio_p, clipped, surrogate_p
1440
+ del policy_loss, kl_loss
1441
+
1442
+ logits = None
1443
+
1444
+ # ====================== GRPO / return logits ======================
1445
+ elif return_logits:
1446
+ logits = self.lm_head(hidden_states)
1447
+ loss = None
1448
+
1449
+ # ====================== Block Diffusion fused loss ======================
1450
+ else:
1451
+ assert labels is not None, "Labels must be provided for training."
1452
+ answer_len = (labels != -100).sum()
1453
+ loss_fct = FusedLinearDiffusionCrossEntropyLoss(reduction="sum")
1454
+ loss = loss_fct(
1455
+ x=hidden_states,
1456
+ target=labels[logits_to_keep_half].contiguous(),
1457
+ weight=self.lm_head.weight,
1458
+ bias=self.lm_head.bias,
1459
+ p_mask=p_mask_out,
1460
+ )
1461
+ loss = loss / answer_len
1462
+ logits = None
1463
+
1464
+ # ====================== eval / inference ======================
1465
  else:
1466
+ outputs: BaseModelOutputWithPast = self.model(
1467
+ input_ids=input_ids,
1468
+ attention_mask=attention_mask,
1469
+ position_ids=position_ids,
1470
+ past_key_values=past_key_values,
1471
+ inputs_embeds=inputs_embeds,
1472
+ use_cache=use_cache,
1473
+ output_attentions=output_attentions,
1474
+ output_hidden_states=output_hidden_states,
1475
+ cache_position=cache_position,
1476
+ **kwargs,
1477
+ )
1478
 
1479
+ hidden_states = outputs.last_hidden_state
1480
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1481
+ hidden_states = hidden_states[:, slice_indices, :].contiguous()
1482
+
1483
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
1484
+ if fuse_linear_and_cross_entropy:
1485
+ logits = None
1486
+ else:
1487
+ logits = self.lm_head(hidden_states)
1488
 
1489
+ loss = None
1490
+ if labels is not None:
1491
+ loss_fct = nn.CrossEntropyLoss()
1492
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
1493
+
1494
+ output = CausalLMOutputWithPast(
1495
  loss=loss,
1496
  logits=logits,
1497
  past_key_values=outputs.past_key_values,
 
1499
  attentions=outputs.attentions,
1500
  )
1501
 
1502
+ if self.training and compute_rl_loss:
1503
+ output.entropy = entropy
1504
+ output.kl_loss = kl_loss_value if "kl_loss_value" in locals() else torch.tensor(0.0, device=input_ids.device)
1505
+
1506
+ return output
1507
+
1508
 
1509
  __all__ = [
1510
  "SDARForCausalLM",