training: fix type mismatch when training (#6)
Browse files- convert moe results to input dtype (827ce49e9f70f875ec446521564bdb5acd03f534)
Co-authored-by: Chen <Jack477@users.noreply.huggingface.co>
- modeling_deepseek.py +1 -0
modeling_deepseek.py
CHANGED
|
@@ -577,6 +577,7 @@ class DeepseekV2MoE(nn.Module):
|
|
| 577 |
for i, expert in enumerate(self.experts):
|
| 578 |
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
| 579 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
|
|
|
| 580 |
y = y.view(*orig_shape)
|
| 581 |
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
| 582 |
else:
|
|
|
|
| 577 |
for i, expert in enumerate(self.experts):
|
| 578 |
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
| 579 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
| 580 |
+
y = y.type(hidden_states.dtype)
|
| 581 |
y = y.view(*orig_shape)
|
| 582 |
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
| 583 |
else:
|