| | from typing import Any, Optional, Tuple |
| |
|
| | from einops import rearrange |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from .models.kv_caching import KeysValues |
| | from .models.slicer import Embedder, Head |
| | from .models.transformer import Transformer |
| |
|
| | class WorldModel(nn.Module): |
| | def __init__(self, config: dict) -> None: |
| | super().__init__() |
| | self.obs_vocab_size, self.act_vocab_size = config["vocab_size"], config["act_vocab_size"] |
| | self.config = config |
| | self.transformer = Transformer(config) |
| |
|
| | all_but_last_obs_tokens_pattern = torch.ones(config["tokens_per_block"]) |
| | all_but_last_obs_tokens_pattern[-2] = 0 |
| | act_tokens_pattern = torch.zeros(self.config["tokens_per_block"]) |
| | act_tokens_pattern[-1] = 1 |
| | obs_tokens_pattern = 1 - act_tokens_pattern |
| |
|
| | self.pos_emb = nn.Embedding(config["max_tokens"], config["embed_dim"]) |
| |
|
| | self.embedder = Embedder( |
| | max_blocks=config["max_blocks"], |
| | block_masks=[act_tokens_pattern, obs_tokens_pattern], |
| | embedding_tables=nn.ModuleList([nn.Embedding(self.act_vocab_size, config["embed_dim"]), nn.Embedding(self.obs_vocab_size, config["embed_dim"])]) |
| | ) |
| |
|
| | self.head_observations = Head( |
| | max_blocks=config["max_blocks"], |
| | block_mask=all_but_last_obs_tokens_pattern, |
| | head_module=nn.Sequential( |
| | nn.Linear(config["embed_dim"], config["embed_dim"]), |
| | nn.ReLU(), |
| | nn.Linear(config["embed_dim"], self.obs_vocab_size) |
| | ) |
| | ) |
| |
|
| | self.head_rewards = Head( |
| | max_blocks=config["max_blocks"], |
| | block_mask=act_tokens_pattern, |
| | head_module=nn.Sequential( |
| | nn.Linear(config["embed_dim"], config["embed_dim"]), |
| | nn.ReLU(), |
| | nn.Linear(config["embed_dim"], 3) |
| | ) |
| | ) |
| |
|
| | self.head_ends = Head( |
| | max_blocks=config["max_blocks"], |
| | block_mask=act_tokens_pattern, |
| | head_module=nn.Sequential( |
| | nn.Linear(config["embed_dim"], config["embed_dim"]), |
| | nn.ReLU(), |
| | nn.Linear(config["embed_dim"], 2) |
| | ) |
| | ) |
| |
|
| | def __repr__(self) -> str: |
| | return "world_model" |
| |
|
| | def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValues] = None) -> dict: |
| |
|
| | num_steps = tokens.size(1) |
| | assert num_steps <= self.config["max_tokens"] |
| | prev_steps = 0 if past_keys_values is None else past_keys_values.size |
| |
|
| | sequences = self.embedder(tokens, num_steps, prev_steps) + self.pos_emb(prev_steps + torch.arange(num_steps, device=tokens.device)) |
| |
|
| | x = self.transformer(sequences, past_keys_values) |
| |
|
| | logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) |
| | logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) |
| | logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) |
| | return { |
| | "output_sequence": x, |
| | "logits_observations": logits_observations, |
| | "logits_rewards": logits_rewards, |
| | "logits_ends": logits_ends |
| |
|
| | } |
| |
|
| | def generate_empty_keys_values(self, n= 1): |
| |
|
| | values = self.transformer.generate_empty_keys_values(n=n, max_tokens= self.config["max_tokens"]) |
| | return values |
| | |
| | def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | assert torch.all(ends.sum(dim=1) <= 1) |
| | mask_fill = torch.logical_not(mask_padding) |
| | labels_observations = rearrange(obs_tokens.masked_fill(mask_fill.unsqueeze(-1).expand_as(obs_tokens), -100), 'b t k -> b (t k)')[:, 1:] |
| | labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() |
| | labels_ends = ends.masked_fill(mask_fill, -100) |
| | return labels_observations.reshape(-1), labels_rewards.reshape(-1), labels_ends.reshape(-1) |
| |
|