| | import torch |
| | from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig |
| | from transformers.models.auto.auto_factory import _BaseAutoModelClass |
| | from open_r1.rmt.MemoryCell import MemoryCell |
| | from open_r1.rmt.RecurrentWrapper import RecurrentWrapper |
| | from open_r1.rmt.PreTrainedRMTConfig import PreTrainedRMTConfig |
| |
|
| |
|
| | |
| | class RecurrentMemoryTransformer(PreTrainedModel): |
| | """ |
| | Recurrent Memory Transformer モデルクラス |
| | 長い文脈をセグメント単位で処理し、メモリを使って情報を保持するトランスフォーマーモデル |
| | """ |
| | |
| | config_class = PreTrainedRMTConfig |
| | auto_model_class = "AutoModelForCausalLM" |
| | |
| | |
| | _keys_to_ignore_on_load_missing = [r"position_ids"] |
| | |
| | |
| | AUTO_MAP = { |
| | "AutoModelForCausalLM": "RecurrentMemoryTransformer", |
| | } |
| | |
| | def __init__(self, config, base_model=None): |
| | """ |
| | 初期化 |
| | |
| | Parameters |
| | ---------- |
| | config : PreTrainedRMTConfig |
| | モデルの設定 |
| | base_model : PreTrainedModel, optional |
| | ベースとなるトランスフォーマーモデル |
| | """ |
| | super().__init__(config) |
| | |
| | |
| | if base_model is None: |
| | |
| | if not hasattr(config, "base_model_type"): |
| | raise ValueError("configにbase_model_typeが指定されていません。RMTの設定にはベースモデルタイプが必要です。") |
| | base_model_type = config.base_model_type |
| | |
| | |
| | base_config = AutoConfig.from_pretrained(base_model_type) |
| | |
| | |
| | rmt_specific_params = ['model_type', 'is_memory_all', 'max_n_segments', 'input_seg_len', |
| | 'output_seg_len', 'align', 'num_mem_tokens', 'base_model_type'] |
| | for key, value in config.__dict__.items(): |
| | if key not in rmt_specific_params and not key.startswith('_'): |
| | setattr(base_config, key, value) |
| | |
| | |
| | base_model = AutoModelForCausalLM.from_config(base_config) |
| | |
| | |
| | memory_cell = MemoryCell(base_model, config.num_mem_tokens) |
| | self.recurrent_wrapper = RecurrentWrapper( |
| | memory_cell=memory_cell, |
| | is_memory_all=config.is_memory_all, |
| | max_n_segments=config.max_n_segments, |
| | input_seg_len=config.input_seg_len, |
| | output_seg_len=config.output_seg_len, |
| | align=config.align |
| | ) |
| | |
| | def get_base_model(self): |
| | """ |
| | ベースモデルを取得 |
| | """ |
| | return self.recurrent_wrapper.memory_cell.model |
| | |
| | def forward(self, input_ids=None, attention_mask=None, labels=None, labels_mask=None, |
| | inputs_embeds=None, output_attentions=None, output_hidden_states=None): |
| | """ |
| | モデルの順伝播 |
| | |
| | Parameters |
| | ---------- |
| | input_ids : torch.Tensor, optional |
| | 入力テンソル |
| | attention_mask : torch.Tensor, optional |
| | アテンションマスク |
| | labels : torch.Tensor, optional |
| | ラベルテンソル |
| | labels_mask : torch.Tensor, optional |
| | ラベルマスク |
| | inputs_embeds : torch.Tensor, optional |
| | 入力埋め込み |
| | output_attentions : bool, optional |
| | アテンション重みを出力するかどうか |
| | output_hidden_states : bool, optional |
| | 隠れ状態を出力するかどうか |
| | """ |
| | forward_kwargs = {} |
| | if input_ids is not None: |
| | forward_kwargs["input_ids"] = input_ids |
| | if labels is not None: |
| | forward_kwargs["labels"] = labels |
| | if attention_mask is not None: |
| | forward_kwargs["attention_mask"] = attention_mask |
| | if labels_mask is not None: |
| | forward_kwargs["labels_mask"] = labels_mask |
| | if inputs_embeds is not None: |
| | forward_kwargs["inputs_embeds"] = inputs_embeds |
| | if output_attentions is not None: |
| | forward_kwargs["output_attentions"] = output_attentions |
| | if output_hidden_states is not None: |
| | forward_kwargs["output_hidden_states"] = output_hidden_states |
| | |
| | |
| | |
| | |
| | out = self.recurrent_wrapper.forward(**forward_kwargs) |
| | """ |
| | # デバッグ出力を削除(または必要に応じてコメント化) |
| | # print(out["loss"]) |
| | |
| | # 分散環境で損失が二重計算されないよう、ワールドサイズで割る |
| | # これは処理済みの場合は不要なので、環境変数などで制御することも可能 |
| | if torch.distributed.is_initialized() and "loss" in out and out["loss"] is not None: |
| | # 既にDeepSpeedが処理している可能性があるため、確認が必要 |
| | # テスト目的で一時的に追加(実際の環境に合わせて調整が必要) |
| | # world_size = torch.distributed.get_world_size() |
| | # out["loss"] = out["loss"] / world_size |
| | pass |
| | """ |
| | return out |
| | |
| | def generate(self, **kwargs): |
| | """ |
| | テキスト生成 |
| | """ |
| | return self.recurrent_wrapper.generate(**kwargs) |
| | |
| | def generate_with_tokenizer(self, tokenizer, input_text, **kwargs): |
| | """ |
| | トークナイザーを用いたテキスト生成 |
| | """ |
| | return self.recurrent_wrapper.generate_with_tokenizer(tokenizer, input_text, **kwargs) |
| | |
| | def get_input_embeddings(self): |
| | """ |
| | 入力埋め込みを取得 |
| | """ |
| | return self.get_base_model().get_input_embeddings() |
| | |
| | def set_input_embeddings(self, embeddings): |
| | """ |
| | 入力埋め込みを設定 |
| | """ |
| | self.get_base_model().set_input_embeddings(embeddings) |
| | |
| | def get_output_embeddings(self): |
| | """ |
| | 出力埋め込みを取得 |
| | """ |
| | return self.get_base_model().get_output_embeddings() |
| | |
| | def resize_token_embeddings(self, new_num_tokens): |
| | """ |
| | トークン埋め込みのサイズを変更 |
| | """ |
| | self.get_base_model().resize_token_embeddings(new_num_tokens) |
| | return self.get_input_embeddings() |
| |
|
| | RecurrentMemoryTransformer.register_for_auto_class("AutoModelForCausalLM") |