| import torch |
| from torch import nn |
| from addict import Dict |
|
|
| from rscd.models.decoderheads.pixel_decoder.msdeformattn import MSDeformAttnPixelDecoder4ScalesFASeg |
| from rscd.models.decoderheads.transformer_decoder import MultiScaleMaskedTransformerDecoder_OurDH_v5 |
|
|
| from torch.nn import functional as F |
|
|
| class MaskFormerHead(nn.Module): |
| def __init__(self, input_shape, |
| num_classes = 1, |
| num_queries = 10, |
| dec_layers = 10 |
| ): |
| super().__init__() |
| self.num_classes = num_classes |
| self.num_queries = num_queries |
| self.dec_layers = dec_layers |
| self.pixel_decoder = self.pixel_decoder_init(input_shape) |
| self.predictor = self.predictor_init() |
| |
| def pixel_decoder_init(self, input_shape): |
| common_stride = 4 |
| transformer_dropout = 0 |
| transformer_nheads = 8 |
| transformer_dim_feedforward = 1024 |
| transformer_enc_layers = 4 |
| conv_dim = 256 |
| mask_dim = 256 |
| transformer_in_features = ["res3", "res4", "res5"] |
|
|
| pixel_decoder = MSDeformAttnPixelDecoder4ScalesFASeg(input_shape, |
| transformer_dropout, |
| transformer_nheads, |
| transformer_dim_feedforward, |
| transformer_enc_layers, |
| conv_dim, |
| mask_dim, |
| transformer_in_features, |
| common_stride) |
| return pixel_decoder |
|
|
| def predictor_init(self): |
| in_channels = 256 |
| num_classes = self.num_classes |
| hidden_dim = 256 |
| num_queries = self.num_queries |
| nheads = 8 |
| dim_feedforward = 1024 |
| dec_layers = self.dec_layers - 1 |
| pre_norm = False |
| mask_dim = 256 |
| enforce_input_project = False |
| mask_classification = True |
| predictor = MultiScaleMaskedTransformerDecoder_OurDH_v5(in_channels, |
| num_classes, |
| mask_classification, |
| hidden_dim, |
| num_queries, |
| nheads, |
| dim_feedforward, |
| dec_layers, |
| pre_norm, |
| mask_dim, |
| enforce_input_project) |
| return predictor |
|
|
| def forward(self, features, mask=None): |
| mask_features, transformer_encoder_features, multi_scale_features, pos_list_2d = self.pixel_decoder.forward_features(features) |
| predictions = self.predictor(multi_scale_features, mask_features, mask, pos_list_2d) |
| return predictions |
|
|
| def dsconv_3x3(in_channel, out_channel): |
| return nn.Sequential( |
| nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel), |
| nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=1), |
| nn.BatchNorm2d(out_channel), |
| nn.ReLU(inplace=True) |
| ) |
|
|
| class SaELayer(nn.Module): |
| def __init__(self, in_channel, reduction=32): |
| super(SaELayer, self).__init__() |
| assert in_channel>=reduction and in_channel%reduction==0,'invalid in_channel in SaElayer' |
| self.reduction = reduction |
| self.cardinality=4 |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| |
| self.fc1 = nn.Sequential( |
| nn.Linear(in_channel,in_channel//self.reduction, bias=False), |
| nn.ReLU(inplace=True) |
| ) |
| |
| self.fc2 = nn.Sequential( |
| nn.Linear(in_channel, in_channel // self.reduction, bias=False), |
| nn.ReLU(inplace=True) |
| ) |
| |
| self.fc3 = nn.Sequential( |
| nn.Linear(in_channel, in_channel // self.reduction, bias=False), |
| nn.ReLU(inplace=True) |
| ) |
| |
| self.fc4 = nn.Sequential( |
| nn.Linear(in_channel, in_channel // self.reduction, bias=False), |
| nn.ReLU(inplace=True) |
| ) |
|
|
| self.fc = nn.Sequential( |
| nn.Linear(in_channel//self.reduction*self.cardinality, in_channel, bias=False), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, x): |
| b, c, _, _ = x.size() |
| y = self.avg_pool(x).view(b, c) |
| y1 = self.fc1(y) |
| y2 = self.fc2(y) |
| y3 = self.fc3(y) |
| y4 = self.fc4(y) |
| y_concate = torch.cat([y1,y2,y3,y4],dim=1) |
| y_ex_dim = self.fc(y_concate).view(b,c,1,1) |
|
|
| return y_ex_dim.expand_as(x) |
|
|
| class TFF(nn.Module): |
| def __init__(self, in_channel, out_channel): |
| super(TFF, self).__init__() |
| self.catconvA = dsconv_3x3(in_channel * 2, in_channel) |
| self.catconvB = dsconv_3x3(in_channel * 2, in_channel) |
| self.catconv = dsconv_3x3(in_channel * 2, out_channel) |
| self.convA = nn.Conv2d(in_channel, 1, 1) |
| self.convB = nn.Conv2d(in_channel, 1, 1) |
| self.sigmoid = nn.Sigmoid() |
| self.senetv2 = SaELayer(in_channel) |
|
|
| def forward(self, xA, xB): |
| x_diff = xA - xB |
| x_weight = self.senetv2(x_diff) |
|
|
| x_diffA = self.catconvA(torch.cat([x_diff, xA], dim=1)) |
| x_diffB = self.catconvB(torch.cat([x_diff, xB], dim=1)) |
|
|
| A_weight = self.sigmoid(self.convA(x_diffA)) |
| B_weight = self.sigmoid(self.convB(x_diffB)) |
|
|
| xA = A_weight * xA * x_weight |
| xB = B_weight * xB * x_weight |
|
|
| x = self.catconv(torch.cat([xA, xB], dim=1)) |
|
|
| return x |
| |
| class CDMask(nn.Module): |
| def __init__(self, channels, |
| num_classes = 1, |
| num_queries = 10, |
| dec_layers = 14): |
| super().__init__() |
| self.channels = channels |
| self.backbone_feature_shape = dict() |
| for i, channel in enumerate(self.channels): |
| self.backbone_feature_shape[f'res{i+2}'] = Dict({'channel': channel, 'stride': 2**(i+2)}) |
|
|
| self.tff1 = TFF(self.channels[0], self.channels[0]) |
| self.tff2 = TFF(self.channels[1], self.channels[1]) |
| self.tff3 = TFF(self.channels[2], self.channels[2]) |
| self.tff4 = TFF(self.channels[3], self.channels[3]) |
|
|
| self.sem_seg_head = MaskFormerHead(self.backbone_feature_shape, num_classes, num_queries, dec_layers) |
|
|
| def semantic_inference(self, mask_cls, mask_pred): |
| |
| mask_cls = F.softmax(mask_cls, dim=-1)[...,1:] |
| mask_pred = mask_pred.sigmoid() |
| semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred).detach() |
| b, c, h, w = semseg.shape |
| for i in range(b): |
| for j in range(c): |
| minval = semseg[i, j].min() |
| maxval = semseg[i, j].max() |
| semseg[i, j] = (semseg[i, j] - minval) / (maxval - minval) |
| return semseg |
|
|
| def forward(self, inputs): |
| featuresA, featuresB =inputs |
| features = [self.tff1(featuresA[0], featuresB[0]), |
| self.tff2(featuresA[1], featuresB[1]), |
| self.tff3(featuresA[2], featuresB[2]), |
| self.tff4(featuresA[3], featuresB[3]),] |
| features = { |
| 'res2': features[0], |
| 'res3': features[1], |
| 'res4': features[2], |
| 'res5': features[3] |
| } |
|
|
| outputs = self.sem_seg_head(features) |
|
|
| mask_cls_results = outputs["pred_logits"] |
| mask_pred_results = outputs["pred_masks"] |
|
|
| mask_pred_results = F.interpolate( |
| mask_pred_results, |
| scale_factor=(4,4), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| pred_masks = self.semantic_inference(mask_cls_results, mask_pred_results) |
|
|
| return [pred_masks, outputs] |
| |
|
|