| import torch |
| import torch.nn as nn |
| from huggingface_hub import PyTorchModelHubMixin |
|
|
| class ConvBNRelu(nn.Module): |
| """ |
| Building block used in HiDDeN network. Is a sequence of Convolution, Batch Normalization, and ReLU activation |
| """ |
|
|
| def __init__(self, channels_in, channels_out): |
| super(ConvBNRelu, self).__init__() |
|
|
| self.layers = nn.Sequential( |
| nn.Conv2d(channels_in, channels_out, 3, stride=1, padding=1), |
| nn.BatchNorm2d(channels_out, eps=1e-3), |
| nn.GELU() |
| ) |
|
|
| def forward(self, x): |
| return self.layers(x) |
|
|
|
|
| class HiddenDecoder(nn.Module): |
| """ |
| Decoder module. Receives a watermarked image and extracts the watermark. |
| """ |
|
|
| def __init__(self, num_blocks, num_bits, channels, redundancy=1): |
| super(HiddenDecoder, self).__init__() |
|
|
| layers = [ConvBNRelu(3, channels)] |
| for _ in range(num_blocks - 1): |
| layers.append(ConvBNRelu(channels, channels)) |
|
|
| layers.append(ConvBNRelu(channels, num_bits * redundancy)) |
| layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1))) |
| self.layers = nn.Sequential(*layers) |
|
|
| self.linear = nn.Linear(num_bits * redundancy, num_bits * redundancy) |
|
|
| self.num_bits = num_bits |
| self.redundancy = redundancy |
|
|
| def forward(self, img_w): |
| x = self.layers(img_w) |
| x = x.squeeze(-1).squeeze(-1) |
| x = self.linear(x) |
|
|
| x = x.view(-1, self.num_bits, self.redundancy) |
| x = torch.sum(x, dim=-1) |
|
|
| return x |
|
|
|
|
| class MsgExtractor(nn.Module, PyTorchModelHubMixin): |
| def __init__(self, hidden_decoder: nn.Module, in_features: int, out_features: int): |
| super().__init__() |
| self.hidden_decoder = hidden_decoder |
| self.head = nn.Linear(in_features, out_features) |
|
|
| def forward(self, x): |
| x = self.hidden_decoder(x) |
| x = self.head(x) |
| return x |
|
|