Upload modeling_qwen3.py with huggingface_hub
Browse files- modeling_qwen3.py +10 -10
modeling_qwen3.py
CHANGED
|
@@ -578,7 +578,7 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
|
|
| 578 |
# [ADD] Custom Logit Masking Logic (Inference Only)
|
| 579 |
# =================================================================
|
| 580 |
# 仅在非训练模式 (self.training == False) 且 input_ids 存在时执行
|
| 581 |
-
if not self.training and input_ids is not None:
|
| 582 |
# 1. 判断 Mask 触发条件
|
| 583 |
# input_ids shape: (batch_size, seq_len)
|
| 584 |
|
|
@@ -591,29 +591,29 @@ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
|
|
| 591 |
if seq_len == 1:
|
| 592 |
# 检查 input_ids 是否在 [172207, 180398] 区间内
|
| 593 |
in_safe_range = ((input_ids >= 172207) & (input_ids <= 180398)).any(dim=-1)
|
| 594 |
-
# 如果不在安全区间,则满足条件 B
|
| 595 |
cond_b = ~in_safe_range
|
| 596 |
else:
|
| 597 |
-
# 长度不为1,条件 B 必定不满足
|
| 598 |
cond_b = torch.zeros_like(has_trigger_token, dtype=torch.bool)
|
| 599 |
|
| 600 |
-
#
|
| 601 |
rows_to_mask = has_trigger_token | cond_b
|
| 602 |
|
| 603 |
# 2. 执行 Mask 操作
|
| 604 |
if rows_to_mask.any():
|
| 605 |
-
#
|
|
|
|
|
|
|
| 606 |
target_discrete_tokens = [151691, 151692, 151693, 151695, 151696, 151697, 151698]
|
| 607 |
mask_indices = torch.tensor(target_discrete_tokens, device=logits.device)
|
| 608 |
neg = torch.finfo(logits.dtype).min
|
|
|
|
| 609 |
# (1) Mask 离散 Token
|
| 610 |
-
#
|
| 611 |
-
|
| 612 |
-
|
| 613 |
|
| 614 |
# (2) Mask 连续区间 [172206, 180398]
|
| 615 |
-
|
| 616 |
-
logits[rows_to_mask, :, 151727:180399] = neg
|
| 617 |
# =================================================================
|
| 618 |
|
| 619 |
loss = None
|
|
|
|
| 578 |
# [ADD] Custom Logit Masking Logic (Inference Only)
|
| 579 |
# =================================================================
|
| 580 |
# 仅在非训练模式 (self.training == False) 且 input_ids 存在时执行
|
| 581 |
+
if not self.training and input_ids is not None and labels is None:
|
| 582 |
# 1. 判断 Mask 触发条件
|
| 583 |
# input_ids shape: (batch_size, seq_len)
|
| 584 |
|
|
|
|
| 591 |
if seq_len == 1:
|
| 592 |
# 检查 input_ids 是否在 [172207, 180398] 区间内
|
| 593 |
in_safe_range = ((input_ids >= 172207) & (input_ids <= 180398)).any(dim=-1)
|
|
|
|
| 594 |
cond_b = ~in_safe_range
|
| 595 |
else:
|
|
|
|
| 596 |
cond_b = torch.zeros_like(has_trigger_token, dtype=torch.bool)
|
| 597 |
|
| 598 |
+
# 综合条件
|
| 599 |
rows_to_mask = has_trigger_token | cond_b
|
| 600 |
|
| 601 |
# 2. 执行 Mask 操作
|
| 602 |
if rows_to_mask.any():
|
| 603 |
+
# 修改点 2: 获取行索引,解决广播报错的关键
|
| 604 |
+
row_idxs = torch.nonzero(rows_to_mask, as_tuple=True)[0]
|
| 605 |
+
|
| 606 |
target_discrete_tokens = [151691, 151692, 151693, 151695, 151696, 151697, 151698]
|
| 607 |
mask_indices = torch.tensor(target_discrete_tokens, device=logits.device)
|
| 608 |
neg = torch.finfo(logits.dtype).min
|
| 609 |
+
|
| 610 |
# (1) Mask 离散 Token
|
| 611 |
+
# 修改点 3: 使用 row_idxs[:, None] 将形状变为 [N, 1],使其能与 mask_indices [7] 广播
|
| 612 |
+
if mask_indices.numel() > 0:
|
| 613 |
+
logits[row_idxs[:, None], :, mask_indices] = neg
|
| 614 |
|
| 615 |
# (2) Mask 连续区间 [172206, 180398]
|
| 616 |
+
logits[row_idxs, :, 151727:180399] = neg
|
|
|
|
| 617 |
# =================================================================
|
| 618 |
|
| 619 |
loss = None
|