channel_attention_map_1D=channel_attention_module(feature_map) # @ c new_feature_map_CAM: new feature map after channel attention module new_feature_map_CAM=elementwise_mul(channel_attention_map_1D,feature_map) spatial_attention_map_2D=spatial_attention_module(new_feature_map) # @ c new_feature_map_SAM: new feature map after spatial attention module new_feature_map_SA=elementwise_mul(spatial_attention_map_2D,new_feature_map) # Note: use broadcasting in elementwise multiplication
# @ Channel Attention Module # @ c feat_max_in_CAM: feature after max pool in channel attention module feat_max_in_CAM=max_pool(feature_map) # @ c feat_avg_in_CAM: feature after average pool in channel attention module feat_avg_in_CAM=avg_pool(feature_map) # @ c MLP_shared_net: define MLP shared net MLP_shared_net(): FC1 ReLU FC2 feat_max_MLP_in_CAM=MLP_shared_net(feat_max_in_CAM) feat_avg_MLP_in_CAM=sigmoid(MLP_shared_net(feat_avg_in_CAM)) channel_attention_1D=elementwise_sum(feat_max_MLP_in_CAM,feat_avg_MLP_in_CAM) feat_after_CAM=elementwise_mul(channel_attention_1D,feature_map)
# @ Spatial Attention Module feat_max_in_SAM=max_pool(feat_after_CAM) feat_avg_in_SAM=avg_pool(feat_after_CAM) feat_cat_in_SAM=concat(feat_max_in_SAM,feat_avg_in_SAM) # @ Define conv net conv_net(): conv2d(kernel=(7,7)) sigmoid() feat_conv_in_SAM=conv_net(feat_cat_in_SAM)