inferencerlabs commited on
Commit
1d854df
·
verified ·
1 Parent(s): 32d49f3

Upload complete model

Browse files
Files changed (1) hide show
  1. modeling_deepseek.py +1808 -0
modeling_deepseek.py ADDED
@@ -0,0 +1,1808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch DeepSeek model."""
21
+ import math
22
+ import warnings
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.distributed as dist
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache
34
+ from transformers.modeling_attn_mask_utils import \
35
+ _prepare_4d_causal_attention_mask
36
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
37
+ CausalLMOutputWithPast,
38
+ SequenceClassifierOutputWithPast)
39
+ from transformers.modeling_utils import PreTrainedModel
40
+ from transformers.pytorch_utils import (ALL_LAYERNORM_LAYERS,
41
+ is_torch_greater_or_equal_than_1_13)
42
+ from transformers.utils import (add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ is_flash_attn_2_available,
45
+ is_flash_attn_greater_or_equal_2_10, logging,
46
+ replace_return_docstrings)
47
+ from transformers.utils.import_utils import is_torch_fx_available
48
+
49
+ from .configuration_deepseek import DeepseekV3Config
50
+
51
+ if is_flash_attn_2_available():
52
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
53
+ from flash_attn.bert_padding import pad_input # noqa
54
+ from flash_attn.bert_padding import index_first_axis, unpad_input
55
+
56
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
57
+ # It means that the function will not be traced through and simply appear as a node in the graph.
58
+ if is_torch_fx_available():
59
+ if not is_torch_greater_or_equal_than_1_13:
60
+ import torch.fx
61
+
62
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(
63
+ _prepare_4d_causal_attention_mask)
64
+
65
+ logger = logging.get_logger(__name__)
66
+
67
+ _CONFIG_FOR_DOC = "DeepseekV3Config"
68
+
69
+
70
+ def _get_unpad_data(attention_mask):
71
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
72
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
73
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
74
+ cu_seqlens = F.pad(
75
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
76
+ return (
77
+ indices,
78
+ cu_seqlens,
79
+ max_seqlen_in_batch,
80
+ )
81
+
82
+
83
+ # code modified from transformers 4.48.3 to amend breaks in newer transformers versions
84
+ def get_usable_length(past_key_value,
85
+ new_seq_length: int,
86
+ layer_idx: Optional[int] = 0) -> int:
87
+ max_length = past_key_value.get_max_cache_shape()
88
+ previous_seq_length = past_key_value.get_seq_length(layer_idx)
89
+ if max_length is not None and max_length > 0 and previous_seq_length + new_seq_length > max_length:
90
+ return max_length - new_seq_length
91
+ return previous_seq_length
92
+
93
+
94
+ class DeepseekV3RMSNorm(nn.Module):
95
+
96
+ def __init__(self, hidden_size, eps=1e-6):
97
+ """
98
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
99
+ """
100
+ super().__init__()
101
+ self.weight = nn.Parameter(torch.ones(hidden_size))
102
+ self.variance_epsilon = eps
103
+
104
+ def forward(self, hidden_states):
105
+ input_dtype = hidden_states.dtype
106
+ hidden_states = hidden_states.to(torch.float32)
107
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
108
+ hidden_states = hidden_states * torch.rsqrt(variance +
109
+ self.variance_epsilon)
110
+ return self.weight * hidden_states.to(input_dtype)
111
+
112
+
113
+ ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)
114
+
115
+
116
+ class DeepseekV3RotaryEmbedding(nn.Module):
117
+
118
+ def __init__(self,
119
+ dim,
120
+ max_position_embeddings=2048,
121
+ base=10000,
122
+ device=None):
123
+ super().__init__()
124
+
125
+ self.dim = dim
126
+ self.max_position_embeddings = max_position_embeddings
127
+ self.base = base
128
+ inv_freq = 1.0 / (self.base**(
129
+ torch.arange(0, self.dim, 2).float().to(device) / self.dim))
130
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
131
+
132
+ # Build here to make `torch.jit.trace` work.
133
+ self._set_cos_sin_cache(
134
+ seq_len=max_position_embeddings,
135
+ device=self.inv_freq.device,
136
+ dtype=torch.get_default_dtype(),
137
+ )
138
+ self.max_seq_len_cached = None
139
+
140
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
141
+ self.max_seq_len_cached = seq_len
142
+ t = torch.arange(self.max_seq_len_cached,
143
+ device=device,
144
+ dtype=self.inv_freq.dtype)
145
+
146
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
147
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
148
+ emb = torch.cat((freqs, freqs), dim=-1)
149
+ self.register_buffer("cos_cached",
150
+ emb.cos().to(dtype),
151
+ persistent=False)
152
+ self.register_buffer("sin_cached",
153
+ emb.sin().to(dtype),
154
+ persistent=False)
155
+
156
+ def forward(self, x, seq_len=None):
157
+ # x: [bs, num_attention_heads, seq_len, head_size]
158
+ if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
159
+ self._set_cos_sin_cache(seq_len=seq_len,
160
+ device=x.device,
161
+ dtype=x.dtype)
162
+
163
+ return (
164
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
165
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
166
+ )
167
+
168
+
169
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3
170
+ class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
171
+ """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
+
173
+ def __init__(
174
+ self,
175
+ dim,
176
+ max_position_embeddings=2048,
177
+ base=10000,
178
+ device=None,
179
+ scaling_factor=1.0,
180
+ ):
181
+ self.scaling_factor = scaling_factor
182
+ super().__init__(dim, max_position_embeddings, base, device)
183
+
184
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
185
+ self.max_seq_len_cached = seq_len
186
+ t = torch.arange(self.max_seq_len_cached,
187
+ device=device,
188
+ dtype=self.inv_freq.dtype)
189
+ t = t / self.scaling_factor
190
+
191
+ freqs = torch.outer(t, self.inv_freq)
192
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
193
+ emb = torch.cat((freqs, freqs), dim=-1)
194
+ self.register_buffer("cos_cached",
195
+ emb.cos().to(dtype),
196
+ persistent=False)
197
+ self.register_buffer("sin_cached",
198
+ emb.sin().to(dtype),
199
+ persistent=False)
200
+
201
+
202
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3
203
+ class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
204
+ """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
205
+
206
+ def __init__(
207
+ self,
208
+ dim,
209
+ max_position_embeddings=2048,
210
+ base=10000,
211
+ device=None,
212
+ scaling_factor=1.0,
213
+ ):
214
+ self.scaling_factor = scaling_factor
215
+ super().__init__(dim, max_position_embeddings, base, device)
216
+
217
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
218
+ self.max_seq_len_cached = seq_len
219
+
220
+ if seq_len > self.max_position_embeddings:
221
+ base = self.base * ((self.scaling_factor * seq_len /
222
+ self.max_position_embeddings) -
223
+ (self.scaling_factor - 1))**(self.dim /
224
+ (self.dim - 2))
225
+ inv_freq = 1.0 / (base**(
226
+ torch.arange(0, self.dim, 2).float().to(device) / self.dim))
227
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
228
+
229
+ t = torch.arange(self.max_seq_len_cached,
230
+ device=device,
231
+ dtype=self.inv_freq.dtype)
232
+
233
+ freqs = torch.outer(t, self.inv_freq)
234
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
235
+ emb = torch.cat((freqs, freqs), dim=-1)
236
+ self.register_buffer("cos_cached",
237
+ emb.cos().to(dtype),
238
+ persistent=False)
239
+ self.register_buffer("sin_cached",
240
+ emb.sin().to(dtype),
241
+ persistent=False)
242
+
243
+
244
+ # Inverse dim formula to find dim based on number of rotations
245
+ def yarn_find_correction_dim(num_rotations,
246
+ dim,
247
+ base=10000,
248
+ max_position_embeddings=2048):
249
+ return (dim * math.log(max_position_embeddings /
250
+ (num_rotations * 2 * math.pi))) / (2 *
251
+ math.log(base))
252
+
253
+
254
+ # Find dim range bounds based on rotations
255
+ def yarn_find_correction_range(low_rot,
256
+ high_rot,
257
+ dim,
258
+ base=10000,
259
+ max_position_embeddings=2048):
260
+ low = math.floor(
261
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
262
+ high = math.ceil(
263
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
264
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
265
+
266
+
267
+ def yarn_get_mscale(scale=1, mscale=1):
268
+ if scale <= 1:
269
+ return 1.0
270
+ return 0.1 * mscale * math.log(scale) + 1.0
271
+
272
+
273
+ def yarn_linear_ramp_mask(min, max, dim):
274
+ if min == max:
275
+ max += 0.001 # Prevent singularity
276
+
277
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
278
+ ramp_func = torch.clamp(linear_func, 0, 1)
279
+ return ramp_func
280
+
281
+
282
+ class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
283
+
284
+ def __init__(
285
+ self,
286
+ dim,
287
+ max_position_embeddings=2048,
288
+ base=10000,
289
+ device=None,
290
+ scaling_factor=1.0,
291
+ original_max_position_embeddings=4096,
292
+ beta_fast=32,
293
+ beta_slow=1,
294
+ mscale=1,
295
+ mscale_all_dim=0,
296
+ ):
297
+ self.scaling_factor = scaling_factor
298
+ self.original_max_position_embeddings = original_max_position_embeddings
299
+ self.beta_fast = beta_fast
300
+ self.beta_slow = beta_slow
301
+ self.mscale = mscale
302
+ self.mscale_all_dim = mscale_all_dim
303
+ super().__init__(dim, max_position_embeddings, base, device)
304
+
305
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
306
+ self.max_seq_len_cached = seq_len
307
+ dim = self.dim
308
+
309
+ freq_extra = 1.0 / (self.base**(
310
+ torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
311
+ freq_inter = 1.0 / (self.scaling_factor * self.base**(
312
+ torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
313
+
314
+ low, high = yarn_find_correction_range(
315
+ self.beta_fast,
316
+ self.beta_slow,
317
+ dim,
318
+ self.base,
319
+ self.original_max_position_embeddings,
320
+ )
321
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
322
+ device=device, dtype=torch.float32)
323
+ inv_freq = freq_inter * (1 -
324
+ inv_freq_mask) + freq_extra * inv_freq_mask
325
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
326
+
327
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
328
+
329
+ freqs = torch.outer(t, inv_freq)
330
+
331
+ _mscale = float(
332
+ yarn_get_mscale(self.scaling_factor, self.mscale) /
333
+ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim))
334
+
335
+ emb = torch.cat((freqs, freqs), dim=-1)
336
+ self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype),
337
+ persistent=False)
338
+ self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype),
339
+ persistent=False)
340
+
341
+
342
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
343
+ def rotate_half(x):
344
+ """Rotates half the hidden dims of the input."""
345
+ x1 = x[..., :x.shape[-1] // 2]
346
+ x2 = x[..., x.shape[-1] // 2:]
347
+ return torch.cat((-x2, x1), dim=-1)
348
+
349
+
350
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
351
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
352
+ """Applies Rotary Position Embedding to the query and key tensors.
353
+
354
+ Args:
355
+ q (`torch.Tensor`): The query tensor.
356
+ k (`torch.Tensor`): The key tensor.
357
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
358
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
359
+ position_ids (`torch.Tensor`):
360
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
361
+ used to pass offsetted position ids when working with a KV-cache.
362
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
363
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
364
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
365
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
366
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
367
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
368
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
369
+ Returns:
370
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
371
+ """
372
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
373
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
374
+
375
+ b, h, s, d = q.shape
376
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
377
+
378
+ b, h, s, d = k.shape
379
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
380
+
381
+ q_embed = (q * cos) + (rotate_half(q) * sin)
382
+ k_embed = (k * cos) + (rotate_half(k) * sin)
383
+ return q_embed, k_embed
384
+
385
+
386
+ class DeepseekV3MLP(nn.Module):
387
+
388
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
389
+ super().__init__()
390
+ self.config = config
391
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
392
+ self.intermediate_size = (config.intermediate_size if intermediate_size
393
+ is None else intermediate_size)
394
+
395
+ self.gate_proj = nn.Linear(self.hidden_size,
396
+ self.intermediate_size,
397
+ bias=False)
398
+ self.up_proj = nn.Linear(self.hidden_size,
399
+ self.intermediate_size,
400
+ bias=False)
401
+ self.down_proj = nn.Linear(self.intermediate_size,
402
+ self.hidden_size,
403
+ bias=False)
404
+ self.act_fn = ACT2FN[config.hidden_act]
405
+
406
+ def forward(self, x):
407
+ down_proj = self.down_proj(
408
+ self.act_fn(self.gate_proj(x)) * self.up_proj(x))
409
+ return down_proj
410
+
411
+
412
+ class MoEGate(nn.Module):
413
+
414
+ def __init__(self, config):
415
+ super().__init__()
416
+ self.config = config
417
+ self.top_k = config.num_experts_per_tok
418
+ self.n_routed_experts = config.n_routed_experts
419
+ self.routed_scaling_factor = config.routed_scaling_factor
420
+ self.scoring_func = config.scoring_func
421
+ self.seq_aux = config.seq_aux
422
+ self.topk_method = config.topk_method
423
+ self.n_group = config.n_group
424
+ self.topk_group = config.topk_group
425
+
426
+ # topk selection algorithm
427
+ self.norm_topk_prob = config.norm_topk_prob
428
+ self.gating_dim = config.hidden_size
429
+ self.weight = nn.Parameter(
430
+ torch.empty((self.n_routed_experts, self.gating_dim)))
431
+ if self.topk_method == "noaux_tc":
432
+ self.e_score_correction_bias = nn.Parameter(
433
+ torch.empty((self.n_routed_experts)))
434
+ self.reset_parameters()
435
+
436
+ def reset_parameters(self) -> None:
437
+ import torch.nn.init as init
438
+
439
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
440
+
441
+ def forward(self, hidden_states):
442
+ bsz, seq_len, h = hidden_states.shape
443
+ ### compute gating score
444
+ hidden_states = hidden_states.view(-1, h)
445
+ logits = F.linear(hidden_states.type(torch.float32),
446
+ self.weight.type(torch.float32), None)
447
+ if self.scoring_func == "sigmoid":
448
+ scores = logits.sigmoid()
449
+ else:
450
+ raise NotImplementedError(
451
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
452
+ )
453
+
454
+ ### select top-k experts
455
+ if self.topk_method == "noaux_tc":
456
+ assert not self.training
457
+ scores_for_choice = scores.view(
458
+ bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
459
+ group_scores = (scores_for_choice.view(
460
+ bsz * seq_len, self.n_group,
461
+ -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group]
462
+ group_idx = torch.topk(group_scores,
463
+ k=self.topk_group,
464
+ dim=-1,
465
+ sorted=False)[1] # [n, top_k_group]
466
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
467
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
468
+ score_mask = (group_mask.unsqueeze(-1).expand(
469
+ bsz * seq_len, self.n_group,
470
+ self.n_routed_experts // self.n_group).reshape(
471
+ bsz * seq_len, -1)) # [n, e]
472
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(),
473
+ 0.0) # [n, e]
474
+ _, topk_idx = torch.topk(tmp_scores,
475
+ k=self.top_k,
476
+ dim=-1,
477
+ sorted=False)
478
+ topk_weight = scores.gather(1, topk_idx)
479
+ else:
480
+ raise NotImplementedError(
481
+ f"insupportable TopK function for MoE gating: {self.topk_method}"
482
+ )
483
+
484
+ ### norm gate to sum 1
485
+ if self.top_k > 1 and self.norm_topk_prob:
486
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
487
+ topk_weight = topk_weight / denominator
488
+ topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
489
+
490
+ return topk_idx, topk_weight
491
+
492
+
493
+ class DeepseekV3MoE(nn.Module):
494
+ """
495
+ A mixed expert module containing shared experts.
496
+ """
497
+
498
+ def __init__(self, config):
499
+ super().__init__()
500
+ self.config = config
501
+ self.num_experts_per_tok = config.num_experts_per_tok
502
+
503
+ if hasattr(config, "ep_size") and config.ep_size > 1:
504
+ assert config.ep_size == dist.get_world_size()
505
+ self.ep_size = config.ep_size
506
+ self.experts_per_rank = config.n_routed_experts // config.ep_size
507
+ self.ep_rank = dist.get_rank()
508
+ self.experts = nn.ModuleList([
509
+ (DeepseekV3MLP(config,
510
+ intermediate_size=config.moe_intermediate_size)
511
+ if i >= self.ep_rank * self.experts_per_rank
512
+ and i < (self.ep_rank + 1) * self.experts_per_rank else None)
513
+ for i in range(config.n_routed_experts)
514
+ ])
515
+ else:
516
+ self.ep_size = 1
517
+ self.experts_per_rank = config.n_routed_experts
518
+ self.ep_rank = 0
519
+ self.experts = nn.ModuleList([
520
+ DeepseekV3MLP(config,
521
+ intermediate_size=config.moe_intermediate_size)
522
+ for i in range(config.n_routed_experts)
523
+ ])
524
+ self.gate = MoEGate(config)
525
+ if config.n_shared_experts is not None:
526
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
527
+ self.shared_experts = DeepseekV3MLP(
528
+ config=config, intermediate_size=intermediate_size)
529
+
530
+ def forward(self, hidden_states):
531
+ identity = hidden_states
532
+ orig_shape = hidden_states.shape
533
+ topk_idx, topk_weight = self.gate(hidden_states)
534
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
535
+ flat_topk_idx = topk_idx.view(-1)
536
+ if not self.training:
537
+ y = self.moe_infer(hidden_states, topk_idx,
538
+ topk_weight).view(*orig_shape)
539
+ if self.config.n_shared_experts is not None:
540
+ y = y + self.shared_experts(identity)
541
+ return y
542
+
543
+ @torch.no_grad()
544
+ def moe_infer(self, x, topk_ids, topk_weight):
545
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
546
+ cnts.scatter_(1, topk_ids, 1)
547
+ tokens_per_expert = cnts.sum(dim=0)
548
+ idxs = topk_ids.view(-1).argsort()
549
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
550
+ sorted_tokens_shape = sorted_tokens.shape
551
+ if self.ep_size > 1:
552
+ tokens_per_ep_rank = tokens_per_expert.view(self.ep_size,
553
+ -1).sum(dim=1)
554
+ tokens_per_expert_group = tokens_per_expert.new_empty(
555
+ tokens_per_expert.shape[0])
556
+ dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
557
+ output_splits = (tokens_per_expert_group.view(
558
+ self.ep_size, -1).sum(1).cpu().numpy().tolist())
559
+ gathered_tokens = sorted_tokens.new_empty(
560
+ tokens_per_expert_group.sum(dim=0).cpu().item(),
561
+ sorted_tokens.shape[1])
562
+ input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
563
+ dist.all_to_all(
564
+ list(gathered_tokens.split(output_splits)),
565
+ list(sorted_tokens.split(input_split_sizes)),
566
+ )
567
+ tokens_per_expert_post_gather = tokens_per_expert_group.view(
568
+ self.ep_size, self.experts_per_rank).sum(dim=0)
569
+ gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0], ),
570
+ dtype=np.int32)
571
+ s = 0
572
+ for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
573
+ gatherd_idxs[s:s + k] = i % self.experts_per_rank
574
+ s += k
575
+ gatherd_idxs = gatherd_idxs.argsort()
576
+ sorted_tokens = gathered_tokens[gatherd_idxs]
577
+ tokens_per_expert = tokens_per_expert_post_gather
578
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
579
+
580
+ outputs = []
581
+ start_idx = 0
582
+ for i, num_tokens in enumerate(tokens_per_expert):
583
+ end_idx = start_idx + num_tokens
584
+ if num_tokens == 0:
585
+ continue
586
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
587
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
588
+ expert_out = expert(tokens_for_this_expert)
589
+ outputs.append(expert_out)
590
+ start_idx = end_idx
591
+
592
+ outs = torch.cat(outputs,
593
+ dim=0) if len(outputs) else sorted_tokens.new_empty(0)
594
+ if self.ep_size > 1:
595
+ new_x = torch.empty_like(outs)
596
+ new_x[gatherd_idxs] = outs
597
+ gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
598
+ dist.all_to_all(
599
+ list(gathered_tokens.split(input_split_sizes)),
600
+ list(new_x.split(output_splits)),
601
+ )
602
+ outs = gathered_tokens
603
+
604
+ new_x = torch.empty_like(outs)
605
+ new_x[idxs] = outs
606
+ final_out = (new_x.view(
607
+ *topk_ids.shape, -1).type(topk_weight.dtype).mul_(
608
+ topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
609
+ return final_out
610
+
611
+
612
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
613
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
614
+ """
615
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
616
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
617
+ """
618
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
619
+ if n_rep == 1:
620
+ return hidden_states
621
+ hidden_states = hidden_states[:, :,
622
+ None, :, :].expand(batch,
623
+ num_key_value_heads,
624
+ n_rep, slen, head_dim)
625
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
626
+ head_dim)
627
+
628
+
629
+ # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3
630
+ class DeepseekV3Attention(nn.Module):
631
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
632
+
633
+ def __init__(self,
634
+ config: DeepseekV3Config,
635
+ layer_idx: Optional[int] = None):
636
+ super().__init__()
637
+ self.config = config
638
+ self.layer_idx = layer_idx
639
+ if layer_idx is None:
640
+ logger.warning_once(
641
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
642
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
643
+ "when creating this class.")
644
+
645
+ self.attention_dropout = config.attention_dropout
646
+ self.hidden_size = config.hidden_size
647
+ self.num_heads = config.num_attention_heads
648
+
649
+ self.max_position_embeddings = config.max_position_embeddings
650
+ self.rope_theta = config.rope_theta
651
+ self.q_lora_rank = config.q_lora_rank
652
+ self.qk_rope_head_dim = config.qk_rope_head_dim
653
+ self.kv_lora_rank = config.kv_lora_rank
654
+ self.v_head_dim = config.v_head_dim
655
+ self.qk_nope_head_dim = config.qk_nope_head_dim
656
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
657
+
658
+ self.is_causal = True
659
+
660
+ if self.q_lora_rank is None:
661
+ self.q_proj = nn.Linear(self.hidden_size,
662
+ self.num_heads * self.q_head_dim,
663
+ bias=False)
664
+ else:
665
+ self.q_a_proj = nn.Linear(self.hidden_size,
666
+ config.q_lora_rank,
667
+ bias=config.attention_bias)
668
+ self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
669
+ self.q_b_proj = nn.Linear(config.q_lora_rank,
670
+ self.num_heads * self.q_head_dim,
671
+ bias=False)
672
+
673
+ self.kv_a_proj_with_mqa = nn.Linear(
674
+ self.hidden_size,
675
+ config.kv_lora_rank + config.qk_rope_head_dim,
676
+ bias=config.attention_bias,
677
+ )
678
+ self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
679
+ self.kv_b_proj = nn.Linear(
680
+ config.kv_lora_rank,
681
+ self.num_heads *
682
+ (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
683
+ bias=False,
684
+ )
685
+
686
+ self.o_proj = nn.Linear(
687
+ self.num_heads * self.v_head_dim,
688
+ self.hidden_size,
689
+ bias=config.attention_bias,
690
+ )
691
+ self._init_rope()
692
+
693
+ self.softmax_scale = self.q_head_dim**(-0.5)
694
+ if self.config.rope_scaling is not None:
695
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
696
+ scaling_factor = self.config.rope_scaling["factor"]
697
+ if mscale_all_dim:
698
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
699
+ self.softmax_scale = self.softmax_scale * mscale * mscale
700
+
701
+ def _init_rope(self):
702
+ if self.config.rope_scaling is None:
703
+ self.rotary_emb = DeepseekV3RotaryEmbedding(
704
+ self.qk_rope_head_dim,
705
+ max_position_embeddings=self.max_position_embeddings,
706
+ base=self.rope_theta,
707
+ )
708
+ else:
709
+ scaling_type = self.config.rope_scaling["type"]
710
+ scaling_factor = self.config.rope_scaling["factor"]
711
+ if scaling_type == "linear":
712
+ self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
713
+ self.qk_rope_head_dim,
714
+ max_position_embeddings=self.max_position_embeddings,
715
+ scaling_factor=scaling_factor,
716
+ base=self.rope_theta,
717
+ )
718
+ elif scaling_type == "dynamic":
719
+ self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(
720
+ self.qk_rope_head_dim,
721
+ max_position_embeddings=self.max_position_embeddings,
722
+ scaling_factor=scaling_factor,
723
+ base=self.rope_theta,
724
+ )
725
+ elif scaling_type == "yarn":
726
+ kwargs = {
727
+ key: self.config.rope_scaling[key]
728
+ for key in [
729
+ "original_max_position_embeddings",
730
+ "beta_fast",
731
+ "beta_slow",
732
+ "mscale",
733
+ "mscale_all_dim",
734
+ ] if key in self.config.rope_scaling
735
+ }
736
+ self.rotary_emb = DeepseekV3YarnRotaryEmbedding(
737
+ self.qk_rope_head_dim,
738
+ max_position_embeddings=self.max_position_embeddings,
739
+ scaling_factor=scaling_factor,
740
+ base=self.rope_theta,
741
+ **kwargs,
742
+ )
743
+ else:
744
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
745
+
746
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
747
+ return (tensor.view(bsz, seq_len, self.num_heads,
748
+ self.v_head_dim).transpose(1, 2).contiguous())
749
+
750
+ def forward(
751
+ self,
752
+ hidden_states: torch.Tensor,
753
+ attention_mask: Optional[torch.Tensor] = None,
754
+ position_ids: Optional[torch.LongTensor] = None,
755
+ past_key_value: Optional[Cache] = None,
756
+ output_attentions: bool = False,
757
+ use_cache: bool = False,
758
+ **kwargs,
759
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
760
+ Optional[Tuple[torch.Tensor]]]:
761
+ if "padding_mask" in kwargs:
762
+ warnings.warn(
763
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
764
+ )
765
+ bsz, q_len, _ = hidden_states.size()
766
+
767
+ if self.q_lora_rank is None:
768
+ q = self.q_proj(hidden_states)
769
+ else:
770
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
771
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
772
+ q_nope, q_pe = torch.split(
773
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
774
+
775
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
776
+ compressed_kv, k_pe = torch.split(
777
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
778
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
779
+ kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
780
+ bsz, q_len, self.num_heads,
781
+ self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
782
+
783
+ k_nope, value_states = torch.split(
784
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
785
+ kv_seq_len = value_states.shape[-2]
786
+ if past_key_value is not None:
787
+ if self.layer_idx is None:
788
+ raise ValueError(
789
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
790
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
791
+ "with a layer index.")
792
+ kv_seq_len += get_usable_length(past_key_value, kv_seq_len,
793
+ self.layer_idx)
794
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
795
+
796
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
797
+
798
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len,
799
+ self.q_head_dim)
800
+ query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
801
+ query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
802
+
803
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len,
804
+ self.q_head_dim)
805
+ key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
806
+ key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
807
+ if past_key_value is not None:
808
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
809
+ key_states, value_states = past_key_value.update(
810
+ key_states, value_states, self.layer_idx, cache_kwargs)
811
+
812
+ attn_weights = (
813
+ torch.matmul(query_states, key_states.transpose(2, 3)) *
814
+ self.softmax_scale)
815
+
816
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
817
+ raise ValueError(
818
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
819
+ f" {attn_weights.size()}")
820
+ assert attention_mask is not None
821
+ if attention_mask is not None:
822
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
823
+ raise ValueError(
824
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
825
+ )
826
+ attn_weights = attn_weights + attention_mask
827
+
828
+ # upcast attention to fp32
829
+ attn_weights = nn.functional.softmax(attn_weights,
830
+ dim=-1,
831
+ dtype=torch.float32).to(
832
+ query_states.dtype)
833
+ attn_weights = nn.functional.dropout(attn_weights,
834
+ p=self.attention_dropout,
835
+ training=self.training)
836
+ attn_output = torch.matmul(attn_weights, value_states)
837
+
838
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
839
+ raise ValueError(
840
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
841
+ f" {attn_output.size()}")
842
+
843
+ attn_output = attn_output.transpose(1, 2).contiguous()
844
+
845
+ attn_output = attn_output.reshape(bsz, q_len,
846
+ self.num_heads * self.v_head_dim)
847
+
848
+ attn_output = self.o_proj(attn_output)
849
+
850
+ if not output_attentions:
851
+ attn_weights = None
852
+
853
+ return attn_output, attn_weights, past_key_value
854
+
855
+
856
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3
857
+ class DeepseekV3FlashAttention2(DeepseekV3Attention):
858
+ """
859
+ DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays
860
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
861
+ flash attention and deal with padding tokens in case the input contains any of them.
862
+ """
863
+
864
+ def __init__(self, *args, **kwargs):
865
+ super().__init__(*args, **kwargs)
866
+
867
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
868
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
869
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
870
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10(
871
+ )
872
+
873
+ def forward(
874
+ self,
875
+ hidden_states: torch.Tensor,
876
+ attention_mask: Optional[torch.LongTensor] = None,
877
+ position_ids: Optional[torch.LongTensor] = None,
878
+ past_key_value: Optional[Cache] = None,
879
+ output_attentions: bool = False,
880
+ use_cache: bool = False,
881
+ **kwargs,
882
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
883
+ Optional[Tuple[torch.Tensor]]]:
884
+ # DeepseekV3FlashAttention2 attention does not support output_attentions
885
+ if "padding_mask" in kwargs:
886
+ warnings.warn(
887
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
888
+ )
889
+
890
+ # overwrite attention_mask with padding_mask
891
+ attention_mask = kwargs.pop("padding_mask")
892
+
893
+ output_attentions = False
894
+
895
+ bsz, q_len, _ = hidden_states.size()
896
+
897
+ if self.q_lora_rank is None:
898
+ q = self.q_proj(hidden_states)
899
+ else:
900
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
901
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
902
+ q_nope, q_pe = torch.split(
903
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
904
+
905
+ # Flash attention requires the input to have the shape
906
+ # batch_size x seq_length x head_dim x hidden_dim
907
+ # therefore we just need to keep the original shape
908
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
909
+ compressed_kv, k_pe = torch.split(
910
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
911
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
912
+ kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
913
+ bsz, q_len, self.num_heads,
914
+ self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
915
+
916
+ k_nope, value_states = torch.split(
917
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
918
+ kv_seq_len = value_states.shape[-2]
919
+
920
+ kv_seq_len = value_states.shape[-2]
921
+ if past_key_value is not None:
922
+ kv_seq_len += get_usable_length(past_key_value, kv_seq_len,
923
+ self.layer_idx)
924
+
925
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
926
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
927
+
928
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len,
929
+ self.q_head_dim)
930
+ query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
931
+ query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
932
+
933
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len,
934
+ self.q_head_dim)
935
+ key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
936
+ key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
937
+
938
+ if self.q_head_dim != self.v_head_dim:
939
+ value_states = F.pad(value_states,
940
+ [0, self.q_head_dim - self.v_head_dim])
941
+
942
+ if past_key_value is not None:
943
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
944
+ key_states, value_states = past_key_value.update(
945
+ key_states, value_states, self.layer_idx, cache_kwargs)
946
+
947
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
948
+ # to be able to avoid many of these transpose/reshape/view.
949
+ query_states = query_states.transpose(1, 2)
950
+ key_states = key_states.transpose(1, 2)
951
+ value_states = value_states.transpose(1, 2)
952
+
953
+ dropout_rate = self.attention_dropout if self.training else 0.0
954
+
955
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
956
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
957
+ # cast them back in the correct dtype just to be sure everything works as expected.
958
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
959
+ # in fp32. (DeepseekV3RMSNorm handles it correctly)
960
+
961
+ input_dtype = query_states.dtype
962
+ if input_dtype == torch.float32:
963
+ # Handle the case where the model is quantized
964
+ if hasattr(self.config, "_pre_quantization_dtype"):
965
+ target_dtype = self.config._pre_quantization_dtype
966
+ elif torch.is_autocast_enabled():
967
+ target_dtype = torch.get_autocast_gpu_dtype()
968
+ else:
969
+ target_dtype = (self.q_proj.weight.dtype if self.q_lora_rank
970
+ is None else self.q_a_proj.weight.dtype)
971
+
972
+ logger.warning_once(
973
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
974
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
975
+ f" {target_dtype}.")
976
+
977
+ query_states = query_states.to(target_dtype)
978
+ key_states = key_states.to(target_dtype)
979
+ value_states = value_states.to(target_dtype)
980
+
981
+ attn_output = self._flash_attention_forward(
982
+ query_states,
983
+ key_states,
984
+ value_states,
985
+ attention_mask,
986
+ q_len,
987
+ dropout=dropout_rate,
988
+ softmax_scale=self.softmax_scale,
989
+ )
990
+ if self.q_head_dim != self.v_head_dim:
991
+ attn_output = attn_output[:, :, :, :self.v_head_dim]
992
+
993
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads *
994
+ self.v_head_dim).contiguous()
995
+ attn_output = self.o_proj(attn_output)
996
+
997
+ if not output_attentions:
998
+ attn_weights = None
999
+
1000
+ return attn_output, attn_weights, past_key_value
1001
+
1002
+ def _flash_attention_forward(
1003
+ self,
1004
+ query_states,
1005
+ key_states,
1006
+ value_states,
1007
+ attention_mask,
1008
+ query_length,
1009
+ dropout=0.0,
1010
+ softmax_scale=None,
1011
+ ):
1012
+ """
1013
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1014
+ first unpad the input, then computes the attention scores and pad the final attention scores.
1015
+
1016
+ Args:
1017
+ query_states (`torch.Tensor`):
1018
+ Input query states to be passed to Flash Attention API
1019
+ key_states (`torch.Tensor`):
1020
+ Input key states to be passed to Flash Attention API
1021
+ value_states (`torch.Tensor`):
1022
+ Input value states to be passed to Flash Attention API
1023
+ attention_mask (`torch.Tensor`):
1024
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1025
+ position of padding tokens and 1 for the position of non-padding tokens.
1026
+ dropout (`int`, *optional*):
1027
+ Attention dropout
1028
+ softmax_scale (`float`, *optional*):
1029
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1030
+ """
1031
+ if not self._flash_attn_uses_top_left_mask:
1032
+ causal = self.is_causal
1033
+ else:
1034
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.
1035
+ causal = self.is_causal and query_length != 1
1036
+
1037
+ # Contains at least one padding token in the sequence
1038
+ if attention_mask is not None:
1039
+ batch_size = query_states.shape[0]
1040
+ (
1041
+ query_states,
1042
+ key_states,
1043
+ value_states,
1044
+ indices_q,
1045
+ cu_seq_lens,
1046
+ max_seq_lens,
1047
+ ) = self._upad_input(query_states, key_states, value_states,
1048
+ attention_mask, query_length)
1049
+
1050
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1051
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1052
+
1053
+ attn_output_unpad = flash_attn_varlen_func(
1054
+ query_states,
1055
+ key_states,
1056
+ value_states,
1057
+ cu_seqlens_q=cu_seqlens_q,
1058
+ cu_seqlens_k=cu_seqlens_k,
1059
+ max_seqlen_q=max_seqlen_in_batch_q,
1060
+ max_seqlen_k=max_seqlen_in_batch_k,
1061
+ dropout_p=dropout,
1062
+ softmax_scale=softmax_scale,
1063
+ causal=causal,
1064
+ )
1065
+
1066
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
1067
+ query_length)
1068
+ else:
1069
+ attn_output = flash_attn_func(
1070
+ query_states,
1071
+ key_states,
1072
+ value_states,
1073
+ dropout,
1074
+ softmax_scale=softmax_scale,
1075
+ causal=causal,
1076
+ )
1077
+
1078
+ return attn_output
1079
+
1080
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
1081
+ query_length):
1082
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
1083
+ attention_mask)
1084
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1085
+
1086
+ key_layer = index_first_axis(
1087
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
1088
+ head_dim),
1089
+ indices_k,
1090
+ )
1091
+ value_layer = index_first_axis(
1092
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
1093
+ head_dim),
1094
+ indices_k,
1095
+ )
1096
+ if query_length == kv_seq_len:
1097
+ query_layer = index_first_axis(
1098
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
1099
+ head_dim),
1100
+ indices_k,
1101
+ )
1102
+ cu_seqlens_q = cu_seqlens_k
1103
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
1104
+ indices_q = indices_k
1105
+ elif query_length == 1:
1106
+ max_seqlen_in_batch_q = 1
1107
+ cu_seqlens_q = torch.arange(
1108
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
1109
+ ) # There is a memcpy here, that is very bad.
1110
+ indices_q = cu_seqlens_q[:-1]
1111
+ query_layer = query_layer.squeeze(1)
1112
+ else:
1113
+ # The -q_len: slice assumes left padding.
1114
+ attention_mask = attention_mask[:, -query_length:]
1115
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1116
+ query_layer, attention_mask)
1117
+
1118
+ return (
1119
+ query_layer,
1120
+ key_layer,
1121
+ value_layer,
1122
+ indices_q,
1123
+ (cu_seqlens_q, cu_seqlens_k),
1124
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1125
+ )
1126
+
1127
+
1128
+ ATTENTION_CLASSES = {
1129
+ "eager": DeepseekV3Attention,
1130
+ "flash_attention_2": DeepseekV3FlashAttention2,
1131
+ }
1132
+
1133
+
1134
+ class DeepseekV3DecoderLayer(nn.Module):
1135
+
1136
+ def __init__(self, config: DeepseekV3Config, layer_idx: int):
1137
+ super().__init__()
1138
+ self.hidden_size = config.hidden_size
1139
+
1140
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
1141
+ config=config, layer_idx=layer_idx)
1142
+
1143
+ self.mlp = (DeepseekV3MoE(config) if
1144
+ (config.n_routed_experts is not None
1145
+ and layer_idx >= config.first_k_dense_replace
1146
+ and layer_idx % config.moe_layer_freq == 0) else
1147
+ DeepseekV3MLP(config))
1148
+ self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size,
1149
+ eps=config.rms_norm_eps)
1150
+ self.post_attention_layernorm = DeepseekV3RMSNorm(
1151
+ config.hidden_size, eps=config.rms_norm_eps)
1152
+
1153
+ def forward(
1154
+ self,
1155
+ hidden_states: torch.Tensor,
1156
+ attention_mask: Optional[torch.Tensor] = None,
1157
+ position_ids: Optional[torch.LongTensor] = None,
1158
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1159
+ output_attentions: Optional[bool] = False,
1160
+ use_cache: Optional[bool] = False,
1161
+ **kwargs,
1162
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
1163
+ torch.FloatTensor]]]:
1164
+ """
1165
+ Args:
1166
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1167
+ attention_mask (`torch.FloatTensor`, *optional*):
1168
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1169
+ query_sequence_length, key_sequence_length)` if default attention is used.
1170
+ output_attentions (`bool`, *optional*):
1171
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1172
+ returned tensors for more detail.
1173
+ use_cache (`bool`, *optional*):
1174
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1175
+ (see `past_key_values`).
1176
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1177
+ """
1178
+ if "padding_mask" in kwargs:
1179
+ warnings.warn(
1180
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1181
+ )
1182
+ residual = hidden_states
1183
+
1184
+ hidden_states = self.input_layernorm(hidden_states)
1185
+
1186
+ # Self Attention
1187
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1188
+ hidden_states=hidden_states,
1189
+ attention_mask=attention_mask,
1190
+ position_ids=position_ids,
1191
+ past_key_value=past_key_value,
1192
+ output_attentions=output_attentions,
1193
+ use_cache=use_cache,
1194
+ **kwargs,
1195
+ )
1196
+ hidden_states = residual + hidden_states
1197
+
1198
+ # Fully Connected
1199
+ residual = hidden_states
1200
+ hidden_states = self.post_attention_layernorm(hidden_states)
1201
+ hidden_states = self.mlp(hidden_states)
1202
+ hidden_states = residual + hidden_states
1203
+
1204
+ outputs = (hidden_states, )
1205
+
1206
+ if output_attentions:
1207
+ outputs += (self_attn_weights, )
1208
+
1209
+ if use_cache:
1210
+ outputs += (present_key_value, )
1211
+
1212
+ return outputs
1213
+
1214
+
1215
+ DeepseekV3_START_DOCSTRING = r"""
1216
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1217
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1218
+ etc.)
1219
+
1220
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1221
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1222
+ and behavior.
1223
+
1224
+ Parameters:
1225
+ config ([`DeepseekV3Config`]):
1226
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1227
+ load the weights associated with the model, only the configuration. Check out the
1228
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1229
+ """
1230
+
1231
+
1232
+ @add_start_docstrings(
1233
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1234
+ DeepseekV3_START_DOCSTRING,
1235
+ )
1236
+ class DeepseekV3PreTrainedModel(PreTrainedModel):
1237
+ config_class = DeepseekV3Config
1238
+ base_model_prefix = "model"
1239
+ supports_gradient_checkpointing = True
1240
+ _no_split_modules = ["DeepseekV3DecoderLayer"]
1241
+ _skip_keys_device_placement = "past_key_values"
1242
+ _supports_flash_attn_2 = True
1243
+ _supports_cache_class = True
1244
+
1245
+ def _init_weights(self, module):
1246
+ std = self.config.initializer_range
1247
+ if isinstance(module, nn.Linear):
1248
+ module.weight.data.normal_(mean=0.0, std=std)
1249
+ if module.bias is not None:
1250
+ module.bias.data.zero_()
1251
+ elif isinstance(module, nn.Embedding):
1252
+ module.weight.data.normal_(mean=0.0, std=std)
1253
+ if module.padding_idx is not None:
1254
+ module.weight.data[module.padding_idx].zero_()
1255
+
1256
+
1257
+ DeepseekV3_INPUTS_DOCSTRING = r"""
1258
+ Args:
1259
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1260
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1261
+ it.
1262
+
1263
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1264
+ [`PreTrainedTokenizer.__call__`] for details.
1265
+
1266
+ [What are input IDs?](../glossary#input-ids)
1267
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1268
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1269
+
1270
+ - 1 for tokens that are **not masked**,
1271
+ - 0 for tokens that are **masked**.
1272
+
1273
+ [What are attention masks?](../glossary#attention-mask)
1274
+
1275
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1276
+ [`PreTrainedTokenizer.__call__`] for details.
1277
+
1278
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1279
+ `past_key_values`).
1280
+
1281
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1282
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1283
+ information on the default strategy.
1284
+
1285
+ - 1 indicates the head is **not masked**,
1286
+ - 0 indicates the head is **masked**.
1287
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1288
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1289
+ config.n_positions - 1]`.
1290
+
1291
+ [What are position IDs?](../glossary#position-ids)
1292
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1293
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1294
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1295
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1296
+
1297
+ Two formats are allowed:
1298
+ - a [`~cache_utils.Cache`] instance;
1299
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1300
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1301
+ cache format.
1302
+
1303
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1304
+ legacy cache format will be returned.
1305
+
1306
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1307
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1308
+ of shape `(batch_size, sequence_length)`.
1309
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1310
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1311
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1312
+ model's internal embedding lookup matrix.
1313
+ use_cache (`bool`, *optional*):
1314
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1315
+ `past_key_values`).
1316
+ output_attentions (`bool`, *optional*):
1317
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1318
+ tensors for more detail.
1319
+ output_hidden_states (`bool`, *optional*):
1320
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1321
+ more detail.
1322
+ return_dict (`bool`, *optional*):
1323
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1324
+ """
1325
+
1326
+
1327
+ @add_start_docstrings(
1328
+ "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1329
+ DeepseekV3_START_DOCSTRING,
1330
+ )
1331
+ class DeepseekV3Model(DeepseekV3PreTrainedModel):
1332
+ """
1333
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
1334
+
1335
+ Args:
1336
+ config: DeepseekV3Config
1337
+ """
1338
+
1339
+ def __init__(self, config: DeepseekV3Config):
1340
+ super().__init__(config)
1341
+ self.padding_idx = config.pad_token_id
1342
+ self.vocab_size = config.vocab_size
1343
+
1344
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
1345
+ self.padding_idx)
1346
+ self.layers = nn.ModuleList([
1347
+ DeepseekV3DecoderLayer(config, layer_idx)
1348
+ for layer_idx in range(config.num_hidden_layers)
1349
+ ])
1350
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1351
+ self.norm = DeepseekV3RMSNorm(config.hidden_size,
1352
+ eps=config.rms_norm_eps)
1353
+
1354
+ self.gradient_checkpointing = False
1355
+ # Initialize weights and apply final processing
1356
+ self.post_init()
1357
+
1358
+ def get_input_embeddings(self):
1359
+ return self.embed_tokens
1360
+
1361
+ def set_input_embeddings(self, value):
1362
+ self.embed_tokens = value
1363
+
1364
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1365
+ def forward(
1366
+ self,
1367
+ input_ids: torch.LongTensor = None,
1368
+ attention_mask: Optional[torch.Tensor] = None,
1369
+ position_ids: Optional[torch.LongTensor] = None,
1370
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1371
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1372
+ use_cache: Optional[bool] = None,
1373
+ output_attentions: Optional[bool] = None,
1374
+ output_hidden_states: Optional[bool] = None,
1375
+ return_dict: Optional[bool] = None,
1376
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1377
+ output_attentions = (output_attentions if output_attentions is not None
1378
+ else self.config.output_attentions)
1379
+ output_hidden_states = (output_hidden_states
1380
+ if output_hidden_states is not None else
1381
+ self.config.output_hidden_states)
1382
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1383
+
1384
+ return_dict = (return_dict if return_dict is not None else
1385
+ self.config.use_return_dict)
1386
+
1387
+ # retrieve input_ids and inputs_embeds
1388
+ if input_ids is not None and inputs_embeds is not None:
1389
+ raise ValueError(
1390
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1391
+ )
1392
+ elif input_ids is not None:
1393
+ batch_size, seq_length = input_ids.shape[:2]
1394
+ elif inputs_embeds is not None:
1395
+ batch_size, seq_length = inputs_embeds.shape[:2]
1396
+ else:
1397
+ raise ValueError(
1398
+ "You have to specify either input_ids or inputs_embeds")
1399
+
1400
+ past_key_values_length = 0
1401
+ if use_cache:
1402
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1403
+ if use_legacy_cache:
1404
+ past_key_values = DynamicCache.from_legacy_cache(
1405
+ past_key_values)
1406
+ past_key_values_length = get_usable_length(past_key_values,
1407
+ seq_length)
1408
+
1409
+ if position_ids is None:
1410
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1411
+ position_ids = torch.arange(
1412
+ past_key_values_length,
1413
+ seq_length + past_key_values_length,
1414
+ dtype=torch.long,
1415
+ device=device,
1416
+ )
1417
+ position_ids = position_ids.unsqueeze(0)
1418
+
1419
+ if inputs_embeds is None:
1420
+ inputs_embeds = self.embed_tokens(input_ids)
1421
+
1422
+ if self._use_flash_attention_2:
1423
+ # 2d mask is passed through the layers
1424
+ attention_mask = (attention_mask if
1425
+ (attention_mask is not None
1426
+ and 0 in attention_mask) else None)
1427
+ else:
1428
+ # 4d mask is passed through the layers
1429
+ attention_mask = _prepare_4d_causal_attention_mask(
1430
+ attention_mask,
1431
+ (batch_size, seq_length),
1432
+ inputs_embeds,
1433
+ past_key_values_length,
1434
+ )
1435
+
1436
+ # embed positions
1437
+ hidden_states = inputs_embeds
1438
+
1439
+ # decoder layers
1440
+ all_hidden_states = () if output_hidden_states else None
1441
+ all_self_attns = () if output_attentions else None
1442
+ next_decoder_cache = None
1443
+
1444
+ for decoder_layer in self.layers:
1445
+ if output_hidden_states:
1446
+ all_hidden_states += (hidden_states, )
1447
+
1448
+ layer_outputs = decoder_layer(
1449
+ hidden_states,
1450
+ attention_mask=attention_mask,
1451
+ position_ids=position_ids,
1452
+ past_key_value=past_key_values,
1453
+ output_attentions=output_attentions,
1454
+ use_cache=use_cache,
1455
+ )
1456
+
1457
+ hidden_states = layer_outputs[0]
1458
+
1459
+ if use_cache:
1460
+ next_decoder_cache = layer_outputs[
1461
+ 2 if output_attentions else 1]
1462
+
1463
+ if output_attentions:
1464
+ all_self_attns += (layer_outputs[1], )
1465
+
1466
+ hidden_states = self.norm(hidden_states)
1467
+
1468
+ # add hidden states from the last decoder layer
1469
+ if output_hidden_states:
1470
+ all_hidden_states += (hidden_states, )
1471
+
1472
+ next_cache = None
1473
+ if use_cache:
1474
+ next_cache = (next_decoder_cache.to_legacy_cache()
1475
+ if use_legacy_cache else next_decoder_cache)
1476
+ if not return_dict:
1477
+ return tuple(
1478
+ v for v in
1479
+ [hidden_states, next_cache, all_hidden_states, all_self_attns]
1480
+ if v is not None)
1481
+ return BaseModelOutputWithPast(
1482
+ last_hidden_state=hidden_states,
1483
+ past_key_values=next_cache,
1484
+ hidden_states=all_hidden_states,
1485
+ attentions=all_self_attns,
1486
+ )
1487
+
1488
+
1489
+ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1490
+ _tied_weights_keys = ["lm_head.weight"]
1491
+
1492
+ def __init__(self, config):
1493
+ super().__init__(config)
1494
+ self.model = DeepseekV3Model(config)
1495
+ self.vocab_size = config.vocab_size
1496
+ self.lm_head = nn.Linear(config.hidden_size,
1497
+ config.vocab_size,
1498
+ bias=False)
1499
+
1500
+ # Initialize weights and apply final processing
1501
+ self.post_init()
1502
+
1503
+ def get_input_embeddings(self):
1504
+ return self.model.embed_tokens
1505
+
1506
+ def set_input_embeddings(self, value):
1507
+ self.model.embed_tokens = value
1508
+
1509
+ def get_output_embeddings(self):
1510
+ return self.lm_head
1511
+
1512
+ def set_output_embeddings(self, new_embeddings):
1513
+ self.lm_head = new_embeddings
1514
+
1515
+ def set_decoder(self, decoder):
1516
+ self.model = decoder
1517
+
1518
+ def get_decoder(self):
1519
+ return self.model
1520
+
1521
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1522
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast,
1523
+ config_class=_CONFIG_FOR_DOC)
1524
+ def forward(
1525
+ self,
1526
+ input_ids: torch.LongTensor = None,
1527
+ attention_mask: Optional[torch.Tensor] = None,
1528
+ position_ids: Optional[torch.LongTensor] = None,
1529
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1530
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1531
+ labels: Optional[torch.LongTensor] = None,
1532
+ use_cache: Optional[bool] = None,
1533
+ output_attentions: Optional[bool] = None,
1534
+ output_hidden_states: Optional[bool] = None,
1535
+ return_dict: Optional[bool] = None,
1536
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1537
+ r"""
1538
+ Args:
1539
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1540
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
1541
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1542
+ (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
1543
+
1544
+ Returns:
1545
+
1546
+ Example:
1547
+
1548
+ ```python
1549
+ >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
1550
+
1551
+ >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1552
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1553
+
1554
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1555
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1556
+
1557
+ >>> # Generate
1558
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1559
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1560
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1561
+ ```"""
1562
+ output_attentions = (output_attentions if output_attentions is not None
1563
+ else self.config.output_attentions)
1564
+ output_hidden_states = (output_hidden_states
1565
+ if output_hidden_states is not None else
1566
+ self.config.output_hidden_states)
1567
+ return_dict = (return_dict if return_dict is not None else
1568
+ self.config.use_return_dict)
1569
+
1570
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1571
+ outputs = self.model(
1572
+ input_ids=input_ids,
1573
+ attention_mask=attention_mask,
1574
+ position_ids=position_ids,
1575
+ past_key_values=past_key_values,
1576
+ inputs_embeds=inputs_embeds,
1577
+ use_cache=use_cache,
1578
+ output_attentions=output_attentions,
1579
+ output_hidden_states=output_hidden_states,
1580
+ return_dict=return_dict,
1581
+ )
1582
+
1583
+ hidden_states = outputs[0]
1584
+ logits = self.lm_head(hidden_states)
1585
+ logits = logits.float()
1586
+
1587
+ loss = None
1588
+ if labels is not None:
1589
+ # Shift so that tokens < n predict n
1590
+ shift_logits = logits[..., :-1, :].contiguous()
1591
+ shift_labels = labels[..., 1:].contiguous()
1592
+ # Flatten the tokens
1593
+ loss_fct = CrossEntropyLoss()
1594
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1595
+ shift_labels = shift_labels.view(-1)
1596
+ # Enable model parallelism
1597
+ shift_labels = shift_labels.to(shift_logits.device)
1598
+ loss = loss_fct(shift_logits, shift_labels)
1599
+
1600
+ if not return_dict:
1601
+ output = (logits, ) + outputs[1:]
1602
+ return (loss, ) + output if loss is not None else output
1603
+
1604
+ return CausalLMOutputWithPast(
1605
+ loss=loss,
1606
+ logits=logits,
1607
+ past_key_values=outputs.past_key_values,
1608
+ hidden_states=outputs.hidden_states,
1609
+ attentions=outputs.attentions,
1610
+ )
1611
+
1612
+ def prepare_inputs_for_generation(
1613
+ self,
1614
+ input_ids,
1615
+ past_key_values=None,
1616
+ attention_mask=None,
1617
+ inputs_embeds=None,
1618
+ **kwargs,
1619
+ ):
1620
+ if past_key_values is not None:
1621
+ if isinstance(past_key_values, Cache):
1622
+ cache_length = past_key_values.get_seq_length()
1623
+ # seen_tokens 可能在某些 transformers 版本中不存在,使用 getattr 安全访问
1624
+ past_length = getattr(past_key_values, 'seen_tokens',
1625
+ cache_length)
1626
+ max_cache_length = past_key_values.get_max_length()
1627
+ else:
1628
+ cache_length = past_length = past_key_values[0][0].shape[2]
1629
+ max_cache_length = None
1630
+
1631
+ # Keep only the unprocessed tokens:
1632
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1633
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1634
+ # input)
1635
+ if (attention_mask is not None
1636
+ and attention_mask.shape[1] > input_ids.shape[1]):
1637
+ input_ids = input_ids[:, -(attention_mask.shape[1] -
1638
+ past_length):]
1639
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1640
+ # input_ids based on the past_length.
1641
+ elif past_length < input_ids.shape[1]:
1642
+ input_ids = input_ids[:, past_length:]
1643
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1644
+
1645
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1646
+ if (max_cache_length is not None and attention_mask is not None
1647
+ and cache_length + input_ids.shape[1] > max_cache_length):
1648
+ attention_mask = attention_mask[:, -max_cache_length:]
1649
+
1650
+ position_ids = kwargs.get("position_ids", None)
1651
+ if attention_mask is not None and position_ids is None:
1652
+ # create position_ids on the fly for batch generation
1653
+ position_ids = attention_mask.long().cumsum(-1) - 1
1654
+ position_ids.masked_fill_(attention_mask == 0, 1)
1655
+ if past_key_values:
1656
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1657
+
1658
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1659
+ if inputs_embeds is not None and past_key_values is None:
1660
+ model_inputs = {"inputs_embeds": inputs_embeds}
1661
+ else:
1662
+ model_inputs = {"input_ids": input_ids}
1663
+
1664
+ model_inputs.update({
1665
+ "position_ids": position_ids,
1666
+ "past_key_values": past_key_values,
1667
+ "use_cache": kwargs.get("use_cache"),
1668
+ "attention_mask": attention_mask,
1669
+ })
1670
+ return model_inputs
1671
+
1672
+ @staticmethod
1673
+ def _reorder_cache(past_key_values, beam_idx):
1674
+ reordered_past = ()
1675
+ for layer_past in past_key_values:
1676
+ reordered_past += (tuple(
1677
+ past_state.index_select(0, beam_idx.to(past_state.device))
1678
+ for past_state in layer_past), )
1679
+ return reordered_past
1680
+
1681
+
1682
+ @add_start_docstrings(
1683
+ """
1684
+ The DeepseekV3 Model transformer with a sequence classification head on top (linear layer).
1685
+
1686
+ [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1687
+ (e.g. GPT-2) do.
1688
+
1689
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1690
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1691
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1692
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1693
+ each row of the batch).
1694
+ """,
1695
+ DeepseekV3_START_DOCSTRING,
1696
+ )
1697
+ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
1698
+
1699
+ def __init__(self, config):
1700
+ super().__init__(config)
1701
+ self.num_labels = config.num_labels
1702
+ self.model = DeepseekV3Model(config)
1703
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1704
+
1705
+ # Initialize weights and apply final processing
1706
+ self.post_init()
1707
+
1708
+ def get_input_embeddings(self):
1709
+ return self.model.embed_tokens
1710
+
1711
+ def set_input_embeddings(self, value):
1712
+ self.model.embed_tokens = value
1713
+
1714
+ @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1715
+ def forward(
1716
+ self,
1717
+ input_ids: torch.LongTensor = None,
1718
+ attention_mask: Optional[torch.Tensor] = None,
1719
+ position_ids: Optional[torch.LongTensor] = None,
1720
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1721
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1722
+ labels: Optional[torch.LongTensor] = None,
1723
+ use_cache: Optional[bool] = None,
1724
+ output_attentions: Optional[bool] = None,
1725
+ output_hidden_states: Optional[bool] = None,
1726
+ return_dict: Optional[bool] = None,
1727
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1728
+ r"""
1729
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1730
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
1731
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1732
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1733
+ """
1734
+ return_dict = (return_dict if return_dict is not None else
1735
+ self.config.use_return_dict)
1736
+
1737
+ transformer_outputs = self.model(
1738
+ input_ids,
1739
+ attention_mask=attention_mask,
1740
+ position_ids=position_ids,
1741
+ past_key_values=past_key_values,
1742
+ inputs_embeds=inputs_embeds,
1743
+ use_cache=use_cache,
1744
+ output_attentions=output_attentions,
1745
+ output_hidden_states=output_hidden_states,
1746
+ return_dict=return_dict,
1747
+ )
1748
+ hidden_states = transformer_outputs[0]
1749
+ logits = self.score(hidden_states)
1750
+
1751
+ if input_ids is not None:
1752
+ batch_size = input_ids.shape[0]
1753
+ else:
1754
+ batch_size = inputs_embeds.shape[0]
1755
+
1756
+ if self.config.pad_token_id is None and batch_size != 1:
1757
+ raise ValueError(
1758
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1759
+ )
1760
+ if self.config.pad_token_id is None:
1761
+ sequence_lengths = -1
1762
+ else:
1763
+ if input_ids is not None:
1764
+ sequence_lengths = (torch.eq(
1765
+ input_ids, self.config.pad_token_id).int().argmax(-1) -
1766
+ 1).to(logits.device)
1767
+ else:
1768
+ sequence_lengths = -1
1769
+
1770
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device),
1771
+ sequence_lengths]
1772
+
1773
+ loss = None
1774
+ if labels is not None:
1775
+ labels = labels.to(logits.device)
1776
+ if self.config.problem_type is None:
1777
+ if self.num_labels == 1:
1778
+ self.config.problem_type = "regression"
1779
+ elif self.num_labels > 1 and (labels.dtype == torch.long
1780
+ or labels.dtype == torch.int):
1781
+ self.config.problem_type = "single_label_classification"
1782
+ else:
1783
+ self.config.problem_type = "multi_label_classification"
1784
+
1785
+ if self.config.problem_type == "regression":
1786
+ loss_fct = MSELoss()
1787
+ if self.num_labels == 1:
1788
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1789
+ else:
1790
+ loss = loss_fct(pooled_logits, labels)
1791
+ elif self.config.problem_type == "single_label_classification":
1792
+ loss_fct = CrossEntropyLoss()
1793
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels),
1794
+ labels.view(-1))
1795
+ elif self.config.problem_type == "multi_label_classification":
1796
+ loss_fct = BCEWithLogitsLoss()
1797
+ loss = loss_fct(pooled_logits, labels)
1798
+ if not return_dict:
1799
+ output = (pooled_logits, ) + transformer_outputs[1:]
1800
+ return ((loss, ) + output) if loss is not None else output
1801
+
1802
+ return SequenceClassifierOutputWithPast(
1803
+ loss=loss,
1804
+ logits=pooled_logits,
1805
+ past_key_values=transformer_outputs.past_key_values,
1806
+ hidden_states=transformer_outputs.hidden_states,
1807
+ attentions=transformer_outputs.attentions,
1808
+ )