| | """ |
| | Image Encoder using pre-trained ResNet50. |
| | Implements the visual feature extraction module from the paper. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torchvision.models import resnet50, ResNet50_Weights |
| |
|
| |
|
| | class ImageEncoder(nn.Module): |
| | """ |
| | Image encoder using ResNet50 with custom final layer. |
| | Critical: Final layer initialized with zeros as per paper. |
| | """ |
| | |
| | def __init__(self, config, pretrained_weights_path: str = None): |
| | """ |
| | Initialize image encoder. |
| | |
| | Args: |
| | config: Configuration object |
| | pretrained_weights_path: Path to ResNet50 weights file |
| | """ |
| | super().__init__() |
| | self.config = config |
| | |
| | |
| | self.resnet = resnet50(weights=None) |
| | |
| | |
| | if pretrained_weights_path: |
| | state_dict = torch.load(pretrained_weights_path, weights_only = False) |
| | self.resnet.load_state_dict(state_dict) |
| | print(f"Loaded ResNet50 weights from {pretrained_weights_path}") |
| | |
| | |
| | self.resnet.fc = nn.Identity() |
| | |
| | |
| | |
| | self.projection = nn.Linear(config.resnet_out_dim, config.hidden_dim) |
| | nn.init.zeros_(self.projection.weight) |
| | nn.init.zeros_(self.projection.bias) |
| | |
| | print("Initialized image encoder final layer with zeros") |
| | |
| | def forward(self, images: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Forward pass through ResNet50. |
| | |
| | Args: |
| | images: Input images [batch_size, num_masks, 1, H, W] |
| | |
| | Returns: |
| | Visual features [batch_size, num_masks, hidden_dim] |
| | """ |
| | batch_size, num_masks, C, H, W = images.shape |
| | |
| | |
| | images_flat = images.view(batch_size * num_masks, C, H, W) |
| | |
| | |
| | if C == 1: |
| | images_flat = images_flat.repeat(1, 3, 1, 1) |
| | |
| | |
| | features = self.resnet(images_flat) |
| | |
| | |
| | features = self.projection(features) |
| | |
| | |
| | features = features.view(batch_size, num_masks, self.config.hidden_dim) |
| | |
| | return features |
| |
|