Add files

This commit is contained in:
deeppomf 2018-10-20 00:11:20 -04:00
parent 44031fb53b
commit 40d3a50d45
7 changed files with 1349 additions and 0 deletions

29
config.py Normal file
View File

@ -0,0 +1,29 @@
import argparse
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1', True):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0', False):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def get_args():
parser = argparse.ArgumentParser(description='')
#Input output folders settings
parser.add_argument('--decensor_input_path', dest='decensor_input_path', default='./decensor_input/', help='input images to be decensored by decensor.py path')
parser.add_argument('--decensor_output_path', dest='decensor_output_path', default='./decensor_output/', help='output images generated from running decensor.py path')
#Decensor settings
parser.add_argument('--mask_color_red', dest='mask_color_red', default=0, help='red channel of mask color in decensoring')
parser.add_argument('--mask_color_green', dest='mask_color_green', default=255, help='green channel of mask color in decensoring')
parser.add_argument('--mask_color_blue', dest='mask_color_blue', default=0, help='blue channel of mask color in decensoring')
parser.add_argument('--is_mosaic', dest='is_mosaic', default='False', type=str2bool, help='true if image has mosaic censoring, false otherwise')
args = parser.parse_args()
return args
if __name__ == '__main__':
get_args()

165
decensor.py Normal file
View File

@ -0,0 +1,165 @@
import numpy as np
from PIL import Image
import os
from copy import deepcopy
import config
from libs.pconv_hybrid_model import PConvUnet
from libs.flood_fill import find_regions, expand_bounding
class Decensor():
def __init__(self):
self.args = config.get_args()
self.decensor_mosaic = self.args.is_mosaic
self.mask_color = [self.args.mask_color_red/255.0, self.args.mask_color_green/255.0, self.args.mask_color_blue/255.0]
if not os.path.exists(self.args.decensor_output_path):
os.makedirs(self.args.decensor_output_path)
self.load_model()
def get_mask(self,ori, width, height):
mask = np.zeros(ori.shape, np.uint8)
#count = 0
#TODO: change to iterate over all images in batch when implementing batches
for row in range(height):
for col in range(width):
if np.array_equal(ori[0][row][col], self.mask_color):
mask[0, row, col] = 1
return 1-mask
def load_model(self):
self.model = PConvUnet(weight_filepath='data/logs/')
self.model.load(
r"./models/model.h5",
train_bn=False,
lr=0.00005
)
def decensor_all_images_in_folder(self):
#load model once at beginning and reuse same model
#self.load_model()
subdir = self.args.decensor_input_path
files = os.listdir(subdir)
#convert all images into np arrays and put them in a list
for file in files:
#print(file)
file_path = os.path.join(subdir, file)
if os.path.isfile(file_path) and os.path.splitext(file_path)[1] == ".png":
print("Decensoring the image {file_path}".format(file_path))
censored_img = Image.open(file_path)
self.decensor_image(censored_img, file)
#decensors one image at a time
#TODO: decensor all cropped parts of the same image in a batch (then i need input for ori an array of those images and make additional changes)
def decensor_image(self,ori, file_name):
width, height = ori.size
#save the alpha channel if the image has an alpha channel
has_alpha = False
alpha_channel = None
if (ori.mode == "RGBA"):
has_alpha = True
alpha_channel = np.asarray(ori)[:,:,3]
alpha_channel = np.expand_dims(alpha_channel, axis =-1)
ori = ori.convert('RGB')
ori_array = np.asarray(ori)
ori_array = np.array(ori_array / 255.0)
ori_array = np.expand_dims(ori_array, axis = 0)
mask = self.get_mask(ori_array, width, height)
regions = find_regions(ori)
print("Found {region_count} censored regions in this image!".format(region_count = len(regions)))
if len(regions) == 0 and not self.decensor_mosaic:
print("No green colored regions detected!")
return
output_img_array = ori_array[0].copy()
for region_counter, region in enumerate(regions, 1):
bounding_box = expand_bounding(ori, region)
crop_img = ori.crop(bounding_box)
#crop_img.show()
#convert mask back to image
mask_reshaped = mask[0,:,:,:] * 255.0
mask_img = Image.fromarray(mask_reshaped.astype('uint8'))
#resize the cropped images
crop_img = crop_img.resize((512, 512))
crop_img_array = np.asarray(crop_img)
crop_img_array = crop_img_array / 255.0
crop_img_array = np.expand_dims(crop_img_array, axis = 0)
#resize the mask images
mask_img = mask_img.crop(bounding_box)
mask_img = mask_img.resize((512, 512))
#mask_img.show()
#convert mask_img back to array
mask_array = np.asarray(mask_img)
mask_array = np.array(mask_array / 255.0)
#the mask has been upscaled so there will be values not equal to 0 or 1
#mask_array[mask_array < 0.01] = 0
mask_array[mask_array > 0] = 1
mask_array = np.expand_dims(mask_array, axis = 0)
# Run predictions for this batch of images
pred_img_array = self.model.predict([crop_img_array, mask_array, mask_array])
pred_img_array = pred_img_array * 255.0
pred_img_array = np.squeeze(pred_img_array, axis = 0)
#scale prediction image back to original size
bounding_width = bounding_box[2]-bounding_box[0]
bounding_height = bounding_box[3]-bounding_box[1]
#convert np array to image
# print(bounding_width,bounding_height)
# print(pred_img_array.shape)
pred_img = Image.fromarray(pred_img_array.astype('uint8'))
#pred_img.show()
pred_img = pred_img.resize((bounding_width, bounding_height), resample = Image.BICUBIC)
pred_img_array = np.asarray(pred_img)
pred_img_array = pred_img_array / 255.0
# print(pred_img_array.shape)
pred_img_array = np.expand_dims(pred_img_array, axis = 0)
for i in range(len(ori_array)):
if self.decensor_mosaic:
output_img_array = pred_img[i]
else:
for col in range(bounding_width):
for row in range(bounding_height):
bounding_width = col + bounding_box[0]
bounding_height = row + bounding_box[1]
if (bounding_width, bounding_height) in region:
output_img_array[bounding_height][bounding_width] = pred_img_array[i,:,:,:][row][col]
print("{region_counter} out of {region_count} regions decensored.".format(region_counter=region_counter, region_count=len(regions)))
output_img_array = output_img_array * 255.0
#restore the alpha channel
if has_alpha:
print(output_img_array.shape)
print(alpha_channel.shape)
output_img_array = np.concatenate((output_img_array, alpha_channel), axis = 2)
output_img = Image.fromarray(output_img_array.astype('uint8'))
#save the decensored image
#file_name, _ = os.path.splitext(file_name)
save_path = os.path.join(self.args.decensor_output_path, file_name)
output_img.save(save_path)
print("Decensored image saved to {save_path}!".format(save_path=save_path))
return
if __name__ == '__main__':
decensor = Decensor()
decensor.decensor_all_images_in_folder()

128
flood_fill.py Normal file
View File

@ -0,0 +1,128 @@
from PIL import Image, ImageDraw
#find strongly connected components with the mask color
def find_regions(image):
pixel = image.load()
neighbors = dict()
width, height = image.size
for x in range(width):
for y in range(height):
if is_green(pixel[x,y]):
neighbors[x, y] = {(x,y)}
for x, y in neighbors:
candidates = (x + 1, y), (x, y + 1)
for candidate in candidates:
if candidate in neighbors:
neighbors[x, y].add(candidate)
neighbors[candidate].add((x, y))
closed_list = set()
def connected_component(pixel):
region = set()
open_list = {pixel}
while open_list:
pixel = open_list.pop()
closed_list.add(pixel)
open_list |= neighbors[pixel] - closed_list
region.add(pixel)
return region
regions = []
for pixel in neighbors:
if pixel not in closed_list:
regions.append(connected_component(pixel))
regions.sort(key = len, reverse = True)
return regions
# risk of box being bigger than the image
def expand_bounding(img, region, expand_factor=1.5, min_size = 256, max_size=512):
#expand bounding box to capture more context
x, y = zip(*region)
min_x, min_y, max_x, max_y = min(x), min(y), max(x), max(y)
width, height = img.size
width_center = width//2
height_center = height//2
bb_width = max_x - min_x
bb_height = max_y - min_y
x_center = (min_x + max_x)//2
y_center = (min_y + max_y)//2
current_size = max(bb_width, bb_height)
current_size = int(current_size * expand_factor)
if current_size > max_size:
current_size = max_size
elif current_size < min_size:
current_size = min_size
x1 = x_center - current_size//2
x2 = x_center + current_size//2
y1 = y_center - current_size//2
y2 = y_center + current_size//2
x1_square = x1
y1_square = y1
x2_square = x2
y2_square = y2
#move bounding boxes that are partially outside of the image inside the image
if (y1_square < 0 or y2_square > (height - 1)) and (x1_square < 0 or x2_square > (width - 1)):
#conservative square region
if x1_square < 0 and y1_square < 0:
x1_square = 0
y1_square = 0
x2_square = current_size
y2_square = current_size
elif x2_square > (width - 1) and y2_square > (height - 1):
x1_square = width - current_size - 1
y1_square = 0
x2_square = width - 1
y2_square = current_size
elif x1_square < 0 and y2_square > (height - 1):
x1_square = 0
y1_square = height - current_size - 1
x2_square = current_size
y2_square = height - 1
elif x2_square > (width - 1) and y2_square > (height - 1):
x1_square = width - current_size - 1
y1_square = height - current_size - 1
x2_square = width - 1
y2_square = height - 1
else:
x1_square = x1
y1_square = y1
x2_square = x2
y2_square = y2
else:
if x1_square < 0:
difference = x1_square
x1_square -= difference
x2_square -= difference
if x2_square > (width - 1):
difference = x2_square - width + 1
x1_square -= difference
x2_square -= difference
if y1_square < 0:
difference = y1_square
y1_square -= difference
y2_square -= difference
if y2_square > (height - 1):
difference = y2_square - height + 1
y1_square -= difference
y2_square -= difference
# if y1_square < 0 or y2_square > (height - 1):
#if bounding box goes outside of the image for some reason, set bounds to original, unexpanded values
#print(width, height)
if x2_square > width or y2_square > height:
print("bounding box out of bounds!")
print(x1_square, y1_square, x2_square, y2_square)
x1_square, y1_square, x2_square, y2_square = min_x, min_y, max_x, max_y
return x1_square, y1_square, x2_square, y2_square
def is_green(pixel):
r, g, b = pixel
return r == 0 and g == 255 and b == 0
if __name__ == '__main__':
image = Image.open('')
no_alpha_image = image.convert('RGB')
draw = ImageDraw.Draw(no_alpha_image)
for region in find_regions(no_alpha_image):
draw.rectangle(expand_bounding(no_alpha_image, region), outline=(0, 255, 0))
no_alpha_image.show()

128
libs/flood_fill.py Normal file
View File

@ -0,0 +1,128 @@
from PIL import Image, ImageDraw
#find strongly connected components with the mask color
def find_regions(image):
pixel = image.load()
neighbors = dict()
width, height = image.size
for x in range(width):
for y in range(height):
if is_green(pixel[x,y]):
neighbors[x, y] = {(x,y)}
for x, y in neighbors:
candidates = (x + 1, y), (x, y + 1)
for candidate in candidates:
if candidate in neighbors:
neighbors[x, y].add(candidate)
neighbors[candidate].add((x, y))
closed_list = set()
def connected_component(pixel):
region = set()
open_list = {pixel}
while open_list:
pixel = open_list.pop()
closed_list.add(pixel)
open_list |= neighbors[pixel] - closed_list
region.add(pixel)
return region
regions = []
for pixel in neighbors:
if pixel not in closed_list:
regions.append(connected_component(pixel))
regions.sort(key = len, reverse = True)
return regions
# risk of box being bigger than the image
def expand_bounding(img, region, expand_factor=1.5, min_size = 256, max_size=512):
#expand bounding box to capture more context
x, y = zip(*region)
min_x, min_y, max_x, max_y = min(x), min(y), max(x), max(y)
width, height = img.size
width_center = width//2
height_center = height//2
bb_width = max_x - min_x
bb_height = max_y - min_y
x_center = (min_x + max_x)//2
y_center = (min_y + max_y)//2
current_size = max(bb_width, bb_height)
current_size = int(current_size * expand_factor)
if current_size > max_size:
current_size = max_size
elif current_size < min_size:
current_size = min_size
x1 = x_center - current_size//2
x2 = x_center + current_size//2
y1 = y_center - current_size//2
y2 = y_center + current_size//2
x1_square = x1
y1_square = y1
x2_square = x2
y2_square = y2
#move bounding boxes that are partially outside of the image inside the image
if (y1_square < 0 or y2_square > (height - 1)) and (x1_square < 0 or x2_square > (width - 1)):
#conservative square region
if x1_square < 0 and y1_square < 0:
x1_square = 0
y1_square = 0
x2_square = current_size
y2_square = current_size
elif x2_square > (width - 1) and y2_square > (height - 1):
x1_square = width - current_size - 1
y1_square = 0
x2_square = width - 1
y2_square = current_size
elif x1_square < 0 and y2_square > (height - 1):
x1_square = 0
y1_square = height - current_size - 1
x2_square = current_size
y2_square = height - 1
elif x2_square > (width - 1) and y2_square > (height - 1):
x1_square = width - current_size - 1
y1_square = height - current_size - 1
x2_square = width - 1
y2_square = height - 1
else:
x1_square = x1
y1_square = y1
x2_square = x2
y2_square = y2
else:
if x1_square < 0:
difference = x1_square
x1_square -= difference
x2_square -= difference
if x2_square > (width - 1):
difference = x2_square - width + 1
x1_square -= difference
x2_square -= difference
if y1_square < 0:
difference = y1_square
y1_square -= difference
y2_square -= difference
if y2_square > (height - 1):
difference = y2_square - height + 1
y1_square -= difference
y2_square -= difference
# if y1_square < 0 or y2_square > (height - 1):
#if bounding box goes outside of the image for some reason, set bounds to original, unexpanded values
#print(width, height)
if x2_square > width or y2_square > height:
print("bounding box out of bounds!")
print(x1_square, y1_square, x2_square, y2_square)
x1_square, y1_square, x2_square, y2_square = min_x, min_y, max_x, max_y
return x1_square, y1_square, x2_square, y2_square
def is_green(pixel):
r, g, b = pixel
return r == 0 and g == 255 and b == 0
if __name__ == '__main__':
image = Image.open('')
no_alpha_image = image.convert('RGB')
draw = ImageDraw.Draw(no_alpha_image)
for region in find_regions(no_alpha_image):
draw.rectangle(expand_bounding(no_alpha_image, region), outline=(0, 255, 0))
no_alpha_image.show()

276
libs/pconv_hybrid_model.py Normal file
View File

@ -0,0 +1,276 @@
import os
from datetime import datetime
from keras.models import Model
from keras.models import load_model
from keras.optimizers import Adam
from keras.layers import Input, Conv2D, UpSampling2D, Dropout, LeakyReLU, BatchNormalization, Activation
from keras.layers.merge import Concatenate
from keras.applications import VGG16
from keras import backend as K
from libs.pconv_layer import PConv2D
class PConvUnet(object):
def __init__(self, img_rows=512, img_cols=512, weight_filepath=None):
"""Create the PConvUnet. If variable image size, set img_rows and img_cols to None"""
# Settings
self.weight_filepath = weight_filepath
self.img_rows = img_rows
self.img_cols = img_cols
assert self.img_rows >= 256, 'Height must be >256 pixels'
assert self.img_cols >= 256, 'Width must be >256 pixels'
# Set current epoch
self.current_epoch = 0
# VGG layers to extract features from (first maxpooling layers, see pp. 7 of paper)
self.vgg_layers = [3, 6, 10]
# Get the vgg16 model for perceptual loss
self.vgg = self.build_vgg()
# Create UNet-like model
self.model = self.build_pconv_unet()
def build_vgg(self):
"""
Load pre-trained VGG16 from keras applications
Extract features to be used in loss function from last conv layer, see architecture at:
https://github.com/keras-team/keras/blob/master/keras/applications/vgg16.py
"""
# Input image to extract features from
img = Input(shape=(self.img_rows, self.img_cols, 3))
# Get the vgg network from Keras applications
vgg = VGG16(weights="imagenet", include_top=False)
# Output the first three pooling layers
vgg.outputs = [vgg.layers[i].output for i in self.vgg_layers]
# Create model and compile
model = Model(inputs=img, outputs=vgg(img))
model.trainable = False
model.compile(loss='mse', optimizer='adam')
return model
def build_pconv_unet(self, train_bn=True, lr=0.0002):
# INPUTS
inputs_img = Input((self.img_rows, self.img_cols, 3))
inputs_mask = Input((self.img_rows, self.img_cols, 3))
loss_mask = Input((self.img_rows, self.img_cols, 3))
# ENCODER
def encoder_layer(img_in, mask_in, filters, kernel_size, bn=True):
conv, mask = PConv2D(filters, kernel_size, strides=2, padding='same')([img_in, mask_in])
if bn:
conv = BatchNormalization(name='EncBN'+str(encoder_layer.counter))(conv, training=train_bn)
conv = Activation('relu')(conv)
encoder_layer.counter += 1
return conv, mask
encoder_layer.counter = 0
e_conv1, e_mask1 = encoder_layer(inputs_img, inputs_mask, 64, 7, bn=False)
e_conv2, e_mask2 = encoder_layer(e_conv1, e_mask1, 128, 5)
e_conv3, e_mask3 = encoder_layer(e_conv2, e_mask2, 256, 5)
e_conv4, e_mask4 = encoder_layer(e_conv3, e_mask3, 512, 3)
e_conv5, e_mask5 = encoder_layer(e_conv4, e_mask4, 512, 3)
e_conv6, e_mask6 = encoder_layer(e_conv5, e_mask5, 512, 3)
e_conv7, e_mask7 = encoder_layer(e_conv6, e_mask6, 512, 3)
e_conv8, e_mask8 = encoder_layer(e_conv7, e_mask7, 512, 3)
# DECODER
def decoder_layer(img_in, mask_in, e_conv, e_mask, filters, kernel_size, bn=True):
up_img = UpSampling2D(size=(2,2))(img_in)
up_mask = UpSampling2D(size=(2,2))(mask_in)
concat_img = Concatenate(axis=3)([e_conv,up_img])
concat_mask = Concatenate(axis=3)([e_mask,up_mask])
conv, mask = PConv2D(filters, kernel_size, padding='same')([concat_img, concat_mask])
if bn:
conv = BatchNormalization()(conv)
conv = LeakyReLU(alpha=0.2)(conv)
return conv, mask
d_conv9, d_mask9 = decoder_layer(e_conv8, e_mask8, e_conv7, e_mask7, 512, 3)
d_conv10, d_mask10 = decoder_layer(d_conv9, d_mask9, e_conv6, e_mask6, 512, 3)
d_conv11, d_mask11 = decoder_layer(d_conv10, d_mask10, e_conv5, e_mask5, 512, 3)
d_conv12, d_mask12 = decoder_layer(d_conv11, d_mask11, e_conv4, e_mask4, 512, 3)
d_conv13, d_mask13 = decoder_layer(d_conv12, d_mask12, e_conv3, e_mask3, 256, 3)
d_conv14, d_mask14 = decoder_layer(d_conv13, d_mask13, e_conv2, e_mask2, 128, 3)
d_conv15, d_mask15 = decoder_layer(d_conv14, d_mask14, e_conv1, e_mask1, 64, 3)
d_conv16, d_mask16 = decoder_layer(d_conv15, d_mask15, inputs_img, inputs_mask, 3, 3, bn=False)
outputs = Conv2D(3, 1, activation = 'sigmoid')(d_conv16)
# Setup the model inputs / outputs
model = Model(inputs=[inputs_img, inputs_mask, loss_mask], outputs=outputs)
# Compile the model
model.compile(
optimizer = Adam(lr=lr),
loss=self.loss_total(loss_mask)
)
return model
def loss_total(self, mask):
"""
Creates a loss function which sums all the loss components
and multiplies by their weights. See paper eq. 7.
"""
def loss(y_true, y_pred):
# Compute predicted image with non-hole pixels set to ground truth
y_comp = mask * y_true + (1-mask) * y_pred
# Compute the vgg features
vgg_out = self.vgg(y_pred)
vgg_gt = self.vgg(y_true)
vgg_comp = self.vgg(y_comp)
# Compute loss components
l1 = self.loss_valid(mask, y_true, y_pred)
l2 = self.loss_hole(mask, y_true, y_pred)
l3 = self.loss_perceptual(vgg_out, vgg_gt, vgg_comp)
l4 = self.loss_style(vgg_out, vgg_gt)
l5 = self.loss_style(vgg_comp, vgg_gt)
l6 = self.loss_tv(mask, y_comp)
# Return loss function
return l1 + 6*l2 + 0.05*l3 + 120*(l4+l5) + 0.1*l6
return loss
def loss_hole(self, mask, y_true, y_pred):
"""Pixel L1 loss within the hole / mask"""
return self.l1((1-mask) * y_true, (1-mask) * y_pred)
def loss_valid(self, mask, y_true, y_pred):
"""Pixel L1 loss outside the hole / mask"""
return self.l1(mask * y_true, mask * y_pred)
def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp):
"""Perceptual loss based on VGG16, see. eq. 3 in paper"""
loss = 0
for o, c, g in zip(vgg_out, vgg_comp, vgg_gt):
loss += self.l1(o, g) + self.l1(c, g)
return loss
def loss_style(self, output, vgg_gt):
"""Style loss based on output/computation, used for both eq. 4 & 5 in paper"""
loss = 0
for o, g in zip(output, vgg_gt):
loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))
return loss
def loss_tv(self, mask, y_comp):
"""Total variation loss, used for smoothing the hole region, see. eq. 6"""
# Create dilated hole region using a 3x3 kernel of all 1s.
kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3]))
dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same')
# Cast values to be [0., 1.], and compute dilated hole region of y_comp
dilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32')
P = dilated_mask * y_comp
# Calculate total variation loss
a = self.l1(P[:,1:,:,:], P[:,:-1,:,:])
b = self.l1(P[:,:,1:,:], P[:,:,:-1,:])
return a+b
def fit(self, generator, epochs=10, plot_callback=None, *args, **kwargs):
"""Fit the U-Net to a (images, targets) generator
param generator: training generator yielding (maskes_image, original_image) tuples
param epochs: number of epochs to train for
param plot_callback: callback function taking Unet model as parameter
"""
# Loop over epochs
for _ in range(epochs):
# Fit the model
self.model.fit_generator(
generator,
epochs=self.current_epoch+1,
initial_epoch=self.current_epoch,
*args, **kwargs
)
# Update epoch
self.current_epoch += 1
# After each epoch predict on test images & show them
if plot_callback:
plot_callback(self.model)
# Save logfile
if self.weight_filepath:
self.save()
def predict(self, sample):
"""Run prediction using this model"""
return self.model.predict(sample)
def summary(self):
"""Get summary of the UNet model"""
print(self.model.summary())
def save(self):
self.model.save_weights(self.current_weightfile())
def load(self, filepath, train_bn=True, lr=0.0002):
# Create UNet-like model
self.model = self.build_pconv_unet(train_bn, lr)
# Load weights into model
#epoch = 50
epoch = int(os.path.basename(filepath).split("_")[0])
assert epoch > 0, "Could not parse weight file. Should start with 'X_', with X being the epoch"
self.current_epoch = epoch
self.model.load_weights(filepath)
def current_weightfile(self):
assert self.weight_filepath != None, 'Must specify location of logs'
return self.weight_filepath + "{}_weights_{}.h5".format(self.current_epoch, self.current_timestamp())
@staticmethod
def current_timestamp():
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
@staticmethod
def l1(y_true, y_pred):
"""Calculate the L1 loss used in all loss calculations"""
if K.ndim(y_true) == 4:
return K.sum(K.abs(y_pred - y_true), axis=[1,2,3])
elif K.ndim(y_true) == 3:
return K.sum(K.abs(y_pred - y_true), axis=[1,2])
else:
raise NotImplementedError("Calculating L1 loss on 1D tensors? should not occur for this network")
@staticmethod
def gram_matrix(x, norm_by_channels=False):
"""Calculate gram matrix used in style loss"""
# Assertions on input
assert K.ndim(x) == 4, 'Input tensor should be a 4d (B, H, W, C) tensor'
assert K.image_data_format() == 'channels_last', "Please use channels-last format"
# Permute channels and get resulting shape
x = K.permute_dimensions(x, (0, 3, 1, 2))
shape = K.shape(x)
B, C, H, W = shape[0], shape[1], shape[2], shape[3]
# Reshape x and do batch dot product
features = K.reshape(x, K.stack([B, C, H*W]))
gram = K.batch_dot(features, features, axes=2)
# Normalize with channels, height and width
gram = gram / K.cast(C * H * W, x.dtype)
return gram

126
libs/pconv_layer.py Normal file
View File

@ -0,0 +1,126 @@
from keras.utils import conv_utils
from keras import backend as K
from keras.engine import InputSpec
from keras.layers import Conv2D
class PConv2D(Conv2D):
def __init__(self, *args, n_channels=3, mono=False, **kwargs):
super().__init__(*args, **kwargs)
self.input_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]
def build(self, input_shape):
"""Adapted from original _Conv() layer of Keras
param input_shape: list of dimensions for [img, mask]
"""
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape[0][channel_axis] is None:
raise ValueError('The channel dimension of the inputs should be defined. Found `None`.')
self.input_dim = input_shape[0][channel_axis]
# Image kernel
kernel_shape = self.kernel_size + (self.input_dim, self.filters)
self.kernel = self.add_weight(shape=kernel_shape,
initializer=self.kernel_initializer,
name='img_kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
# Mask kernel
self.kernel_mask = K.ones(shape=self.kernel_size + (self.input_dim, self.filters))
if self.use_bias:
self.bias = self.add_weight(shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.built = True
def call(self, inputs, mask=None):
'''
We will be using the Keras conv2d method, and essentially we have
to do here is multiply the mask with the input X, before we apply the
convolutions. For the mask itself, we apply convolutions with all weights
set to 1.
Subsequently, we set all mask values >0 to 1, and otherwise 0
'''
# Both image and mask must be supplied
if type(inputs) is not list or len(inputs) != 2:
raise Exception('PartialConvolution2D must be called on a list of two tensors [img, mask]. Instead got: ' + str(inputs))
# Create normalization. Slight change here compared to paper, using mean mask value instead of sum
normalization = K.mean(inputs[1], axis=[1,2], keepdims=True)
normalization = K.repeat_elements(normalization, inputs[1].shape[1], axis=1)
normalization = K.repeat_elements(normalization, inputs[1].shape[2], axis=2)
# Apply convolutions to image
img_output = K.conv2d(
(inputs[0]*inputs[1]) / normalization, self.kernel,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate
)
# Apply convolutions to mask
mask_output = K.conv2d(
inputs[1], self.kernel_mask,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate
)
# Where something happened, set 1, otherwise 0
mask_output = K.cast(K.greater(mask_output, 0), 'float32')
# Apply bias only to the image (if chosen to do so)
if self.use_bias:
img_output = K.bias_add(
img_output,
self.bias,
data_format=self.data_format)
# Apply activations on the image
if self.activation is not None:
img_output = self.activation(img_output)
return [img_output, mask_output]
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_last':
space = input_shape[0][1:-1]
new_space = []
for i in range(len(space)):
new_dim = conv_utils.conv_output_length(
space[i],
self.kernel_size[i],
padding=self.padding,
stride=self.strides[i],
dilation=self.dilation_rate[i])
new_space.append(new_dim)
new_shape = (input_shape[0][0],) + tuple(new_space) + (self.filters,)
return [new_shape, new_shape]
if self.data_format == 'channels_first':
space = input_shape[2:]
new_space = []
for i in range(len(space)):
new_dim = conv_utils.conv_output_length(
space[i],
self.kernel_size[i],
padding=self.padding,
stride=self.strides[i],
dilation=self.dilation_rate[i])
new_space.append(new_dim)
new_shape = (input_shape[0], self.filters) + tuple(new_space)
return [new_shape, new_shape]

497
ui.py Normal file
View File

@ -0,0 +1,497 @@
"""
Code illustration: 6.09
Modules imported here:
from tkinter import messagebox
from tkinter import filedialog
Attributes added here:
file_name = "untitled"
Methods modified here:
on_new_file_menu_clicked()
on_save_menu_clicked()
on_save_as_menu_clicked()
on_close_menu_clicked()
on_undo_menu_clicked()
on_canvas_zoom_in_menu_clicked()
on_canvas_zoom_out_menu_clicked()
on_about_menu_clicked()
Methods added here
start_new_project()
actual_save()
close_window()
undo()
canvas_zoom_in()
canvas_zoom_out()
@ Tkinter GUI Application Development Blueprints
"""
import math
from PIL import Image, ImageTk, ImageDraw
import tkinter as tk
from tkinter import colorchooser
from tkinter import ttk
from tkinter import messagebox
from tkinter import filedialog
import framework
import decensor
class PaintApplication(framework.Framework):
def __init__(self, root):
super().__init__(root)
self.drawn_img = None
self.screen_width = root.winfo_screenwidth()
self.screen_height = root.winfo_screenheight()
self.start_x, self.start_y = 0, 0
self.end_x, self.end_y = 0, 0
self.current_item = None
self.fill = "#00ff00"
self.fill_pil = (0,255,0,255)
self.outline = "#00ff00"
self.brush_width = 2
self.background = 'white'
self.foreground = "#00ff00"
self.file_name = "Untitled"
self.tool_bar_functions = (
"draw_line", "draw_irregular_line"
)
self.selected_tool_bar_function = self.tool_bar_functions[0]
self.create_gui()
self.bind_mouse()
def on_new_file_menu_clicked(self, event=None):
self.start_new_project()
def start_new_project(self):
self.canvas.delete(tk.ALL)
self.canvas.config(bg="#ffffff")
self.root.title('untitled')
def on_open_image_menu_clicked(self, event=None):
self.open_image()
def open_image(self):
self.file_name = filedialog.askopenfilename(master=self.root, filetypes = [("All Files","*.*")], title="Open...")
print(self.file_name)
self.canvas.img = Image.open(self.file_name)
self.canvas.img_width, self.canvas.img_height = self.canvas.img.size
#make reference to image to prevent garbage collection
#https://stackoverflow.com/questions/20061396/image-display-on-tkinter-canvas-not-working
self.canvas.tk_img = ImageTk.PhotoImage(self.canvas.img)
self.canvas.config(width=self.canvas.img_width, height=self.canvas.img_height)
self.canvas.create_image(self.canvas.img_width/2.0,self.canvas.img_height/2.0,image=self.canvas.tk_img)
self.drawn_img = Image.new("RGBA", self.canvas.img.size)
self.drawn_img_draw = ImageDraw.Draw(self.drawn_img)
def on_import_mask_clicked(self, event=None):
self.import_mask()
def display_canvas(self):
composite_img = Image.alpha_composite(self.canvas.img.convert('RGBA'), self.drawn_img).convert('RGB')
self.canvas.tk_img = ImageTk.PhotoImage(composite_img)
self.canvas.create_image(self.canvas.img_width/2.0,self.canvas.img_height/2.0,image=self.canvas.tk_img)
def import_mask(self):
file_name_mask = filedialog.askopenfilename(master=self.root, filetypes = [("All Files","*.*")], title="Import mask...")
mask_img = Image.open(file_name_mask)
if (mask_img.size != self.canvas.img.size):
messagebox.showerror("Import mask", "Mask image size does not match the original image size! Mask image not imported.")
return
self.drawn_img = mask_img
self.drawn_img_draw = ImageDraw.Draw(self.drawn_img)
self.display_canvas()
def on_save_menu_clicked(self, event=None):
if self.file_name == 'untitled':
self.on_save_as_menu_clicked()
else:
self.actual_save()
def on_save_as_menu_clicked(self):
file_name = filedialog.asksaveasfilename(
master=self.root, filetypes=[('All Files', ('*.ps', '*.ps'))], title="Save...")
if not file_name:
return
self.file_name = file_name
self.actual_save()
def actual_save(self):
self.canvas.postscript(file=self.file_name, colormode='color')
self.root.title(self.file_name)
def on_close_menu_clicked(self):
self.close_window()
def close_window(self):
if messagebox.askokcancel("Quit", "Do you really want to quit?"):
self.root.destroy()
def on_undo_menu_clicked(self, event=None):
self.undo()
def undo(self):
items_stack = list(self.canvas.find("all"))
try:
last_item_id = items_stack.pop()
except IndexError:
return
self.canvas.delete(last_item_id)
def on_canvas_zoom_in_menu_clicked(self):
self.canvas_zoom_in()
def on_canvas_zoom_out_menu_clicked(self):
self.canvas_zoom_out()
def canvas_zoom_in(self):
self.canvas.scale("all", 0, 0, 1.2, 1.2)
self.canvas.config(scrollregion=self.canvas.bbox(tk.ALL))
self.canvas.pack(side=tk.RIGHT, expand=tk.YES, fill=tk.BOTH)
def canvas_zoom_out(self):
self.canvas.scale("all", 0, 0, .8, .8)
self.canvas.config(scrollregion=self.canvas.bbox(tk.ALL))
self.canvas.pack(side=tk.RIGHT, expand=tk.YES, fill=tk.BOTH)
def on_decensor_menu_clicked(self, event=None):
combined_img = Image.alpha_composite(self.canvas.img.convert('RGBA'), self.drawn_img)
decensorer = decensor.Decensor()
decensorer.decensor_image(combined_img.convert('RGB'), self.file_name + ".png")
messagebox.showinfo(
"Decensoring", "Decensoring complete!")
def on_about_menu_clicked(self, event=None):
# messagebox.showinfo(
# "Decensoring", "Decensoring in progress.")
messagebox.showinfo(
"About", "Tkinter GUI Application\n Development Blueprints")
def get_all_configurations_for_item(self):
configuration_dict = {}
for key, value in self.canvas.itemconfig("current").items():
if value[-1] and value[-1] not in ["0", "0.0", "0,0", "current"]:
configuration_dict[key] = value[-1]
return configuration_dict
def canvas_function_wrapper(self, function_name, *arg, **kwargs):
func = getattr(self.canvas, function_name)
func(*arg, **kwargs)
def adjust_canvas_coords(self, x_coordinate, y_coordinate):
# low_x, high_x = self.x_scroll.get()
# percent_x = low_x/(1+low_x-high_x)
# low_y, high_y = self.y_scroll.get()
# percent_y = low_y/(1+low_y-high_y)
low_x, high_x = self.x_scroll.get()
low_y, high_y = self.y_scroll.get()
#length_y = high_y - low_y
return low_x * 800 + x_coordinate, low_y * 800 + y_coordinate
def create_circle(self, x, y, r, **kwargs):
return self.canvas.create_oval(x-r, y-r, x+r, y+r, **kwargs)
def draw_irregular_line(self):
# self.current_item = self.canvas.create_line(
# self.start_x, self.start_y, self.end_x, self.end_y, fill=self.fill, width=self.brush_width)
# self.current_item = self.create_circle(self.end_x, self.end_y, self.brush_width/2.0, fill=self.fill, width=0)
#draw in PIL
self.drawn_img_draw.line((self.start_x, self.start_y, self.end_x, self.end_y), fill=self.fill_pil, width=int(self.brush_width))
self.drawn_img_draw.ellipse((self.end_x - self.brush_width/2.0, self.end_y - self.brush_width/2.0, self.end_x + self.brush_width/2.0, self.end_y + self.brush_width/2.0), fill=self.fill_pil)
self.display_canvas()
# composite_img = Image.alpha_composite(self.canvas.img.convert('RGBA'), self.drawn_img).convert('RGB')
# self.canvas.tk_img = ImageTk.PhotoImage(composite_img)
# self.canvas.create_image(self.canvas.img_width/2.0,self.canvas.img_height/2.0,image=self.canvas.tk_img)
self.canvas.bind("<B1-Motion>", self.draw_irregular_line_update_x_y)
def draw_irregular_line_update_x_y(self, event=None):
self.start_x, self.start_y = self.end_x, self.end_y
self.end_x, self.end_y = self.adjust_canvas_coords(event.x, event.y)
self.draw_irregular_line()
def draw_irregular_line_options(self):
self.create_fill_options_combobox()
self.create_width_options_combobox()
def on_tool_bar_button_clicked(self, button_index):
self.selected_tool_bar_function = self.tool_bar_functions[button_index]
self.remove_options_from_top_bar()
self.display_options_in_the_top_bar()
self.bind_mouse()
def float_range(self, x, y, step):
while x < y:
yield x
x += step
def set_foreground_color(self, event=None):
self.foreground = self.get_color_from_chooser(
self.foreground, "foreground")
self.color_palette.itemconfig(
self.foreground_palette, width=0, fill=self.foreground)
def set_background_color(self, event=None):
self.background = self.get_color_from_chooser(
self.background, "background")
self.color_palette.itemconfig(
self.background_palette, width=0, fill=self.background)
def get_color_from_chooser(self, initial_color, color_type="a"):
color = colorchooser.askcolor(
color=initial_color,
title="select {} color".format(color_type)
)[-1]
if color:
return color
# dialog has been cancelled
else:
return initial_color
def try_to_set_fill_after_palette_change(self):
try:
self.set_fill()
except:
pass
def try_to_set_outline_after_palette_change(self):
try:
self.set_outline()
except:
pass
def display_options_in_the_top_bar(self):
self.show_selected_tool_icon_in_top_bar(
self.selected_tool_bar_function)
options_function_name = "{}_options".format(self.selected_tool_bar_function)
func = getattr(self, options_function_name, self.function_not_defined)
func()
def draw_line_options(self):
self.create_fill_options_combobox()
self.create_width_options_combobox()
def create_fill_options_combobox(self):
tk.Label(self.top_bar, text='Fill:').pack(side="left")
self.fill_combobox = ttk.Combobox(
self.top_bar, state='readonly', width=5)
self.fill_combobox.pack(side="left")
self.fill_combobox['values'] = ('none', 'fg', 'bg', 'black', 'white')
self.fill_combobox.bind('<<ComboboxSelected>>', self.set_fill)
self.fill_combobox.set(self.fill)
def create_outline_options_combobox(self):
tk.Label(self.top_bar, text='Outline:').pack(side="left")
self.outline_combobox = ttk.Combobox(
self.top_bar, state='readonly', width=5)
self.outline_combobox.pack(side="left")
self.outline_combobox['values'] = (
'none', 'fg', 'bg', 'black', 'white')
self.outline_combobox.bind('<<ComboboxSelected>>', self.set_outline)
self.outline_combobox.set(self.outline)
def create_width_options_combobox(self):
tk.Label(self.top_bar, text='Width:').pack(side="left")
self.width_combobox = ttk.Combobox(
self.top_bar, state='readonly', width=3)
self.width_combobox.pack(side="left")
self.width_combobox['values'] = (
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50)
self.width_combobox.bind('<<ComboboxSelected>>', self.set_brush_width)
self.width_combobox.set(self.brush_width)
def set_fill(self, event=None):
fill_color = self.fill_combobox.get()
if fill_color == 'none':
self.fill = '' # transparent
elif fill_color == 'fg':
self.fill = self.foreground
elif fill_color == 'bg':
self.fill = self.background
else:
self.fill = fill_color
def set_outline(self, event=None):
outline_color = self.outline_combobox.get()
if outline_color == 'none':
self.outline = '' # transparent
elif outline_color == 'fg':
self.outline = self.foreground
elif outline_color == 'bg':
self.outline = self.background
else:
self.outline = outline_color
def set_brush_width(self, event):
self.brush_width = float(self.width_combobox.get())
def create_color_palette(self):
self.color_palette = tk.Canvas(self.tool_bar, height=55, width=55)
self.color_palette.grid(row=10, column=1, columnspan=2, pady=5, padx=3)
self.background_palette = self.color_palette.create_rectangle(
15, 15, 48, 48, outline=self.background, fill=self.background)
self.foreground_palette = self.color_palette.create_rectangle(
1, 1, 33, 33, outline=self.foreground, fill=self.foreground)
self.bind_color_palette()
def bind_color_palette(self):
self.color_palette.tag_bind(
self.background_palette, "<Button-1>", self.set_background_color)
self.color_palette.tag_bind(
self.foreground_palette, "<Button-1>", self.set_foreground_color)
def create_current_coordinate_label(self):
self.current_coordinate_label = tk.Label(
self.tool_bar, text='x:0\ny: 0 ')
self.current_coordinate_label.grid(
row=13, column=1, columnspan=2, pady=5, padx=1, sticky='w')
def show_current_coordinates(self, event=None):
x_coordinate = event.x
y_coordinate = event.y
coordinate_string = "x:{0}\ny:{1}".format(x_coordinate, y_coordinate)
self.current_coordinate_label.config(text=coordinate_string)
def function_not_defined(self):
pass
def execute_selected_method(self):
self.current_item = None
func = getattr(
self, self.selected_tool_bar_function, self.function_not_defined)
func()
def draw_line(self):
self.current_item = self.canvas.create_line(
self.start_x, self.start_y, self.end_x, self.end_y, fill=self.fill, width=self.brush_width)
# self.drawn_img_draw.line((self.start_x, self.start_y, self.end_x, self.end_y), fill=self.fill_pil, width=self.brush_width)
def create_tool_bar_buttons(self):
for index, name in enumerate(self.tool_bar_functions):
icon = tk.PhotoImage(file='icons/' + name + '.gif')
self.button = tk.Button(
self.tool_bar, image=icon, command=lambda index=index: self.on_tool_bar_button_clicked(index))
self.button.grid(
row=index // 2, column=1 + index % 2, sticky='nsew')
self.button.image = icon
def remove_options_from_top_bar(self):
for child in self.top_bar.winfo_children():
child.destroy()
def show_selected_tool_icon_in_top_bar(self, function_name):
display_name = function_name.replace("_", " ").capitalize() + ":"
tk.Label(self.top_bar, text=display_name).pack(side="left")
photo = tk.PhotoImage(
file='icons/' + function_name + '.gif')
label = tk.Label(self.top_bar, image=photo)
label.image = photo
label.pack(side="left")
def bind_mouse(self):
self.canvas.bind("<Button-1>", self.on_mouse_button_pressed)
self.canvas.bind(
"<Button1-Motion>", self.on_mouse_button_pressed_motion)
self.canvas.bind(
"<Button1-ButtonRelease>", self.on_mouse_button_released)
self.canvas.bind("<Motion>", self.on_mouse_unpressed_motion)
def on_mouse_button_pressed(self, event):
self.start_x = self.end_x = self.canvas.canvasx(event.x)
self.start_y = self.end_y = self.canvas.canvasy(event.y)
self.execute_selected_method()
def on_mouse_button_pressed_motion(self, event):
self.end_x = self.canvas.canvasx(event.x)
self.end_y = self.canvas.canvasy(event.y)
self.canvas.delete(self.current_item)
self.execute_selected_method()
def on_mouse_button_released(self, event):
self.end_x = self.canvas.canvasx(event.x)
self.end_y = self.canvas.canvasy(event.y)
def on_mouse_unpressed_motion(self, event):
self.show_current_coordinates(event)
def create_gui(self):
self.create_menu()
self.create_top_bar()
self.create_tool_bar()
self.create_tool_bar_buttons()
self.create_drawing_canvas()
self.create_color_palette()
self.create_current_coordinate_label()
self.bind_menu_accelrator_keys()
self.show_selected_tool_icon_in_top_bar("draw_line")
self.draw_line_options()
def create_menu(self):
self.menubar = tk.Menu(self.root)
menu_definitions = (
'File- &New/Ctrl+N/self.on_new_file_menu_clicked, Open/Ctrl+O/self.on_open_image_menu_clicked, Import Mask/Ctrl+M/self.on_import_mask_clicked, Save/Ctrl+S/self.on_save_menu_clicked, SaveAs/ /self.on_save_as_menu_clicked, sep, Exit/Alt+F4/self.on_close_menu_clicked',
'Edit- Undo/Ctrl+Z/self.on_undo_menu_clicked, sep',
'View- Zoom in//self.on_canvas_zoom_in_menu_clicked,Zoom Out//self.on_canvas_zoom_out_menu_clicked',
'Decensor- Decensor/Ctrl+D/self.on_decensor_menu_clicked',
'About- About/F1/self.on_about_menu_clicked'
)
self.build_menu(menu_definitions)
def create_top_bar(self):
self.top_bar = tk.Frame(self.root, height=25, relief="raised")
self.top_bar.pack(fill="x", side="top", pady=2)
def create_tool_bar(self):
self.tool_bar = tk.Frame(self.root, relief="raised", width=50)
self.tool_bar.pack(fill="y", side="left", pady=3)
def create_drawing_canvas(self):
self.canvas_frame = tk.Frame(self.root, width=900, height=900)
self.canvas_frame.pack(side="right", expand="yes", fill="both")
self.canvas = tk.Canvas(self.canvas_frame, background="white",
width=512, height=512, scrollregion=(0, 0, 512, 512))
self.create_scroll_bar()
self.canvas.pack(side=tk.RIGHT, expand=tk.YES, fill=tk.BOTH)
self.canvas.img = Image.open('./icons/canvas_top_test.png').convert('RGBA')
self.canvas.img = self.canvas.img.resize((512,512))
self.canvas.tk_img = ImageTk.PhotoImage(self.canvas.img)
self.canvas.create_image(256,256,image=self.canvas.tk_img)
def create_scroll_bar(self):
self.x_scroll = tk.Scrollbar(self.canvas_frame, orient="horizontal")
self.x_scroll.pack(side="bottom", fill="x")
self.x_scroll.config(command=self.canvas.xview)
self.y_scroll = tk.Scrollbar(self.canvas_frame, orient="vertical")
self.y_scroll.pack(side="right", fill="y")
self.y_scroll.config(command=self.canvas.yview)
self.canvas.config(
xscrollcommand=self.x_scroll.set, yscrollcommand=self.y_scroll.set)
def bind_menu_accelrator_keys(self):
self.root.bind('<KeyPress-F1>', self.on_about_menu_clicked)
self.root.bind('<Control-N>', self.on_new_file_menu_clicked)
self.root.bind('<Control-n>', self.on_new_file_menu_clicked)
self.root.bind('<Control-s>', self.on_save_menu_clicked)
self.root.bind('<Control-S>', self.on_save_menu_clicked)
self.root.bind('<Control-z>', self.on_undo_menu_clicked)
self.root.bind('<Control-Z>', self.on_undo_menu_clicked)
if __name__ == '__main__':
root = tk.Tk()
app = PaintApplication(root)
root.mainloop()