chua commited on
Commit
6564fb9
·
verified ·
1 Parent(s): 1c3f89f

Upload modeling_qwen3.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_qwen3.py +10 -10
modeling_qwen3.py CHANGED
@@ -578,7 +578,7 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
578
  # [ADD] Custom Logit Masking Logic (Inference Only)
579
  # =================================================================
580
  # 仅在非训练模式 (self.training == False) 且 input_ids 存在时执行
581
- if not self.training and input_ids is not None:
582
  # 1. 判断 Mask 触发条件
583
  # input_ids shape: (batch_size, seq_len)
584
 
@@ -591,29 +591,29 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
591
  if seq_len == 1:
592
  # 检查 input_ids 是否在 [172207, 180398] 区间内
593
  in_safe_range = ((input_ids >= 172207) & (input_ids <= 180398)).any(dim=-1)
594
- # 如果不在安全区间,则满足条件 B
595
  cond_b = ~in_safe_range
596
  else:
597
- # 长度不为1,条件 B 必定不满足
598
  cond_b = torch.zeros_like(has_trigger_token, dtype=torch.bool)
599
 
600
- # 综合条件: 满足 A 或 满足 B
601
  rows_to_mask = has_trigger_token | cond_b
602
 
603
  # 2. 执行 Mask 操作
604
  if rows_to_mask.any():
605
- # 离散 Token 列表
 
 
606
  target_discrete_tokens = [151691, 151692, 151693, 151695, 151696, 151697, 151698]
607
  mask_indices = torch.tensor(target_discrete_tokens, device=logits.device)
608
  neg = torch.finfo(logits.dtype).min
 
609
  # (1) Mask 离散 Token
610
- # logits[rows_to_mask] 选取需要 mask batch
611
- # [:, mask_indices] 选取特定的 token ID 列
612
- logits[rows_to_mask, :, mask_indices] = neg
613
 
614
  # (2) Mask 连续区间 [172206, 180398]
615
- # 注意:Python 切片右边界是开区间,所以要写到 180399
616
- logits[rows_to_mask, :, 151727:180399] = neg
617
  # =================================================================
618
 
619
  loss = None
 
578
  # [ADD] Custom Logit Masking Logic (Inference Only)
579
  # =================================================================
580
  # 仅在非训练模式 (self.training == False) 且 input_ids 存在时执行
581
+ if not self.training and input_ids is not None and labels is None:
582
  # 1. 判断 Mask 触发条件
583
  # input_ids shape: (batch_size, seq_len)
584
 
 
591
  if seq_len == 1:
592
  # 检查 input_ids 是否在 [172207, 180398] 区间内
593
  in_safe_range = ((input_ids >= 172207) & (input_ids <= 180398)).any(dim=-1)
 
594
  cond_b = ~in_safe_range
595
  else:
 
596
  cond_b = torch.zeros_like(has_trigger_token, dtype=torch.bool)
597
 
598
+ # 综合条件
599
  rows_to_mask = has_trigger_token | cond_b
600
 
601
  # 2. 执行 Mask 操作
602
  if rows_to_mask.any():
603
+ # 修改点 2: 获取行索引,解决广播报错的关键
604
+ row_idxs = torch.nonzero(rows_to_mask, as_tuple=True)[0]
605
+
606
  target_discrete_tokens = [151691, 151692, 151693, 151695, 151696, 151697, 151698]
607
  mask_indices = torch.tensor(target_discrete_tokens, device=logits.device)
608
  neg = torch.finfo(logits.dtype).min
609
+
610
  # (1) Mask 离散 Token
611
+ # 修改点 3: 使用 row_idxs[:, None] 将形状变为 [N, 1],使其能与 mask_indices [7] 广播
612
+ if mask_indices.numel() > 0:
613
+ logits[row_idxs[:, None], :, mask_indices] = neg
614
 
615
  # (2) Mask 连续区间 [172206, 180398]
616
+ logits[row_idxs, :, 151727:180399] = neg
 
617
  # =================================================================
618
 
619
  loss = None