diff --git a/config.py b/config.py new file mode 100644 index 0000000..9ef6140 --- /dev/null +++ b/config.py @@ -0,0 +1,29 @@ +import argparse + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1', True): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0', False): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + +def get_args(): + parser = argparse.ArgumentParser(description='') + + #Input output folders settings + parser.add_argument('--decensor_input_path', dest='decensor_input_path', default='./decensor_input/', help='input images to be decensored by decensor.py path') + parser.add_argument('--decensor_output_path', dest='decensor_output_path', default='./decensor_output/', help='output images generated from running decensor.py path') + + #Decensor settings + parser.add_argument('--mask_color_red', dest='mask_color_red', default=0, help='red channel of mask color in decensoring') + parser.add_argument('--mask_color_green', dest='mask_color_green', default=255, help='green channel of mask color in decensoring') + parser.add_argument('--mask_color_blue', dest='mask_color_blue', default=0, help='blue channel of mask color in decensoring') + + parser.add_argument('--is_mosaic', dest='is_mosaic', default='False', type=str2bool, help='true if image has mosaic censoring, false otherwise') + + args = parser.parse_args() + return args + +if __name__ == '__main__': + get_args() \ No newline at end of file diff --git a/decensor.py b/decensor.py new file mode 100644 index 0000000..801850e --- /dev/null +++ b/decensor.py @@ -0,0 +1,165 @@ +import numpy as np +from PIL import Image +import os + +from copy import deepcopy + +import config +from libs.pconv_hybrid_model import PConvUnet +from libs.flood_fill import find_regions, expand_bounding + +class Decensor(): + + def __init__(self): + self.args = config.get_args() + self.decensor_mosaic = self.args.is_mosaic + + self.mask_color = [self.args.mask_color_red/255.0, self.args.mask_color_green/255.0, self.args.mask_color_blue/255.0] + + if not os.path.exists(self.args.decensor_output_path): + os.makedirs(self.args.decensor_output_path) + + self.load_model() + + def get_mask(self,ori, width, height): + mask = np.zeros(ori.shape, np.uint8) + #count = 0 + #TODO: change to iterate over all images in batch when implementing batches + for row in range(height): + for col in range(width): + if np.array_equal(ori[0][row][col], self.mask_color): + mask[0, row, col] = 1 + return 1-mask + + def load_model(self): + self.model = PConvUnet(weight_filepath='data/logs/') + self.model.load( + r"./models/model.h5", + train_bn=False, + lr=0.00005 + ) + + def decensor_all_images_in_folder(self): + #load model once at beginning and reuse same model + #self.load_model() + + subdir = self.args.decensor_input_path + files = os.listdir(subdir) + + #convert all images into np arrays and put them in a list + for file in files: + #print(file) + file_path = os.path.join(subdir, file) + if os.path.isfile(file_path) and os.path.splitext(file_path)[1] == ".png": + print("Decensoring the image {file_path}".format(file_path)) + censored_img = Image.open(file_path) + self.decensor_image(censored_img, file) + + #decensors one image at a time + #TODO: decensor all cropped parts of the same image in a batch (then i need input for ori an array of those images and make additional changes) + def decensor_image(self,ori, file_name): + width, height = ori.size + #save the alpha channel if the image has an alpha channel + has_alpha = False + alpha_channel = None + if (ori.mode == "RGBA"): + has_alpha = True + alpha_channel = np.asarray(ori)[:,:,3] + alpha_channel = np.expand_dims(alpha_channel, axis =-1) + ori = ori.convert('RGB') + + ori_array = np.asarray(ori) + ori_array = np.array(ori_array / 255.0) + ori_array = np.expand_dims(ori_array, axis = 0) + + mask = self.get_mask(ori_array, width, height) + + regions = find_regions(ori) + print("Found {region_count} censored regions in this image!".format(region_count = len(regions))) + + if len(regions) == 0 and not self.decensor_mosaic: + print("No green colored regions detected!") + return + + output_img_array = ori_array[0].copy() + + for region_counter, region in enumerate(regions, 1): + bounding_box = expand_bounding(ori, region) + crop_img = ori.crop(bounding_box) + #crop_img.show() + #convert mask back to image + mask_reshaped = mask[0,:,:,:] * 255.0 + mask_img = Image.fromarray(mask_reshaped.astype('uint8')) + #resize the cropped images + crop_img = crop_img.resize((512, 512)) + crop_img_array = np.asarray(crop_img) + crop_img_array = crop_img_array / 255.0 + crop_img_array = np.expand_dims(crop_img_array, axis = 0) + #resize the mask images + mask_img = mask_img.crop(bounding_box) + mask_img = mask_img.resize((512, 512)) + #mask_img.show() + #convert mask_img back to array + mask_array = np.asarray(mask_img) + mask_array = np.array(mask_array / 255.0) + #the mask has been upscaled so there will be values not equal to 0 or 1 + #mask_array[mask_array < 0.01] = 0 + mask_array[mask_array > 0] = 1 + mask_array = np.expand_dims(mask_array, axis = 0) + + # Run predictions for this batch of images + pred_img_array = self.model.predict([crop_img_array, mask_array, mask_array]) + + pred_img_array = pred_img_array * 255.0 + pred_img_array = np.squeeze(pred_img_array, axis = 0) + + #scale prediction image back to original size + bounding_width = bounding_box[2]-bounding_box[0] + bounding_height = bounding_box[3]-bounding_box[1] + #convert np array to image + + # print(bounding_width,bounding_height) + # print(pred_img_array.shape) + + pred_img = Image.fromarray(pred_img_array.astype('uint8')) + #pred_img.show() + pred_img = pred_img.resize((bounding_width, bounding_height), resample = Image.BICUBIC) + pred_img_array = np.asarray(pred_img) + pred_img_array = pred_img_array / 255.0 + + # print(pred_img_array.shape) + pred_img_array = np.expand_dims(pred_img_array, axis = 0) + + for i in range(len(ori_array)): + if self.decensor_mosaic: + output_img_array = pred_img[i] + else: + for col in range(bounding_width): + for row in range(bounding_height): + bounding_width = col + bounding_box[0] + bounding_height = row + bounding_box[1] + if (bounding_width, bounding_height) in region: + output_img_array[bounding_height][bounding_width] = pred_img_array[i,:,:,:][row][col] + print("{region_counter} out of {region_count} regions decensored.".format(region_counter=region_counter, region_count=len(regions))) + + output_img_array = output_img_array * 255.0 + + #restore the alpha channel + if has_alpha: + print(output_img_array.shape) + print(alpha_channel.shape) + output_img_array = np.concatenate((output_img_array, alpha_channel), axis = 2) + + output_img = Image.fromarray(output_img_array.astype('uint8')) + + #save the decensored image + #file_name, _ = os.path.splitext(file_name) + save_path = os.path.join(self.args.decensor_output_path, file_name) + output_img.save(save_path) + + print("Decensored image saved to {save_path}!".format(save_path=save_path)) + return + +if __name__ == '__main__': + decensor = Decensor() + decensor.decensor_all_images_in_folder() \ No newline at end of file diff --git a/flood_fill.py b/flood_fill.py new file mode 100644 index 0000000..f206bbe --- /dev/null +++ b/flood_fill.py @@ -0,0 +1,128 @@ +from PIL import Image, ImageDraw + +#find strongly connected components with the mask color +def find_regions(image): + pixel = image.load() + neighbors = dict() + width, height = image.size + for x in range(width): + for y in range(height): + if is_green(pixel[x,y]): + neighbors[x, y] = {(x,y)} + for x, y in neighbors: + candidates = (x + 1, y), (x, y + 1) + for candidate in candidates: + if candidate in neighbors: + neighbors[x, y].add(candidate) + neighbors[candidate].add((x, y)) + closed_list = set() + + def connected_component(pixel): + region = set() + open_list = {pixel} + while open_list: + pixel = open_list.pop() + closed_list.add(pixel) + open_list |= neighbors[pixel] - closed_list + region.add(pixel) + return region + + regions = [] + for pixel in neighbors: + if pixel not in closed_list: + regions.append(connected_component(pixel)) + regions.sort(key = len, reverse = True) + return regions + +# risk of box being bigger than the image +def expand_bounding(img, region, expand_factor=1.5, min_size = 256, max_size=512): + #expand bounding box to capture more context + x, y = zip(*region) + min_x, min_y, max_x, max_y = min(x), min(y), max(x), max(y) + width, height = img.size + width_center = width//2 + height_center = height//2 + bb_width = max_x - min_x + bb_height = max_y - min_y + x_center = (min_x + max_x)//2 + y_center = (min_y + max_y)//2 + current_size = max(bb_width, bb_height) + current_size = int(current_size * expand_factor) + if current_size > max_size: + current_size = max_size + elif current_size < min_size: + current_size = min_size + x1 = x_center - current_size//2 + x2 = x_center + current_size//2 + y1 = y_center - current_size//2 + y2 = y_center + current_size//2 + x1_square = x1 + y1_square = y1 + x2_square = x2 + y2_square = y2 + #move bounding boxes that are partially outside of the image inside the image + if (y1_square < 0 or y2_square > (height - 1)) and (x1_square < 0 or x2_square > (width - 1)): + #conservative square region + if x1_square < 0 and y1_square < 0: + x1_square = 0 + y1_square = 0 + x2_square = current_size + y2_square = current_size + elif x2_square > (width - 1) and y2_square > (height - 1): + x1_square = width - current_size - 1 + y1_square = 0 + x2_square = width - 1 + y2_square = current_size + elif x1_square < 0 and y2_square > (height - 1): + x1_square = 0 + y1_square = height - current_size - 1 + x2_square = current_size + y2_square = height - 1 + elif x2_square > (width - 1) and y2_square > (height - 1): + x1_square = width - current_size - 1 + y1_square = height - current_size - 1 + x2_square = width - 1 + y2_square = height - 1 + else: + x1_square = x1 + y1_square = y1 + x2_square = x2 + y2_square = y2 + else: + if x1_square < 0: + difference = x1_square + x1_square -= difference + x2_square -= difference + if x2_square > (width - 1): + difference = x2_square - width + 1 + x1_square -= difference + x2_square -= difference + if y1_square < 0: + difference = y1_square + y1_square -= difference + y2_square -= difference + if y2_square > (height - 1): + difference = y2_square - height + 1 + y1_square -= difference + y2_square -= difference + # if y1_square < 0 or y2_square > (height - 1): + + #if bounding box goes outside of the image for some reason, set bounds to original, unexpanded values + #print(width, height) + if x2_square > width or y2_square > height: + print("bounding box out of bounds!") + print(x1_square, y1_square, x2_square, y2_square) + x1_square, y1_square, x2_square, y2_square = min_x, min_y, max_x, max_y + return x1_square, y1_square, x2_square, y2_square + +def is_green(pixel): + r, g, b = pixel + return r == 0 and g == 255 and b == 0 + +if __name__ == '__main__': + image = Image.open('') + no_alpha_image = image.convert('RGB') + draw = ImageDraw.Draw(no_alpha_image) + for region in find_regions(no_alpha_image): + draw.rectangle(expand_bounding(no_alpha_image, region), outline=(0, 255, 0)) + no_alpha_image.show() \ No newline at end of file diff --git a/libs/flood_fill.py b/libs/flood_fill.py new file mode 100644 index 0000000..f206bbe --- /dev/null +++ b/libs/flood_fill.py @@ -0,0 +1,128 @@ +from PIL import Image, ImageDraw + +#find strongly connected components with the mask color +def find_regions(image): + pixel = image.load() + neighbors = dict() + width, height = image.size + for x in range(width): + for y in range(height): + if is_green(pixel[x,y]): + neighbors[x, y] = {(x,y)} + for x, y in neighbors: + candidates = (x + 1, y), (x, y + 1) + for candidate in candidates: + if candidate in neighbors: + neighbors[x, y].add(candidate) + neighbors[candidate].add((x, y)) + closed_list = set() + + def connected_component(pixel): + region = set() + open_list = {pixel} + while open_list: + pixel = open_list.pop() + closed_list.add(pixel) + open_list |= neighbors[pixel] - closed_list + region.add(pixel) + return region + + regions = [] + for pixel in neighbors: + if pixel not in closed_list: + regions.append(connected_component(pixel)) + regions.sort(key = len, reverse = True) + return regions + +# risk of box being bigger than the image +def expand_bounding(img, region, expand_factor=1.5, min_size = 256, max_size=512): + #expand bounding box to capture more context + x, y = zip(*region) + min_x, min_y, max_x, max_y = min(x), min(y), max(x), max(y) + width, height = img.size + width_center = width//2 + height_center = height//2 + bb_width = max_x - min_x + bb_height = max_y - min_y + x_center = (min_x + max_x)//2 + y_center = (min_y + max_y)//2 + current_size = max(bb_width, bb_height) + current_size = int(current_size * expand_factor) + if current_size > max_size: + current_size = max_size + elif current_size < min_size: + current_size = min_size + x1 = x_center - current_size//2 + x2 = x_center + current_size//2 + y1 = y_center - current_size//2 + y2 = y_center + current_size//2 + x1_square = x1 + y1_square = y1 + x2_square = x2 + y2_square = y2 + #move bounding boxes that are partially outside of the image inside the image + if (y1_square < 0 or y2_square > (height - 1)) and (x1_square < 0 or x2_square > (width - 1)): + #conservative square region + if x1_square < 0 and y1_square < 0: + x1_square = 0 + y1_square = 0 + x2_square = current_size + y2_square = current_size + elif x2_square > (width - 1) and y2_square > (height - 1): + x1_square = width - current_size - 1 + y1_square = 0 + x2_square = width - 1 + y2_square = current_size + elif x1_square < 0 and y2_square > (height - 1): + x1_square = 0 + y1_square = height - current_size - 1 + x2_square = current_size + y2_square = height - 1 + elif x2_square > (width - 1) and y2_square > (height - 1): + x1_square = width - current_size - 1 + y1_square = height - current_size - 1 + x2_square = width - 1 + y2_square = height - 1 + else: + x1_square = x1 + y1_square = y1 + x2_square = x2 + y2_square = y2 + else: + if x1_square < 0: + difference = x1_square + x1_square -= difference + x2_square -= difference + if x2_square > (width - 1): + difference = x2_square - width + 1 + x1_square -= difference + x2_square -= difference + if y1_square < 0: + difference = y1_square + y1_square -= difference + y2_square -= difference + if y2_square > (height - 1): + difference = y2_square - height + 1 + y1_square -= difference + y2_square -= difference + # if y1_square < 0 or y2_square > (height - 1): + + #if bounding box goes outside of the image for some reason, set bounds to original, unexpanded values + #print(width, height) + if x2_square > width or y2_square > height: + print("bounding box out of bounds!") + print(x1_square, y1_square, x2_square, y2_square) + x1_square, y1_square, x2_square, y2_square = min_x, min_y, max_x, max_y + return x1_square, y1_square, x2_square, y2_square + +def is_green(pixel): + r, g, b = pixel + return r == 0 and g == 255 and b == 0 + +if __name__ == '__main__': + image = Image.open('') + no_alpha_image = image.convert('RGB') + draw = ImageDraw.Draw(no_alpha_image) + for region in find_regions(no_alpha_image): + draw.rectangle(expand_bounding(no_alpha_image, region), outline=(0, 255, 0)) + no_alpha_image.show() \ No newline at end of file diff --git a/libs/pconv_hybrid_model.py b/libs/pconv_hybrid_model.py new file mode 100644 index 0000000..ebffbe6 --- /dev/null +++ b/libs/pconv_hybrid_model.py @@ -0,0 +1,276 @@ +import os +from datetime import datetime + +from keras.models import Model +from keras.models import load_model +from keras.optimizers import Adam +from keras.layers import Input, Conv2D, UpSampling2D, Dropout, LeakyReLU, BatchNormalization, Activation +from keras.layers.merge import Concatenate +from keras.applications import VGG16 +from keras import backend as K +from libs.pconv_layer import PConv2D + + +class PConvUnet(object): + + def __init__(self, img_rows=512, img_cols=512, weight_filepath=None): + """Create the PConvUnet. If variable image size, set img_rows and img_cols to None""" + + # Settings + self.weight_filepath = weight_filepath + self.img_rows = img_rows + self.img_cols = img_cols + assert self.img_rows >= 256, 'Height must be >256 pixels' + assert self.img_cols >= 256, 'Width must be >256 pixels' + + # Set current epoch + self.current_epoch = 0 + + # VGG layers to extract features from (first maxpooling layers, see pp. 7 of paper) + self.vgg_layers = [3, 6, 10] + + # Get the vgg16 model for perceptual loss + self.vgg = self.build_vgg() + + # Create UNet-like model + self.model = self.build_pconv_unet() + + def build_vgg(self): + """ + Load pre-trained VGG16 from keras applications + Extract features to be used in loss function from last conv layer, see architecture at: + https://github.com/keras-team/keras/blob/master/keras/applications/vgg16.py + """ + # Input image to extract features from + img = Input(shape=(self.img_rows, self.img_cols, 3)) + + # Get the vgg network from Keras applications + vgg = VGG16(weights="imagenet", include_top=False) + + # Output the first three pooling layers + vgg.outputs = [vgg.layers[i].output for i in self.vgg_layers] + + # Create model and compile + model = Model(inputs=img, outputs=vgg(img)) + model.trainable = False + model.compile(loss='mse', optimizer='adam') + + return model + + def build_pconv_unet(self, train_bn=True, lr=0.0002): + + # INPUTS + inputs_img = Input((self.img_rows, self.img_cols, 3)) + inputs_mask = Input((self.img_rows, self.img_cols, 3)) + loss_mask = Input((self.img_rows, self.img_cols, 3)) + + # ENCODER + def encoder_layer(img_in, mask_in, filters, kernel_size, bn=True): + conv, mask = PConv2D(filters, kernel_size, strides=2, padding='same')([img_in, mask_in]) + if bn: + conv = BatchNormalization(name='EncBN'+str(encoder_layer.counter))(conv, training=train_bn) + conv = Activation('relu')(conv) + encoder_layer.counter += 1 + return conv, mask + encoder_layer.counter = 0 + + e_conv1, e_mask1 = encoder_layer(inputs_img, inputs_mask, 64, 7, bn=False) + e_conv2, e_mask2 = encoder_layer(e_conv1, e_mask1, 128, 5) + e_conv3, e_mask3 = encoder_layer(e_conv2, e_mask2, 256, 5) + e_conv4, e_mask4 = encoder_layer(e_conv3, e_mask3, 512, 3) + e_conv5, e_mask5 = encoder_layer(e_conv4, e_mask4, 512, 3) + e_conv6, e_mask6 = encoder_layer(e_conv5, e_mask5, 512, 3) + e_conv7, e_mask7 = encoder_layer(e_conv6, e_mask6, 512, 3) + e_conv8, e_mask8 = encoder_layer(e_conv7, e_mask7, 512, 3) + + # DECODER + def decoder_layer(img_in, mask_in, e_conv, e_mask, filters, kernel_size, bn=True): + up_img = UpSampling2D(size=(2,2))(img_in) + up_mask = UpSampling2D(size=(2,2))(mask_in) + concat_img = Concatenate(axis=3)([e_conv,up_img]) + concat_mask = Concatenate(axis=3)([e_mask,up_mask]) + conv, mask = PConv2D(filters, kernel_size, padding='same')([concat_img, concat_mask]) + if bn: + conv = BatchNormalization()(conv) + conv = LeakyReLU(alpha=0.2)(conv) + return conv, mask + + d_conv9, d_mask9 = decoder_layer(e_conv8, e_mask8, e_conv7, e_mask7, 512, 3) + d_conv10, d_mask10 = decoder_layer(d_conv9, d_mask9, e_conv6, e_mask6, 512, 3) + d_conv11, d_mask11 = decoder_layer(d_conv10, d_mask10, e_conv5, e_mask5, 512, 3) + d_conv12, d_mask12 = decoder_layer(d_conv11, d_mask11, e_conv4, e_mask4, 512, 3) + d_conv13, d_mask13 = decoder_layer(d_conv12, d_mask12, e_conv3, e_mask3, 256, 3) + d_conv14, d_mask14 = decoder_layer(d_conv13, d_mask13, e_conv2, e_mask2, 128, 3) + d_conv15, d_mask15 = decoder_layer(d_conv14, d_mask14, e_conv1, e_mask1, 64, 3) + d_conv16, d_mask16 = decoder_layer(d_conv15, d_mask15, inputs_img, inputs_mask, 3, 3, bn=False) + outputs = Conv2D(3, 1, activation = 'sigmoid')(d_conv16) + + # Setup the model inputs / outputs + model = Model(inputs=[inputs_img, inputs_mask, loss_mask], outputs=outputs) + + # Compile the model + model.compile( + optimizer = Adam(lr=lr), + loss=self.loss_total(loss_mask) + ) + + return model + + def loss_total(self, mask): + """ + Creates a loss function which sums all the loss components + and multiplies by their weights. See paper eq. 7. + """ + def loss(y_true, y_pred): + + # Compute predicted image with non-hole pixels set to ground truth + y_comp = mask * y_true + (1-mask) * y_pred + + # Compute the vgg features + vgg_out = self.vgg(y_pred) + vgg_gt = self.vgg(y_true) + vgg_comp = self.vgg(y_comp) + + # Compute loss components + l1 = self.loss_valid(mask, y_true, y_pred) + l2 = self.loss_hole(mask, y_true, y_pred) + l3 = self.loss_perceptual(vgg_out, vgg_gt, vgg_comp) + l4 = self.loss_style(vgg_out, vgg_gt) + l5 = self.loss_style(vgg_comp, vgg_gt) + l6 = self.loss_tv(mask, y_comp) + + # Return loss function + return l1 + 6*l2 + 0.05*l3 + 120*(l4+l5) + 0.1*l6 + + return loss + + def loss_hole(self, mask, y_true, y_pred): + """Pixel L1 loss within the hole / mask""" + return self.l1((1-mask) * y_true, (1-mask) * y_pred) + + def loss_valid(self, mask, y_true, y_pred): + """Pixel L1 loss outside the hole / mask""" + return self.l1(mask * y_true, mask * y_pred) + + def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp): + """Perceptual loss based on VGG16, see. eq. 3 in paper""" + loss = 0 + for o, c, g in zip(vgg_out, vgg_comp, vgg_gt): + loss += self.l1(o, g) + self.l1(c, g) + return loss + + def loss_style(self, output, vgg_gt): + """Style loss based on output/computation, used for both eq. 4 & 5 in paper""" + loss = 0 + for o, g in zip(output, vgg_gt): + loss += self.l1(self.gram_matrix(o), self.gram_matrix(g)) + return loss + + def loss_tv(self, mask, y_comp): + """Total variation loss, used for smoothing the hole region, see. eq. 6""" + + # Create dilated hole region using a 3x3 kernel of all 1s. + kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3])) + dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same') + + # Cast values to be [0., 1.], and compute dilated hole region of y_comp + dilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32') + P = dilated_mask * y_comp + + # Calculate total variation loss + a = self.l1(P[:,1:,:,:], P[:,:-1,:,:]) + b = self.l1(P[:,:,1:,:], P[:,:,:-1,:]) + return a+b + + def fit(self, generator, epochs=10, plot_callback=None, *args, **kwargs): + """Fit the U-Net to a (images, targets) generator + + param generator: training generator yielding (maskes_image, original_image) tuples + param epochs: number of epochs to train for + param plot_callback: callback function taking Unet model as parameter + """ + + # Loop over epochs + for _ in range(epochs): + + # Fit the model + self.model.fit_generator( + generator, + epochs=self.current_epoch+1, + initial_epoch=self.current_epoch, + *args, **kwargs + ) + + # Update epoch + self.current_epoch += 1 + + # After each epoch predict on test images & show them + if plot_callback: + plot_callback(self.model) + + # Save logfile + if self.weight_filepath: + self.save() + + def predict(self, sample): + """Run prediction using this model""" + return self.model.predict(sample) + + def summary(self): + """Get summary of the UNet model""" + print(self.model.summary()) + + def save(self): + self.model.save_weights(self.current_weightfile()) + + def load(self, filepath, train_bn=True, lr=0.0002): + + # Create UNet-like model + self.model = self.build_pconv_unet(train_bn, lr) + + # Load weights into model + #epoch = 50 + epoch = int(os.path.basename(filepath).split("_")[0]) + assert epoch > 0, "Could not parse weight file. Should start with 'X_', with X being the epoch" + self.current_epoch = epoch + self.model.load_weights(filepath) + + def current_weightfile(self): + assert self.weight_filepath != None, 'Must specify location of logs' + return self.weight_filepath + "{}_weights_{}.h5".format(self.current_epoch, self.current_timestamp()) + + @staticmethod + def current_timestamp(): + return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + + @staticmethod + def l1(y_true, y_pred): + """Calculate the L1 loss used in all loss calculations""" + if K.ndim(y_true) == 4: + return K.sum(K.abs(y_pred - y_true), axis=[1,2,3]) + elif K.ndim(y_true) == 3: + return K.sum(K.abs(y_pred - y_true), axis=[1,2]) + else: + raise NotImplementedError("Calculating L1 loss on 1D tensors? should not occur for this network") + + @staticmethod + def gram_matrix(x, norm_by_channels=False): + """Calculate gram matrix used in style loss""" + + # Assertions on input + assert K.ndim(x) == 4, 'Input tensor should be a 4d (B, H, W, C) tensor' + assert K.image_data_format() == 'channels_last', "Please use channels-last format" + + # Permute channels and get resulting shape + x = K.permute_dimensions(x, (0, 3, 1, 2)) + shape = K.shape(x) + B, C, H, W = shape[0], shape[1], shape[2], shape[3] + + # Reshape x and do batch dot product + features = K.reshape(x, K.stack([B, C, H*W])) + gram = K.batch_dot(features, features, axes=2) + + # Normalize with channels, height and width + gram = gram / K.cast(C * H * W, x.dtype) + + return gram diff --git a/libs/pconv_layer.py b/libs/pconv_layer.py new file mode 100644 index 0000000..9efae82 --- /dev/null +++ b/libs/pconv_layer.py @@ -0,0 +1,126 @@ + +from keras.utils import conv_utils +from keras import backend as K +from keras.engine import InputSpec +from keras.layers import Conv2D + + +class PConv2D(Conv2D): + def __init__(self, *args, n_channels=3, mono=False, **kwargs): + super().__init__(*args, **kwargs) + self.input_spec = [InputSpec(ndim=4), InputSpec(ndim=4)] + + def build(self, input_shape): + """Adapted from original _Conv() layer of Keras + param input_shape: list of dimensions for [img, mask] + """ + + if self.data_format == 'channels_first': + channel_axis = 1 + else: + channel_axis = -1 + + if input_shape[0][channel_axis] is None: + raise ValueError('The channel dimension of the inputs should be defined. Found `None`.') + + self.input_dim = input_shape[0][channel_axis] + + # Image kernel + kernel_shape = self.kernel_size + (self.input_dim, self.filters) + self.kernel = self.add_weight(shape=kernel_shape, + initializer=self.kernel_initializer, + name='img_kernel', + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + # Mask kernel + self.kernel_mask = K.ones(shape=self.kernel_size + (self.input_dim, self.filters)) + + if self.use_bias: + self.bias = self.add_weight(shape=(self.filters,), + initializer=self.bias_initializer, + name='bias', + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + else: + self.bias = None + self.built = True + + def call(self, inputs, mask=None): + ''' + We will be using the Keras conv2d method, and essentially we have + to do here is multiply the mask with the input X, before we apply the + convolutions. For the mask itself, we apply convolutions with all weights + set to 1. + Subsequently, we set all mask values >0 to 1, and otherwise 0 + ''' + + # Both image and mask must be supplied + if type(inputs) is not list or len(inputs) != 2: + raise Exception('PartialConvolution2D must be called on a list of two tensors [img, mask]. Instead got: ' + str(inputs)) + + # Create normalization. Slight change here compared to paper, using mean mask value instead of sum + normalization = K.mean(inputs[1], axis=[1,2], keepdims=True) + normalization = K.repeat_elements(normalization, inputs[1].shape[1], axis=1) + normalization = K.repeat_elements(normalization, inputs[1].shape[2], axis=2) + + # Apply convolutions to image + img_output = K.conv2d( + (inputs[0]*inputs[1]) / normalization, self.kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate + ) + + # Apply convolutions to mask + mask_output = K.conv2d( + inputs[1], self.kernel_mask, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate + ) + + # Where something happened, set 1, otherwise 0 + mask_output = K.cast(K.greater(mask_output, 0), 'float32') + + # Apply bias only to the image (if chosen to do so) + if self.use_bias: + img_output = K.bias_add( + img_output, + self.bias, + data_format=self.data_format) + + # Apply activations on the image + if self.activation is not None: + img_output = self.activation(img_output) + + return [img_output, mask_output] + + def compute_output_shape(self, input_shape): + if self.data_format == 'channels_last': + space = input_shape[0][1:-1] + new_space = [] + for i in range(len(space)): + new_dim = conv_utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + new_shape = (input_shape[0][0],) + tuple(new_space) + (self.filters,) + return [new_shape, new_shape] + if self.data_format == 'channels_first': + space = input_shape[2:] + new_space = [] + for i in range(len(space)): + new_dim = conv_utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + new_shape = (input_shape[0], self.filters) + tuple(new_space) + return [new_shape, new_shape] diff --git a/ui.py b/ui.py new file mode 100644 index 0000000..59e3d6b --- /dev/null +++ b/ui.py @@ -0,0 +1,497 @@ +""" +Code illustration: 6.09 + + Modules imported here: + from tkinter import messagebox + from tkinter import filedialog + + Attributes added here: + file_name = "untitled" + + Methods modified here: + on_new_file_menu_clicked() + on_save_menu_clicked() + on_save_as_menu_clicked() + on_close_menu_clicked() + on_undo_menu_clicked() + on_canvas_zoom_in_menu_clicked() + on_canvas_zoom_out_menu_clicked() + on_about_menu_clicked() + + Methods added here + start_new_project() + actual_save() + close_window() + undo() + canvas_zoom_in() + canvas_zoom_out() + +@ Tkinter GUI Application Development Blueprints +""" +import math +from PIL import Image, ImageTk, ImageDraw +import tkinter as tk +from tkinter import colorchooser +from tkinter import ttk +from tkinter import messagebox +from tkinter import filedialog +import framework +import decensor + + +class PaintApplication(framework.Framework): + + def __init__(self, root): + super().__init__(root) + + self.drawn_img = None + self.screen_width = root.winfo_screenwidth() + self.screen_height = root.winfo_screenheight() + self.start_x, self.start_y = 0, 0 + self.end_x, self.end_y = 0, 0 + self.current_item = None + self.fill = "#00ff00" + self.fill_pil = (0,255,0,255) + self.outline = "#00ff00" + self.brush_width = 2 + self.background = 'white' + self.foreground = "#00ff00" + self.file_name = "Untitled" + self.tool_bar_functions = ( + "draw_line", "draw_irregular_line" + ) + self.selected_tool_bar_function = self.tool_bar_functions[0] + + self.create_gui() + self.bind_mouse() + + def on_new_file_menu_clicked(self, event=None): + self.start_new_project() + + def start_new_project(self): + self.canvas.delete(tk.ALL) + self.canvas.config(bg="#ffffff") + self.root.title('untitled') + + def on_open_image_menu_clicked(self, event=None): + self.open_image() + + def open_image(self): + self.file_name = filedialog.askopenfilename(master=self.root, filetypes = [("All Files","*.*")], title="Open...") + print(self.file_name) + self.canvas.img = Image.open(self.file_name) + self.canvas.img_width, self.canvas.img_height = self.canvas.img.size + #make reference to image to prevent garbage collection + #https://stackoverflow.com/questions/20061396/image-display-on-tkinter-canvas-not-working + self.canvas.tk_img = ImageTk.PhotoImage(self.canvas.img) + self.canvas.config(width=self.canvas.img_width, height=self.canvas.img_height) + self.canvas.create_image(self.canvas.img_width/2.0,self.canvas.img_height/2.0,image=self.canvas.tk_img) + + self.drawn_img = Image.new("RGBA", self.canvas.img.size) + self.drawn_img_draw = ImageDraw.Draw(self.drawn_img) + + + def on_import_mask_clicked(self, event=None): + self.import_mask() + + def display_canvas(self): + composite_img = Image.alpha_composite(self.canvas.img.convert('RGBA'), self.drawn_img).convert('RGB') + self.canvas.tk_img = ImageTk.PhotoImage(composite_img) + + self.canvas.create_image(self.canvas.img_width/2.0,self.canvas.img_height/2.0,image=self.canvas.tk_img) + + def import_mask(self): + file_name_mask = filedialog.askopenfilename(master=self.root, filetypes = [("All Files","*.*")], title="Import mask...") + mask_img = Image.open(file_name_mask) + if (mask_img.size != self.canvas.img.size): + messagebox.showerror("Import mask", "Mask image size does not match the original image size! Mask image not imported.") + return + self.drawn_img = mask_img + self.drawn_img_draw = ImageDraw.Draw(self.drawn_img) + self.display_canvas() + + def on_save_menu_clicked(self, event=None): + if self.file_name == 'untitled': + self.on_save_as_menu_clicked() + else: + self.actual_save() + + def on_save_as_menu_clicked(self): + file_name = filedialog.asksaveasfilename( + master=self.root, filetypes=[('All Files', ('*.ps', '*.ps'))], title="Save...") + if not file_name: + return + self.file_name = file_name + self.actual_save() + + def actual_save(self): + self.canvas.postscript(file=self.file_name, colormode='color') + self.root.title(self.file_name) + + def on_close_menu_clicked(self): + self.close_window() + + def close_window(self): + if messagebox.askokcancel("Quit", "Do you really want to quit?"): + self.root.destroy() + + def on_undo_menu_clicked(self, event=None): + self.undo() + + def undo(self): + items_stack = list(self.canvas.find("all")) + try: + last_item_id = items_stack.pop() + except IndexError: + return + self.canvas.delete(last_item_id) + + def on_canvas_zoom_in_menu_clicked(self): + self.canvas_zoom_in() + + def on_canvas_zoom_out_menu_clicked(self): + self.canvas_zoom_out() + + def canvas_zoom_in(self): + self.canvas.scale("all", 0, 0, 1.2, 1.2) + self.canvas.config(scrollregion=self.canvas.bbox(tk.ALL)) + self.canvas.pack(side=tk.RIGHT, expand=tk.YES, fill=tk.BOTH) + + def canvas_zoom_out(self): + self.canvas.scale("all", 0, 0, .8, .8) + self.canvas.config(scrollregion=self.canvas.bbox(tk.ALL)) + self.canvas.pack(side=tk.RIGHT, expand=tk.YES, fill=tk.BOTH) + + def on_decensor_menu_clicked(self, event=None): + combined_img = Image.alpha_composite(self.canvas.img.convert('RGBA'), self.drawn_img) + decensorer = decensor.Decensor() + decensorer.decensor_image(combined_img.convert('RGB'), self.file_name + ".png") + messagebox.showinfo( + "Decensoring", "Decensoring complete!") + + def on_about_menu_clicked(self, event=None): + # messagebox.showinfo( + # "Decensoring", "Decensoring in progress.") + messagebox.showinfo( + "About", "Tkinter GUI Application\n Development Blueprints") + + def get_all_configurations_for_item(self): + configuration_dict = {} + for key, value in self.canvas.itemconfig("current").items(): + if value[-1] and value[-1] not in ["0", "0.0", "0,0", "current"]: + configuration_dict[key] = value[-1] + return configuration_dict + + def canvas_function_wrapper(self, function_name, *arg, **kwargs): + func = getattr(self.canvas, function_name) + func(*arg, **kwargs) + + def adjust_canvas_coords(self, x_coordinate, y_coordinate): + # low_x, high_x = self.x_scroll.get() + # percent_x = low_x/(1+low_x-high_x) + + # low_y, high_y = self.y_scroll.get() + # percent_y = low_y/(1+low_y-high_y) + + low_x, high_x = self.x_scroll.get() + low_y, high_y = self.y_scroll.get() + #length_y = high_y - low_y + return low_x * 800 + x_coordinate, low_y * 800 + y_coordinate + + def create_circle(self, x, y, r, **kwargs): + return self.canvas.create_oval(x-r, y-r, x+r, y+r, **kwargs) + + def draw_irregular_line(self): + # self.current_item = self.canvas.create_line( + # self.start_x, self.start_y, self.end_x, self.end_y, fill=self.fill, width=self.brush_width) + # self.current_item = self.create_circle(self.end_x, self.end_y, self.brush_width/2.0, fill=self.fill, width=0) + + #draw in PIL + self.drawn_img_draw.line((self.start_x, self.start_y, self.end_x, self.end_y), fill=self.fill_pil, width=int(self.brush_width)) + self.drawn_img_draw.ellipse((self.end_x - self.brush_width/2.0, self.end_y - self.brush_width/2.0, self.end_x + self.brush_width/2.0, self.end_y + self.brush_width/2.0), fill=self.fill_pil) + + self.display_canvas() + # composite_img = Image.alpha_composite(self.canvas.img.convert('RGBA'), self.drawn_img).convert('RGB') + # self.canvas.tk_img = ImageTk.PhotoImage(composite_img) + + # self.canvas.create_image(self.canvas.img_width/2.0,self.canvas.img_height/2.0,image=self.canvas.tk_img) + + self.canvas.bind("", self.draw_irregular_line_update_x_y) + + def draw_irregular_line_update_x_y(self, event=None): + self.start_x, self.start_y = self.end_x, self.end_y + self.end_x, self.end_y = self.adjust_canvas_coords(event.x, event.y) + self.draw_irregular_line() + + def draw_irregular_line_options(self): + self.create_fill_options_combobox() + self.create_width_options_combobox() + + def on_tool_bar_button_clicked(self, button_index): + self.selected_tool_bar_function = self.tool_bar_functions[button_index] + self.remove_options_from_top_bar() + self.display_options_in_the_top_bar() + self.bind_mouse() + + def float_range(self, x, y, step): + while x < y: + yield x + x += step + + def set_foreground_color(self, event=None): + self.foreground = self.get_color_from_chooser( + self.foreground, "foreground") + self.color_palette.itemconfig( + self.foreground_palette, width=0, fill=self.foreground) + + def set_background_color(self, event=None): + self.background = self.get_color_from_chooser( + self.background, "background") + self.color_palette.itemconfig( + self.background_palette, width=0, fill=self.background) + + def get_color_from_chooser(self, initial_color, color_type="a"): + color = colorchooser.askcolor( + color=initial_color, + title="select {} color".format(color_type) + )[-1] + if color: + return color + # dialog has been cancelled + else: + return initial_color + + def try_to_set_fill_after_palette_change(self): + try: + self.set_fill() + except: + pass + + def try_to_set_outline_after_palette_change(self): + try: + self.set_outline() + except: + pass + + def display_options_in_the_top_bar(self): + self.show_selected_tool_icon_in_top_bar( + self.selected_tool_bar_function) + options_function_name = "{}_options".format(self.selected_tool_bar_function) + func = getattr(self, options_function_name, self.function_not_defined) + func() + + def draw_line_options(self): + self.create_fill_options_combobox() + self.create_width_options_combobox() + + def create_fill_options_combobox(self): + tk.Label(self.top_bar, text='Fill:').pack(side="left") + self.fill_combobox = ttk.Combobox( + self.top_bar, state='readonly', width=5) + self.fill_combobox.pack(side="left") + self.fill_combobox['values'] = ('none', 'fg', 'bg', 'black', 'white') + self.fill_combobox.bind('<>', self.set_fill) + self.fill_combobox.set(self.fill) + + def create_outline_options_combobox(self): + tk.Label(self.top_bar, text='Outline:').pack(side="left") + self.outline_combobox = ttk.Combobox( + self.top_bar, state='readonly', width=5) + self.outline_combobox.pack(side="left") + self.outline_combobox['values'] = ( + 'none', 'fg', 'bg', 'black', 'white') + self.outline_combobox.bind('<>', self.set_outline) + self.outline_combobox.set(self.outline) + + def create_width_options_combobox(self): + tk.Label(self.top_bar, text='Width:').pack(side="left") + self.width_combobox = ttk.Combobox( + self.top_bar, state='readonly', width=3) + self.width_combobox.pack(side="left") + self.width_combobox['values'] = ( + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50) + self.width_combobox.bind('<>', self.set_brush_width) + self.width_combobox.set(self.brush_width) + + def set_fill(self, event=None): + fill_color = self.fill_combobox.get() + if fill_color == 'none': + self.fill = '' # transparent + elif fill_color == 'fg': + self.fill = self.foreground + elif fill_color == 'bg': + self.fill = self.background + else: + self.fill = fill_color + + def set_outline(self, event=None): + outline_color = self.outline_combobox.get() + if outline_color == 'none': + self.outline = '' # transparent + elif outline_color == 'fg': + self.outline = self.foreground + elif outline_color == 'bg': + self.outline = self.background + else: + self.outline = outline_color + + def set_brush_width(self, event): + self.brush_width = float(self.width_combobox.get()) + + def create_color_palette(self): + self.color_palette = tk.Canvas(self.tool_bar, height=55, width=55) + self.color_palette.grid(row=10, column=1, columnspan=2, pady=5, padx=3) + self.background_palette = self.color_palette.create_rectangle( + 15, 15, 48, 48, outline=self.background, fill=self.background) + self.foreground_palette = self.color_palette.create_rectangle( + 1, 1, 33, 33, outline=self.foreground, fill=self.foreground) + self.bind_color_palette() + + def bind_color_palette(self): + self.color_palette.tag_bind( + self.background_palette, "", self.set_background_color) + self.color_palette.tag_bind( + self.foreground_palette, "", self.set_foreground_color) + + def create_current_coordinate_label(self): + self.current_coordinate_label = tk.Label( + self.tool_bar, text='x:0\ny: 0 ') + self.current_coordinate_label.grid( + row=13, column=1, columnspan=2, pady=5, padx=1, sticky='w') + + def show_current_coordinates(self, event=None): + x_coordinate = event.x + y_coordinate = event.y + coordinate_string = "x:{0}\ny:{1}".format(x_coordinate, y_coordinate) + self.current_coordinate_label.config(text=coordinate_string) + + def function_not_defined(self): + pass + + def execute_selected_method(self): + self.current_item = None + func = getattr( + self, self.selected_tool_bar_function, self.function_not_defined) + func() + + def draw_line(self): + self.current_item = self.canvas.create_line( + self.start_x, self.start_y, self.end_x, self.end_y, fill=self.fill, width=self.brush_width) + + # self.drawn_img_draw.line((self.start_x, self.start_y, self.end_x, self.end_y), fill=self.fill_pil, width=self.brush_width) + + def create_tool_bar_buttons(self): + for index, name in enumerate(self.tool_bar_functions): + icon = tk.PhotoImage(file='icons/' + name + '.gif') + self.button = tk.Button( + self.tool_bar, image=icon, command=lambda index=index: self.on_tool_bar_button_clicked(index)) + self.button.grid( + row=index // 2, column=1 + index % 2, sticky='nsew') + self.button.image = icon + + def remove_options_from_top_bar(self): + for child in self.top_bar.winfo_children(): + child.destroy() + + def show_selected_tool_icon_in_top_bar(self, function_name): + display_name = function_name.replace("_", " ").capitalize() + ":" + tk.Label(self.top_bar, text=display_name).pack(side="left") + photo = tk.PhotoImage( + file='icons/' + function_name + '.gif') + label = tk.Label(self.top_bar, image=photo) + label.image = photo + label.pack(side="left") + + def bind_mouse(self): + self.canvas.bind("", self.on_mouse_button_pressed) + self.canvas.bind( + "", self.on_mouse_button_pressed_motion) + self.canvas.bind( + "", self.on_mouse_button_released) + self.canvas.bind("", self.on_mouse_unpressed_motion) + + def on_mouse_button_pressed(self, event): + self.start_x = self.end_x = self.canvas.canvasx(event.x) + self.start_y = self.end_y = self.canvas.canvasy(event.y) + self.execute_selected_method() + + def on_mouse_button_pressed_motion(self, event): + self.end_x = self.canvas.canvasx(event.x) + self.end_y = self.canvas.canvasy(event.y) + self.canvas.delete(self.current_item) + self.execute_selected_method() + + def on_mouse_button_released(self, event): + self.end_x = self.canvas.canvasx(event.x) + self.end_y = self.canvas.canvasy(event.y) + + def on_mouse_unpressed_motion(self, event): + self.show_current_coordinates(event) + + def create_gui(self): + self.create_menu() + self.create_top_bar() + self.create_tool_bar() + self.create_tool_bar_buttons() + self.create_drawing_canvas() + self.create_color_palette() + self.create_current_coordinate_label() + self.bind_menu_accelrator_keys() + self.show_selected_tool_icon_in_top_bar("draw_line") + self.draw_line_options() + + def create_menu(self): + self.menubar = tk.Menu(self.root) + menu_definitions = ( + 'File- &New/Ctrl+N/self.on_new_file_menu_clicked, Open/Ctrl+O/self.on_open_image_menu_clicked, Import Mask/Ctrl+M/self.on_import_mask_clicked, Save/Ctrl+S/self.on_save_menu_clicked, SaveAs/ /self.on_save_as_menu_clicked, sep, Exit/Alt+F4/self.on_close_menu_clicked', + 'Edit- Undo/Ctrl+Z/self.on_undo_menu_clicked, sep', + 'View- Zoom in//self.on_canvas_zoom_in_menu_clicked,Zoom Out//self.on_canvas_zoom_out_menu_clicked', + 'Decensor- Decensor/Ctrl+D/self.on_decensor_menu_clicked', + 'About- About/F1/self.on_about_menu_clicked' + ) + self.build_menu(menu_definitions) + + def create_top_bar(self): + self.top_bar = tk.Frame(self.root, height=25, relief="raised") + self.top_bar.pack(fill="x", side="top", pady=2) + + def create_tool_bar(self): + self.tool_bar = tk.Frame(self.root, relief="raised", width=50) + self.tool_bar.pack(fill="y", side="left", pady=3) + + def create_drawing_canvas(self): + self.canvas_frame = tk.Frame(self.root, width=900, height=900) + self.canvas_frame.pack(side="right", expand="yes", fill="both") + self.canvas = tk.Canvas(self.canvas_frame, background="white", + width=512, height=512, scrollregion=(0, 0, 512, 512)) + self.create_scroll_bar() + self.canvas.pack(side=tk.RIGHT, expand=tk.YES, fill=tk.BOTH) + + self.canvas.img = Image.open('./icons/canvas_top_test.png').convert('RGBA') + self.canvas.img = self.canvas.img.resize((512,512)) + self.canvas.tk_img = ImageTk.PhotoImage(self.canvas.img) + self.canvas.create_image(256,256,image=self.canvas.tk_img) + + def create_scroll_bar(self): + self.x_scroll = tk.Scrollbar(self.canvas_frame, orient="horizontal") + self.x_scroll.pack(side="bottom", fill="x") + self.x_scroll.config(command=self.canvas.xview) + self.y_scroll = tk.Scrollbar(self.canvas_frame, orient="vertical") + self.y_scroll.pack(side="right", fill="y") + self.y_scroll.config(command=self.canvas.yview) + self.canvas.config( + xscrollcommand=self.x_scroll.set, yscrollcommand=self.y_scroll.set) + + def bind_menu_accelrator_keys(self): + self.root.bind('', self.on_about_menu_clicked) + self.root.bind('', self.on_new_file_menu_clicked) + self.root.bind('', self.on_new_file_menu_clicked) + self.root.bind('', self.on_save_menu_clicked) + self.root.bind('', self.on_save_menu_clicked) + self.root.bind('', self.on_undo_menu_clicked) + self.root.bind('', self.on_undo_menu_clicked) + +if __name__ == '__main__': + root = tk.Tk() + app = PaintApplication(root) + root.mainloop() \ No newline at end of file