from keras.utils import conv_utils from keras import backend as K from keras.engine import InputSpec from keras.layers import Conv2D class PConv2D(Conv2D): def __init__(self, *args, n_channels=3, mono=False, **kwargs): super().__init__(*args, **kwargs) self.input_spec = [InputSpec(ndim=4), InputSpec(ndim=4)] def build(self, input_shape): """Adapted from original _Conv() layer of Keras param input_shape: list of dimensions for [img, mask] """ if self.data_format == 'channels_first': channel_axis = 1 else: channel_axis = -1 if input_shape[0][channel_axis] is None: raise ValueError('The channel dimension of the inputs should be defined. Found `None`.') self.input_dim = input_shape[0][channel_axis] # Image kernel kernel_shape = self.kernel_size + (self.input_dim, self.filters) self.kernel = self.add_weight(shape=kernel_shape, initializer=self.kernel_initializer, name='img_kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) # Mask kernel self.kernel_mask = K.ones(shape=self.kernel_size + (self.input_dim, self.filters)) if self.use_bias: self.bias = self.add_weight(shape=(self.filters,), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, constraint=self.bias_constraint) else: self.bias = None self.built = True def call(self, inputs, mask=None): ''' We will be using the Keras conv2d method, and essentially we have to do here is multiply the mask with the input X, before we apply the convolutions. For the mask itself, we apply convolutions with all weights set to 1. Subsequently, we set all mask values >0 to 1, and otherwise 0 ''' # Both image and mask must be supplied if type(inputs) is not list or len(inputs) != 2: raise Exception('PartialConvolution2D must be called on a list of two tensors [img, mask]. Instead got: ' + str(inputs)) # Create normalization. Slight change here compared to paper, using mean mask value instead of sum normalization = K.mean(inputs[1], axis=[1,2], keepdims=True) normalization = K.repeat_elements(normalization, inputs[1].shape[1], axis=1) normalization = K.repeat_elements(normalization, inputs[1].shape[2], axis=2) # Apply convolutions to image img_output = K.conv2d( (inputs[0]*inputs[1]) / normalization, self.kernel, strides=self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate ) # Apply convolutions to mask mask_output = K.conv2d( inputs[1], self.kernel_mask, strides=self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate ) # Where something happened, set 1, otherwise 0 mask_output = K.cast(K.greater(mask_output, 0), 'float32') # Apply bias only to the image (if chosen to do so) if self.use_bias: img_output = K.bias_add( img_output, self.bias, data_format=self.data_format) # Apply activations on the image if self.activation is not None: img_output = self.activation(img_output) return [img_output, mask_output] def compute_output_shape(self, input_shape): if self.data_format == 'channels_last': space = input_shape[0][1:-1] new_space = [] for i in range(len(space)): new_dim = conv_utils.conv_output_length( space[i], self.kernel_size[i], padding=self.padding, stride=self.strides[i], dilation=self.dilation_rate[i]) new_space.append(new_dim) new_shape = (input_shape[0][0],) + tuple(new_space) + (self.filters,) return [new_shape, new_shape] if self.data_format == 'channels_first': space = input_shape[2:] new_space = [] for i in range(len(space)): new_dim = conv_utils.conv_output_length( space[i], self.kernel_size[i], padding=self.padding, stride=self.strides[i], dilation=self.dilation_rate[i]) new_space.append(new_dim) new_shape = (input_shape[0], self.filters) + tuple(new_space) return [new_shape, new_shape]