| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from typing import Any, cast |
| |
|
| | from .attention import ParallelAttentionBlock, KVCache |
| | from .phi2_configuration import Phi2Config |
| |
|
| |
|
| | class Phi2PreTrainedModel(PreTrainedModel): |
| | config_class = Phi2Config |
| | supports_gradient_checkpointing = False |
| | |
| |
|
| | def __init__(self, config: Phi2Config): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | def _init_weights(self, module: nn.Module) -> None: |
| | |
| | if isinstance(module, (nn.Linear,)): |
| | module.weight.data.normal_(mean=0.0, std=self.config.weight_initialization_range) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=self.config.weight_initialization_range) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| | elif isinstance(module, nn.LayerNorm): |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids: torch.LongTensor, |
| | past_key_values: KVCache | None = None, |
| | key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None, |
| | **kwargs, |
| | ) -> dict[str, Any]: |
| | kv_cache = past_key_values |
| | if not kv_cache: |
| | kv_cache = KVCache( |
| | max_seqlen=self.config.initial_cos_sin_cache_len, |
| | max_batch_size=input_ids.shape[0], |
| | seqlen_offset=0, |
| | batch_size_offset=0, |
| | kv_block_map={}, |
| | lengths_per_sample=None, |
| | ) |
| | else: |
| | |
| | kv_cache.seqlen_offset = input_ids.shape[1] - 1 |
| | input_ids = cast(torch.LongTensor, input_ids[:, -1].unsqueeze(-1)) |
| |
|
| | return { |
| | "input_ids": input_ids, |
| | "kv_cache": kv_cache, |
| | "key_padding_mask": key_padding_mask, |
| | } |
| |
|
| |
|
| | class Embedding(nn.Module): |
| | """Token embedding with dropout.""" |
| |
|
| | def __init__( |
| | self, |
| | vocab_size: int, |
| | d_embedding: int, |
| | embd_pdrop: float, |
| | ) -> None: |
| | super().__init__() |
| | self.embeddings = nn.Embedding(vocab_size, d_embedding) |
| | self.dropout = nn.Dropout(embd_pdrop) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | ) -> torch.FloatTensor: |
| | x = self.embeddings( |
| | input_ids.view(-1, input_ids.size()[-1]) |
| | ) |
| | x = self.dropout(x) |
| | return x |
| |
|
| |
|
| | class Phi2Model(Phi2PreTrainedModel): |
| | def __init__(self, config: Phi2Config) -> None: |
| | super().__init__(config) |
| | self.embedding = Embedding( |
| | vocab_size=config.vocab_size, |
| | d_embedding=config.d_embedding, |
| | embd_pdrop=config.embd_pdrop, |
| | ) |
| | self.parallel_blocks = nn.ModuleList([ |
| | ParallelAttentionBlock( |
| | resid_pdrop=config.resid_pdrop, |
| | layer_norm_epsilon=config.layer_norm_epsilon, |
| | d_embedding=config.d_embedding, |
| | n_attn_heads=config.n_attn_heads, |
| | block_n=i, |
| | initial_cos_sin_cache_len=config.initial_cos_sin_cache_len, |
| | attn_pdrop=config.attn_pdrop, |
| | use_flash_rotary=config.use_flash_rotary, |
| | use_flash_attn=config.use_flash_attn, |
| | use_fused_dense=config.use_fused_dense, |
| | checkpointing=config.checkpointing, |
| | ) |
| | for i in range(config.n_attn_blocks) |
| | ]) |
| | self.gradient_checkpointing_disable() |
| | self.post_init() |
| |
|
| | """ |
| | def get_input_embeddings(self) -> nn.Embedding: |
| | return self.embedding.embeddings |
| | |
| | def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: |
| | self.embedding.embeddings = new_embeddings |
| | """ |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | kv_cache: KVCache | None = None, |
| | key_padding_mask: torch.BoolTensor | None = None, |
| | ) -> torch.FloatTensor: |
| | x = self.embedding(input_ids) |
| | for block in self.parallel_blocks: |
| | x = block( |
| | x, |
| | kv_cache=kv_cache, |
| | key_padding_mask=key_padding_mask, |
| | ) |
| | return x |
| |
|
| |
|
| | class Phi2ModelForCausalLM(Phi2PreTrainedModel): |
| | def __init__(self, config: Phi2Config) -> None: |
| | super().__init__(config) |
| | self.model = Phi2Model(config) |
| | self.lm_head_layer_norm = nn.LayerNorm(config.d_embedding, eps=config.layer_norm_epsilon) |
| | self.lm_head_linear = nn.Linear(config.d_embedding, config.vocab_size) |
| | self.loss_fn = nn.CrossEntropyLoss() |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | kv_cache: KVCache | None = None, |
| | key_padding_mask: torch.BoolTensor | None = None, |
| | labels: torch.LongTensor | None = None, |
| | **kwargs, |
| | ) -> CausalLMOutputWithPast: |
| | x = self.model(input_ids, kv_cache=kv_cache, key_padding_mask=key_padding_mask) |
| | x = self.lm_head_layer_norm(x) |
| | logits = self.lm_head_linear(x).to(torch.float32) |
| | loss = ( |
| | self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1)) |
| | if labels is not None |
| | else None |
| | ) |
| | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=kv_cache) |
| |
|