mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2025-03-23 13:20:54 +00:00
Add files
This commit is contained in:
parent
44031fb53b
commit
40d3a50d45
29
config.py
Normal file
29
config.py
Normal file
@ -0,0 +1,29 @@
|
||||
import argparse
|
||||
|
||||
def str2bool(v):
|
||||
if v.lower() in ('yes', 'true', 't', 'y', '1', True):
|
||||
return True
|
||||
elif v.lower() in ('no', 'false', 'f', 'n', '0', False):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
|
||||
#Input output folders settings
|
||||
parser.add_argument('--decensor_input_path', dest='decensor_input_path', default='./decensor_input/', help='input images to be decensored by decensor.py path')
|
||||
parser.add_argument('--decensor_output_path', dest='decensor_output_path', default='./decensor_output/', help='output images generated from running decensor.py path')
|
||||
|
||||
#Decensor settings
|
||||
parser.add_argument('--mask_color_red', dest='mask_color_red', default=0, help='red channel of mask color in decensoring')
|
||||
parser.add_argument('--mask_color_green', dest='mask_color_green', default=255, help='green channel of mask color in decensoring')
|
||||
parser.add_argument('--mask_color_blue', dest='mask_color_blue', default=0, help='blue channel of mask color in decensoring')
|
||||
|
||||
parser.add_argument('--is_mosaic', dest='is_mosaic', default='False', type=str2bool, help='true if image has mosaic censoring, false otherwise')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
if __name__ == '__main__':
|
||||
get_args()
|
165
decensor.py
Normal file
165
decensor.py
Normal file
@ -0,0 +1,165 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import config
|
||||
from libs.pconv_hybrid_model import PConvUnet
|
||||
from libs.flood_fill import find_regions, expand_bounding
|
||||
|
||||
class Decensor():
|
||||
|
||||
def __init__(self):
|
||||
self.args = config.get_args()
|
||||
self.decensor_mosaic = self.args.is_mosaic
|
||||
|
||||
self.mask_color = [self.args.mask_color_red/255.0, self.args.mask_color_green/255.0, self.args.mask_color_blue/255.0]
|
||||
|
||||
if not os.path.exists(self.args.decensor_output_path):
|
||||
os.makedirs(self.args.decensor_output_path)
|
||||
|
||||
self.load_model()
|
||||
|
||||
def get_mask(self,ori, width, height):
|
||||
mask = np.zeros(ori.shape, np.uint8)
|
||||
#count = 0
|
||||
#TODO: change to iterate over all images in batch when implementing batches
|
||||
for row in range(height):
|
||||
for col in range(width):
|
||||
if np.array_equal(ori[0][row][col], self.mask_color):
|
||||
mask[0, row, col] = 1
|
||||
return 1-mask
|
||||
|
||||
def load_model(self):
|
||||
self.model = PConvUnet(weight_filepath='data/logs/')
|
||||
self.model.load(
|
||||
r"./models/model.h5",
|
||||
train_bn=False,
|
||||
lr=0.00005
|
||||
)
|
||||
|
||||
def decensor_all_images_in_folder(self):
|
||||
#load model once at beginning and reuse same model
|
||||
#self.load_model()
|
||||
|
||||
subdir = self.args.decensor_input_path
|
||||
files = os.listdir(subdir)
|
||||
|
||||
#convert all images into np arrays and put them in a list
|
||||
for file in files:
|
||||
#print(file)
|
||||
file_path = os.path.join(subdir, file)
|
||||
if os.path.isfile(file_path) and os.path.splitext(file_path)[1] == ".png":
|
||||
print("Decensoring the image {file_path}".format(file_path))
|
||||
censored_img = Image.open(file_path)
|
||||
self.decensor_image(censored_img, file)
|
||||
|
||||
#decensors one image at a time
|
||||
#TODO: decensor all cropped parts of the same image in a batch (then i need input for ori an array of those images and make additional changes)
|
||||
def decensor_image(self,ori, file_name):
|
||||
width, height = ori.size
|
||||
#save the alpha channel if the image has an alpha channel
|
||||
has_alpha = False
|
||||
alpha_channel = None
|
||||
if (ori.mode == "RGBA"):
|
||||
has_alpha = True
|
||||
alpha_channel = np.asarray(ori)[:,:,3]
|
||||
alpha_channel = np.expand_dims(alpha_channel, axis =-1)
|
||||
ori = ori.convert('RGB')
|
||||
|
||||
ori_array = np.asarray(ori)
|
||||
ori_array = np.array(ori_array / 255.0)
|
||||
ori_array = np.expand_dims(ori_array, axis = 0)
|
||||
|
||||
mask = self.get_mask(ori_array, width, height)
|
||||
|
||||
regions = find_regions(ori)
|
||||
print("Found {region_count} censored regions in this image!".format(region_count = len(regions)))
|
||||
|
||||
if len(regions) == 0 and not self.decensor_mosaic:
|
||||
print("No green colored regions detected!")
|
||||
return
|
||||
|
||||
output_img_array = ori_array[0].copy()
|
||||
|
||||
for region_counter, region in enumerate(regions, 1):
|
||||
bounding_box = expand_bounding(ori, region)
|
||||
crop_img = ori.crop(bounding_box)
|
||||
#crop_img.show()
|
||||
#convert mask back to image
|
||||
mask_reshaped = mask[0,:,:,:] * 255.0
|
||||
mask_img = Image.fromarray(mask_reshaped.astype('uint8'))
|
||||
#resize the cropped images
|
||||
crop_img = crop_img.resize((512, 512))
|
||||
crop_img_array = np.asarray(crop_img)
|
||||
crop_img_array = crop_img_array / 255.0
|
||||
crop_img_array = np.expand_dims(crop_img_array, axis = 0)
|
||||
#resize the mask images
|
||||
mask_img = mask_img.crop(bounding_box)
|
||||
mask_img = mask_img.resize((512, 512))
|
||||
#mask_img.show()
|
||||
#convert mask_img back to array
|
||||
mask_array = np.asarray(mask_img)
|
||||
mask_array = np.array(mask_array / 255.0)
|
||||
#the mask has been upscaled so there will be values not equal to 0 or 1
|
||||
#mask_array[mask_array < 0.01] = 0
|
||||
mask_array[mask_array > 0] = 1
|
||||
mask_array = np.expand_dims(mask_array, axis = 0)
|
||||
|
||||
# Run predictions for this batch of images
|
||||
pred_img_array = self.model.predict([crop_img_array, mask_array, mask_array])
|
||||
|
||||
pred_img_array = pred_img_array * 255.0
|
||||
pred_img_array = np.squeeze(pred_img_array, axis = 0)
|
||||
|
||||
#scale prediction image back to original size
|
||||
bounding_width = bounding_box[2]-bounding_box[0]
|
||||
bounding_height = bounding_box[3]-bounding_box[1]
|
||||
#convert np array to image
|
||||
|
||||
# print(bounding_width,bounding_height)
|
||||
# print(pred_img_array.shape)
|
||||
|
||||
pred_img = Image.fromarray(pred_img_array.astype('uint8'))
|
||||
#pred_img.show()
|
||||
pred_img = pred_img.resize((bounding_width, bounding_height), resample = Image.BICUBIC)
|
||||
pred_img_array = np.asarray(pred_img)
|
||||
pred_img_array = pred_img_array / 255.0
|
||||
|
||||
# print(pred_img_array.shape)
|
||||
pred_img_array = np.expand_dims(pred_img_array, axis = 0)
|
||||
|
||||
for i in range(len(ori_array)):
|
||||
if self.decensor_mosaic:
|
||||
output_img_array = pred_img[i]
|
||||
else:
|
||||
for col in range(bounding_width):
|
||||
for row in range(bounding_height):
|
||||
bounding_width = col + bounding_box[0]
|
||||
bounding_height = row + bounding_box[1]
|
||||
if (bounding_width, bounding_height) in region:
|
||||
output_img_array[bounding_height][bounding_width] = pred_img_array[i,:,:,:][row][col]
|
||||
print("{region_counter} out of {region_count} regions decensored.".format(region_counter=region_counter, region_count=len(regions)))
|
||||
|
||||
output_img_array = output_img_array * 255.0
|
||||
|
||||
#restore the alpha channel
|
||||
if has_alpha:
|
||||
print(output_img_array.shape)
|
||||
print(alpha_channel.shape)
|
||||
output_img_array = np.concatenate((output_img_array, alpha_channel), axis = 2)
|
||||
|
||||
output_img = Image.fromarray(output_img_array.astype('uint8'))
|
||||
|
||||
#save the decensored image
|
||||
#file_name, _ = os.path.splitext(file_name)
|
||||
save_path = os.path.join(self.args.decensor_output_path, file_name)
|
||||
output_img.save(save_path)
|
||||
|
||||
print("Decensored image saved to {save_path}!".format(save_path=save_path))
|
||||
return
|
||||
|
||||
if __name__ == '__main__':
|
||||
decensor = Decensor()
|
||||
decensor.decensor_all_images_in_folder()
|
128
flood_fill.py
Normal file
128
flood_fill.py
Normal file
@ -0,0 +1,128 @@
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
#find strongly connected components with the mask color
|
||||
def find_regions(image):
|
||||
pixel = image.load()
|
||||
neighbors = dict()
|
||||
width, height = image.size
|
||||
for x in range(width):
|
||||
for y in range(height):
|
||||
if is_green(pixel[x,y]):
|
||||
neighbors[x, y] = {(x,y)}
|
||||
for x, y in neighbors:
|
||||
candidates = (x + 1, y), (x, y + 1)
|
||||
for candidate in candidates:
|
||||
if candidate in neighbors:
|
||||
neighbors[x, y].add(candidate)
|
||||
neighbors[candidate].add((x, y))
|
||||
closed_list = set()
|
||||
|
||||
def connected_component(pixel):
|
||||
region = set()
|
||||
open_list = {pixel}
|
||||
while open_list:
|
||||
pixel = open_list.pop()
|
||||
closed_list.add(pixel)
|
||||
open_list |= neighbors[pixel] - closed_list
|
||||
region.add(pixel)
|
||||
return region
|
||||
|
||||
regions = []
|
||||
for pixel in neighbors:
|
||||
if pixel not in closed_list:
|
||||
regions.append(connected_component(pixel))
|
||||
regions.sort(key = len, reverse = True)
|
||||
return regions
|
||||
|
||||
# risk of box being bigger than the image
|
||||
def expand_bounding(img, region, expand_factor=1.5, min_size = 256, max_size=512):
|
||||
#expand bounding box to capture more context
|
||||
x, y = zip(*region)
|
||||
min_x, min_y, max_x, max_y = min(x), min(y), max(x), max(y)
|
||||
width, height = img.size
|
||||
width_center = width//2
|
||||
height_center = height//2
|
||||
bb_width = max_x - min_x
|
||||
bb_height = max_y - min_y
|
||||
x_center = (min_x + max_x)//2
|
||||
y_center = (min_y + max_y)//2
|
||||
current_size = max(bb_width, bb_height)
|
||||
current_size = int(current_size * expand_factor)
|
||||
if current_size > max_size:
|
||||
current_size = max_size
|
||||
elif current_size < min_size:
|
||||
current_size = min_size
|
||||
x1 = x_center - current_size//2
|
||||
x2 = x_center + current_size//2
|
||||
y1 = y_center - current_size//2
|
||||
y2 = y_center + current_size//2
|
||||
x1_square = x1
|
||||
y1_square = y1
|
||||
x2_square = x2
|
||||
y2_square = y2
|
||||
#move bounding boxes that are partially outside of the image inside the image
|
||||
if (y1_square < 0 or y2_square > (height - 1)) and (x1_square < 0 or x2_square > (width - 1)):
|
||||
#conservative square region
|
||||
if x1_square < 0 and y1_square < 0:
|
||||
x1_square = 0
|
||||
y1_square = 0
|
||||
x2_square = current_size
|
||||
y2_square = current_size
|
||||
elif x2_square > (width - 1) and y2_square > (height - 1):
|
||||
x1_square = width - current_size - 1
|
||||
y1_square = 0
|
||||
x2_square = width - 1
|
||||
y2_square = current_size
|
||||
elif x1_square < 0 and y2_square > (height - 1):
|
||||
x1_square = 0
|
||||
y1_square = height - current_size - 1
|
||||
x2_square = current_size
|
||||
y2_square = height - 1
|
||||
elif x2_square > (width - 1) and y2_square > (height - 1):
|
||||
x1_square = width - current_size - 1
|
||||
y1_square = height - current_size - 1
|
||||
x2_square = width - 1
|
||||
y2_square = height - 1
|
||||
else:
|
||||
x1_square = x1
|
||||
y1_square = y1
|
||||
x2_square = x2
|
||||
y2_square = y2
|
||||
else:
|
||||
if x1_square < 0:
|
||||
difference = x1_square
|
||||
x1_square -= difference
|
||||
x2_square -= difference
|
||||
if x2_square > (width - 1):
|
||||
difference = x2_square - width + 1
|
||||
x1_square -= difference
|
||||
x2_square -= difference
|
||||
if y1_square < 0:
|
||||
difference = y1_square
|
||||
y1_square -= difference
|
||||
y2_square -= difference
|
||||
if y2_square > (height - 1):
|
||||
difference = y2_square - height + 1
|
||||
y1_square -= difference
|
||||
y2_square -= difference
|
||||
# if y1_square < 0 or y2_square > (height - 1):
|
||||
|
||||
#if bounding box goes outside of the image for some reason, set bounds to original, unexpanded values
|
||||
#print(width, height)
|
||||
if x2_square > width or y2_square > height:
|
||||
print("bounding box out of bounds!")
|
||||
print(x1_square, y1_square, x2_square, y2_square)
|
||||
x1_square, y1_square, x2_square, y2_square = min_x, min_y, max_x, max_y
|
||||
return x1_square, y1_square, x2_square, y2_square
|
||||
|
||||
def is_green(pixel):
|
||||
r, g, b = pixel
|
||||
return r == 0 and g == 255 and b == 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = Image.open('')
|
||||
no_alpha_image = image.convert('RGB')
|
||||
draw = ImageDraw.Draw(no_alpha_image)
|
||||
for region in find_regions(no_alpha_image):
|
||||
draw.rectangle(expand_bounding(no_alpha_image, region), outline=(0, 255, 0))
|
||||
no_alpha_image.show()
|
128
libs/flood_fill.py
Normal file
128
libs/flood_fill.py
Normal file
@ -0,0 +1,128 @@
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
#find strongly connected components with the mask color
|
||||
def find_regions(image):
|
||||
pixel = image.load()
|
||||
neighbors = dict()
|
||||
width, height = image.size
|
||||
for x in range(width):
|
||||
for y in range(height):
|
||||
if is_green(pixel[x,y]):
|
||||
neighbors[x, y] = {(x,y)}
|
||||
for x, y in neighbors:
|
||||
candidates = (x + 1, y), (x, y + 1)
|
||||
for candidate in candidates:
|
||||
if candidate in neighbors:
|
||||
neighbors[x, y].add(candidate)
|
||||
neighbors[candidate].add((x, y))
|
||||
closed_list = set()
|
||||
|
||||
def connected_component(pixel):
|
||||
region = set()
|
||||
open_list = {pixel}
|
||||
while open_list:
|
||||
pixel = open_list.pop()
|
||||
closed_list.add(pixel)
|
||||
open_list |= neighbors[pixel] - closed_list
|
||||
region.add(pixel)
|
||||
return region
|
||||
|
||||
regions = []
|
||||
for pixel in neighbors:
|
||||
if pixel not in closed_list:
|
||||
regions.append(connected_component(pixel))
|
||||
regions.sort(key = len, reverse = True)
|
||||
return regions
|
||||
|
||||
# risk of box being bigger than the image
|
||||
def expand_bounding(img, region, expand_factor=1.5, min_size = 256, max_size=512):
|
||||
#expand bounding box to capture more context
|
||||
x, y = zip(*region)
|
||||
min_x, min_y, max_x, max_y = min(x), min(y), max(x), max(y)
|
||||
width, height = img.size
|
||||
width_center = width//2
|
||||
height_center = height//2
|
||||
bb_width = max_x - min_x
|
||||
bb_height = max_y - min_y
|
||||
x_center = (min_x + max_x)//2
|
||||
y_center = (min_y + max_y)//2
|
||||
current_size = max(bb_width, bb_height)
|
||||
current_size = int(current_size * expand_factor)
|
||||
if current_size > max_size:
|
||||
current_size = max_size
|
||||
elif current_size < min_size:
|
||||
current_size = min_size
|
||||
x1 = x_center - current_size//2
|
||||
x2 = x_center + current_size//2
|
||||
y1 = y_center - current_size//2
|
||||
y2 = y_center + current_size//2
|
||||
x1_square = x1
|
||||
y1_square = y1
|
||||
x2_square = x2
|
||||
y2_square = y2
|
||||
#move bounding boxes that are partially outside of the image inside the image
|
||||
if (y1_square < 0 or y2_square > (height - 1)) and (x1_square < 0 or x2_square > (width - 1)):
|
||||
#conservative square region
|
||||
if x1_square < 0 and y1_square < 0:
|
||||
x1_square = 0
|
||||
y1_square = 0
|
||||
x2_square = current_size
|
||||
y2_square = current_size
|
||||
elif x2_square > (width - 1) and y2_square > (height - 1):
|
||||
x1_square = width - current_size - 1
|
||||
y1_square = 0
|
||||
x2_square = width - 1
|
||||
y2_square = current_size
|
||||
elif x1_square < 0 and y2_square > (height - 1):
|
||||
x1_square = 0
|
||||
y1_square = height - current_size - 1
|
||||
x2_square = current_size
|
||||
y2_square = height - 1
|
||||
elif x2_square > (width - 1) and y2_square > (height - 1):
|
||||
x1_square = width - current_size - 1
|
||||
y1_square = height - current_size - 1
|
||||
x2_square = width - 1
|
||||
y2_square = height - 1
|
||||
else:
|
||||
x1_square = x1
|
||||
y1_square = y1
|
||||
x2_square = x2
|
||||
y2_square = y2
|
||||
else:
|
||||
if x1_square < 0:
|
||||
difference = x1_square
|
||||
x1_square -= difference
|
||||
x2_square -= difference
|
||||
if x2_square > (width - 1):
|
||||
difference = x2_square - width + 1
|
||||
x1_square -= difference
|
||||
x2_square -= difference
|
||||
if y1_square < 0:
|
||||
difference = y1_square
|
||||
y1_square -= difference
|
||||
y2_square -= difference
|
||||
if y2_square > (height - 1):
|
||||
difference = y2_square - height + 1
|
||||
y1_square -= difference
|
||||
y2_square -= difference
|
||||
# if y1_square < 0 or y2_square > (height - 1):
|
||||
|
||||
#if bounding box goes outside of the image for some reason, set bounds to original, unexpanded values
|
||||
#print(width, height)
|
||||
if x2_square > width or y2_square > height:
|
||||
print("bounding box out of bounds!")
|
||||
print(x1_square, y1_square, x2_square, y2_square)
|
||||
x1_square, y1_square, x2_square, y2_square = min_x, min_y, max_x, max_y
|
||||
return x1_square, y1_square, x2_square, y2_square
|
||||
|
||||
def is_green(pixel):
|
||||
r, g, b = pixel
|
||||
return r == 0 and g == 255 and b == 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = Image.open('')
|
||||
no_alpha_image = image.convert('RGB')
|
||||
draw = ImageDraw.Draw(no_alpha_image)
|
||||
for region in find_regions(no_alpha_image):
|
||||
draw.rectangle(expand_bounding(no_alpha_image, region), outline=(0, 255, 0))
|
||||
no_alpha_image.show()
|
276
libs/pconv_hybrid_model.py
Normal file
276
libs/pconv_hybrid_model.py
Normal file
@ -0,0 +1,276 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from keras.models import Model
|
||||
from keras.models import load_model
|
||||
from keras.optimizers import Adam
|
||||
from keras.layers import Input, Conv2D, UpSampling2D, Dropout, LeakyReLU, BatchNormalization, Activation
|
||||
from keras.layers.merge import Concatenate
|
||||
from keras.applications import VGG16
|
||||
from keras import backend as K
|
||||
from libs.pconv_layer import PConv2D
|
||||
|
||||
|
||||
class PConvUnet(object):
|
||||
|
||||
def __init__(self, img_rows=512, img_cols=512, weight_filepath=None):
|
||||
"""Create the PConvUnet. If variable image size, set img_rows and img_cols to None"""
|
||||
|
||||
# Settings
|
||||
self.weight_filepath = weight_filepath
|
||||
self.img_rows = img_rows
|
||||
self.img_cols = img_cols
|
||||
assert self.img_rows >= 256, 'Height must be >256 pixels'
|
||||
assert self.img_cols >= 256, 'Width must be >256 pixels'
|
||||
|
||||
# Set current epoch
|
||||
self.current_epoch = 0
|
||||
|
||||
# VGG layers to extract features from (first maxpooling layers, see pp. 7 of paper)
|
||||
self.vgg_layers = [3, 6, 10]
|
||||
|
||||
# Get the vgg16 model for perceptual loss
|
||||
self.vgg = self.build_vgg()
|
||||
|
||||
# Create UNet-like model
|
||||
self.model = self.build_pconv_unet()
|
||||
|
||||
def build_vgg(self):
|
||||
"""
|
||||
Load pre-trained VGG16 from keras applications
|
||||
Extract features to be used in loss function from last conv layer, see architecture at:
|
||||
https://github.com/keras-team/keras/blob/master/keras/applications/vgg16.py
|
||||
"""
|
||||
# Input image to extract features from
|
||||
img = Input(shape=(self.img_rows, self.img_cols, 3))
|
||||
|
||||
# Get the vgg network from Keras applications
|
||||
vgg = VGG16(weights="imagenet", include_top=False)
|
||||
|
||||
# Output the first three pooling layers
|
||||
vgg.outputs = [vgg.layers[i].output for i in self.vgg_layers]
|
||||
|
||||
# Create model and compile
|
||||
model = Model(inputs=img, outputs=vgg(img))
|
||||
model.trainable = False
|
||||
model.compile(loss='mse', optimizer='adam')
|
||||
|
||||
return model
|
||||
|
||||
def build_pconv_unet(self, train_bn=True, lr=0.0002):
|
||||
|
||||
# INPUTS
|
||||
inputs_img = Input((self.img_rows, self.img_cols, 3))
|
||||
inputs_mask = Input((self.img_rows, self.img_cols, 3))
|
||||
loss_mask = Input((self.img_rows, self.img_cols, 3))
|
||||
|
||||
# ENCODER
|
||||
def encoder_layer(img_in, mask_in, filters, kernel_size, bn=True):
|
||||
conv, mask = PConv2D(filters, kernel_size, strides=2, padding='same')([img_in, mask_in])
|
||||
if bn:
|
||||
conv = BatchNormalization(name='EncBN'+str(encoder_layer.counter))(conv, training=train_bn)
|
||||
conv = Activation('relu')(conv)
|
||||
encoder_layer.counter += 1
|
||||
return conv, mask
|
||||
encoder_layer.counter = 0
|
||||
|
||||
e_conv1, e_mask1 = encoder_layer(inputs_img, inputs_mask, 64, 7, bn=False)
|
||||
e_conv2, e_mask2 = encoder_layer(e_conv1, e_mask1, 128, 5)
|
||||
e_conv3, e_mask3 = encoder_layer(e_conv2, e_mask2, 256, 5)
|
||||
e_conv4, e_mask4 = encoder_layer(e_conv3, e_mask3, 512, 3)
|
||||
e_conv5, e_mask5 = encoder_layer(e_conv4, e_mask4, 512, 3)
|
||||
e_conv6, e_mask6 = encoder_layer(e_conv5, e_mask5, 512, 3)
|
||||
e_conv7, e_mask7 = encoder_layer(e_conv6, e_mask6, 512, 3)
|
||||
e_conv8, e_mask8 = encoder_layer(e_conv7, e_mask7, 512, 3)
|
||||
|
||||
# DECODER
|
||||
def decoder_layer(img_in, mask_in, e_conv, e_mask, filters, kernel_size, bn=True):
|
||||
up_img = UpSampling2D(size=(2,2))(img_in)
|
||||
up_mask = UpSampling2D(size=(2,2))(mask_in)
|
||||
concat_img = Concatenate(axis=3)([e_conv,up_img])
|
||||
concat_mask = Concatenate(axis=3)([e_mask,up_mask])
|
||||
conv, mask = PConv2D(filters, kernel_size, padding='same')([concat_img, concat_mask])
|
||||
if bn:
|
||||
conv = BatchNormalization()(conv)
|
||||
conv = LeakyReLU(alpha=0.2)(conv)
|
||||
return conv, mask
|
||||
|
||||
d_conv9, d_mask9 = decoder_layer(e_conv8, e_mask8, e_conv7, e_mask7, 512, 3)
|
||||
d_conv10, d_mask10 = decoder_layer(d_conv9, d_mask9, e_conv6, e_mask6, 512, 3)
|
||||
d_conv11, d_mask11 = decoder_layer(d_conv10, d_mask10, e_conv5, e_mask5, 512, 3)
|
||||
d_conv12, d_mask12 = decoder_layer(d_conv11, d_mask11, e_conv4, e_mask4, 512, 3)
|
||||
d_conv13, d_mask13 = decoder_layer(d_conv12, d_mask12, e_conv3, e_mask3, 256, 3)
|
||||
d_conv14, d_mask14 = decoder_layer(d_conv13, d_mask13, e_conv2, e_mask2, 128, 3)
|
||||
d_conv15, d_mask15 = decoder_layer(d_conv14, d_mask14, e_conv1, e_mask1, 64, 3)
|
||||
d_conv16, d_mask16 = decoder_layer(d_conv15, d_mask15, inputs_img, inputs_mask, 3, 3, bn=False)
|
||||
outputs = Conv2D(3, 1, activation = 'sigmoid')(d_conv16)
|
||||
|
||||
# Setup the model inputs / outputs
|
||||
model = Model(inputs=[inputs_img, inputs_mask, loss_mask], outputs=outputs)
|
||||
|
||||
# Compile the model
|
||||
model.compile(
|
||||
optimizer = Adam(lr=lr),
|
||||
loss=self.loss_total(loss_mask)
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def loss_total(self, mask):
|
||||
"""
|
||||
Creates a loss function which sums all the loss components
|
||||
and multiplies by their weights. See paper eq. 7.
|
||||
"""
|
||||
def loss(y_true, y_pred):
|
||||
|
||||
# Compute predicted image with non-hole pixels set to ground truth
|
||||
y_comp = mask * y_true + (1-mask) * y_pred
|
||||
|
||||
# Compute the vgg features
|
||||
vgg_out = self.vgg(y_pred)
|
||||
vgg_gt = self.vgg(y_true)
|
||||
vgg_comp = self.vgg(y_comp)
|
||||
|
||||
# Compute loss components
|
||||
l1 = self.loss_valid(mask, y_true, y_pred)
|
||||
l2 = self.loss_hole(mask, y_true, y_pred)
|
||||
l3 = self.loss_perceptual(vgg_out, vgg_gt, vgg_comp)
|
||||
l4 = self.loss_style(vgg_out, vgg_gt)
|
||||
l5 = self.loss_style(vgg_comp, vgg_gt)
|
||||
l6 = self.loss_tv(mask, y_comp)
|
||||
|
||||
# Return loss function
|
||||
return l1 + 6*l2 + 0.05*l3 + 120*(l4+l5) + 0.1*l6
|
||||
|
||||
return loss
|
||||
|
||||
def loss_hole(self, mask, y_true, y_pred):
|
||||
"""Pixel L1 loss within the hole / mask"""
|
||||
return self.l1((1-mask) * y_true, (1-mask) * y_pred)
|
||||
|
||||
def loss_valid(self, mask, y_true, y_pred):
|
||||
"""Pixel L1 loss outside the hole / mask"""
|
||||
return self.l1(mask * y_true, mask * y_pred)
|
||||
|
||||
def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp):
|
||||
"""Perceptual loss based on VGG16, see. eq. 3 in paper"""
|
||||
loss = 0
|
||||
for o, c, g in zip(vgg_out, vgg_comp, vgg_gt):
|
||||
loss += self.l1(o, g) + self.l1(c, g)
|
||||
return loss
|
||||
|
||||
def loss_style(self, output, vgg_gt):
|
||||
"""Style loss based on output/computation, used for both eq. 4 & 5 in paper"""
|
||||
loss = 0
|
||||
for o, g in zip(output, vgg_gt):
|
||||
loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))
|
||||
return loss
|
||||
|
||||
def loss_tv(self, mask, y_comp):
|
||||
"""Total variation loss, used for smoothing the hole region, see. eq. 6"""
|
||||
|
||||
# Create dilated hole region using a 3x3 kernel of all 1s.
|
||||
kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3]))
|
||||
dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same')
|
||||
|
||||
# Cast values to be [0., 1.], and compute dilated hole region of y_comp
|
||||
dilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32')
|
||||
P = dilated_mask * y_comp
|
||||
|
||||
# Calculate total variation loss
|
||||
a = self.l1(P[:,1:,:,:], P[:,:-1,:,:])
|
||||
b = self.l1(P[:,:,1:,:], P[:,:,:-1,:])
|
||||
return a+b
|
||||
|
||||
def fit(self, generator, epochs=10, plot_callback=None, *args, **kwargs):
|
||||
"""Fit the U-Net to a (images, targets) generator
|
||||
|
||||
param generator: training generator yielding (maskes_image, original_image) tuples
|
||||
param epochs: number of epochs to train for
|
||||
param plot_callback: callback function taking Unet model as parameter
|
||||
"""
|
||||
|
||||
# Loop over epochs
|
||||
for _ in range(epochs):
|
||||
|
||||
# Fit the model
|
||||
self.model.fit_generator(
|
||||
generator,
|
||||
epochs=self.current_epoch+1,
|
||||
initial_epoch=self.current_epoch,
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
# Update epoch
|
||||
self.current_epoch += 1
|
||||
|
||||
# After each epoch predict on test images & show them
|
||||
if plot_callback:
|
||||
plot_callback(self.model)
|
||||
|
||||
# Save logfile
|
||||
if self.weight_filepath:
|
||||
self.save()
|
||||
|
||||
def predict(self, sample):
|
||||
"""Run prediction using this model"""
|
||||
return self.model.predict(sample)
|
||||
|
||||
def summary(self):
|
||||
"""Get summary of the UNet model"""
|
||||
print(self.model.summary())
|
||||
|
||||
def save(self):
|
||||
self.model.save_weights(self.current_weightfile())
|
||||
|
||||
def load(self, filepath, train_bn=True, lr=0.0002):
|
||||
|
||||
# Create UNet-like model
|
||||
self.model = self.build_pconv_unet(train_bn, lr)
|
||||
|
||||
# Load weights into model
|
||||
#epoch = 50
|
||||
epoch = int(os.path.basename(filepath).split("_")[0])
|
||||
assert epoch > 0, "Could not parse weight file. Should start with 'X_', with X being the epoch"
|
||||
self.current_epoch = epoch
|
||||
self.model.load_weights(filepath)
|
||||
|
||||
def current_weightfile(self):
|
||||
assert self.weight_filepath != None, 'Must specify location of logs'
|
||||
return self.weight_filepath + "{}_weights_{}.h5".format(self.current_epoch, self.current_timestamp())
|
||||
|
||||
@staticmethod
|
||||
def current_timestamp():
|
||||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
|
||||
@staticmethod
|
||||
def l1(y_true, y_pred):
|
||||
"""Calculate the L1 loss used in all loss calculations"""
|
||||
if K.ndim(y_true) == 4:
|
||||
return K.sum(K.abs(y_pred - y_true), axis=[1,2,3])
|
||||
elif K.ndim(y_true) == 3:
|
||||
return K.sum(K.abs(y_pred - y_true), axis=[1,2])
|
||||
else:
|
||||
raise NotImplementedError("Calculating L1 loss on 1D tensors? should not occur for this network")
|
||||
|
||||
@staticmethod
|
||||
def gram_matrix(x, norm_by_channels=False):
|
||||
"""Calculate gram matrix used in style loss"""
|
||||
|
||||
# Assertions on input
|
||||
assert K.ndim(x) == 4, 'Input tensor should be a 4d (B, H, W, C) tensor'
|
||||
assert K.image_data_format() == 'channels_last', "Please use channels-last format"
|
||||
|
||||
# Permute channels and get resulting shape
|
||||
x = K.permute_dimensions(x, (0, 3, 1, 2))
|
||||
shape = K.shape(x)
|
||||
B, C, H, W = shape[0], shape[1], shape[2], shape[3]
|
||||
|
||||
# Reshape x and do batch dot product
|
||||
features = K.reshape(x, K.stack([B, C, H*W]))
|
||||
gram = K.batch_dot(features, features, axes=2)
|
||||
|
||||
# Normalize with channels, height and width
|
||||
gram = gram / K.cast(C * H * W, x.dtype)
|
||||
|
||||
return gram
|
126
libs/pconv_layer.py
Normal file
126
libs/pconv_layer.py
Normal file
@ -0,0 +1,126 @@
|
||||
|
||||
from keras.utils import conv_utils
|
||||
from keras import backend as K
|
||||
from keras.engine import InputSpec
|
||||
from keras.layers import Conv2D
|
||||
|
||||
|
||||
class PConv2D(Conv2D):
|
||||
def __init__(self, *args, n_channels=3, mono=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.input_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]
|
||||
|
||||
def build(self, input_shape):
|
||||
"""Adapted from original _Conv() layer of Keras
|
||||
param input_shape: list of dimensions for [img, mask]
|
||||
"""
|
||||
|
||||
if self.data_format == 'channels_first':
|
||||
channel_axis = 1
|
||||
else:
|
||||
channel_axis = -1
|
||||
|
||||
if input_shape[0][channel_axis] is None:
|
||||
raise ValueError('The channel dimension of the inputs should be defined. Found `None`.')
|
||||
|
||||
self.input_dim = input_shape[0][channel_axis]
|
||||
|
||||
# Image kernel
|
||||
kernel_shape = self.kernel_size + (self.input_dim, self.filters)
|
||||
self.kernel = self.add_weight(shape=kernel_shape,
|
||||
initializer=self.kernel_initializer,
|
||||
name='img_kernel',
|
||||
regularizer=self.kernel_regularizer,
|
||||
constraint=self.kernel_constraint)
|
||||
# Mask kernel
|
||||
self.kernel_mask = K.ones(shape=self.kernel_size + (self.input_dim, self.filters))
|
||||
|
||||
if self.use_bias:
|
||||
self.bias = self.add_weight(shape=(self.filters,),
|
||||
initializer=self.bias_initializer,
|
||||
name='bias',
|
||||
regularizer=self.bias_regularizer,
|
||||
constraint=self.bias_constraint)
|
||||
else:
|
||||
self.bias = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, mask=None):
|
||||
'''
|
||||
We will be using the Keras conv2d method, and essentially we have
|
||||
to do here is multiply the mask with the input X, before we apply the
|
||||
convolutions. For the mask itself, we apply convolutions with all weights
|
||||
set to 1.
|
||||
Subsequently, we set all mask values >0 to 1, and otherwise 0
|
||||
'''
|
||||
|
||||
# Both image and mask must be supplied
|
||||
if type(inputs) is not list or len(inputs) != 2:
|
||||
raise Exception('PartialConvolution2D must be called on a list of two tensors [img, mask]. Instead got: ' + str(inputs))
|
||||
|
||||
# Create normalization. Slight change here compared to paper, using mean mask value instead of sum
|
||||
normalization = K.mean(inputs[1], axis=[1,2], keepdims=True)
|
||||
normalization = K.repeat_elements(normalization, inputs[1].shape[1], axis=1)
|
||||
normalization = K.repeat_elements(normalization, inputs[1].shape[2], axis=2)
|
||||
|
||||
# Apply convolutions to image
|
||||
img_output = K.conv2d(
|
||||
(inputs[0]*inputs[1]) / normalization, self.kernel,
|
||||
strides=self.strides,
|
||||
padding=self.padding,
|
||||
data_format=self.data_format,
|
||||
dilation_rate=self.dilation_rate
|
||||
)
|
||||
|
||||
# Apply convolutions to mask
|
||||
mask_output = K.conv2d(
|
||||
inputs[1], self.kernel_mask,
|
||||
strides=self.strides,
|
||||
padding=self.padding,
|
||||
data_format=self.data_format,
|
||||
dilation_rate=self.dilation_rate
|
||||
)
|
||||
|
||||
# Where something happened, set 1, otherwise 0
|
||||
mask_output = K.cast(K.greater(mask_output, 0), 'float32')
|
||||
|
||||
# Apply bias only to the image (if chosen to do so)
|
||||
if self.use_bias:
|
||||
img_output = K.bias_add(
|
||||
img_output,
|
||||
self.bias,
|
||||
data_format=self.data_format)
|
||||
|
||||
# Apply activations on the image
|
||||
if self.activation is not None:
|
||||
img_output = self.activation(img_output)
|
||||
|
||||
return [img_output, mask_output]
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
if self.data_format == 'channels_last':
|
||||
space = input_shape[0][1:-1]
|
||||
new_space = []
|
||||
for i in range(len(space)):
|
||||
new_dim = conv_utils.conv_output_length(
|
||||
space[i],
|
||||
self.kernel_size[i],
|
||||
padding=self.padding,
|
||||
stride=self.strides[i],
|
||||
dilation=self.dilation_rate[i])
|
||||
new_space.append(new_dim)
|
||||
new_shape = (input_shape[0][0],) + tuple(new_space) + (self.filters,)
|
||||
return [new_shape, new_shape]
|
||||
if self.data_format == 'channels_first':
|
||||
space = input_shape[2:]
|
||||
new_space = []
|
||||
for i in range(len(space)):
|
||||
new_dim = conv_utils.conv_output_length(
|
||||
space[i],
|
||||
self.kernel_size[i],
|
||||
padding=self.padding,
|
||||
stride=self.strides[i],
|
||||
dilation=self.dilation_rate[i])
|
||||
new_space.append(new_dim)
|
||||
new_shape = (input_shape[0], self.filters) + tuple(new_space)
|
||||
return [new_shape, new_shape]
|
497
ui.py
Normal file
497
ui.py
Normal file
@ -0,0 +1,497 @@
|
||||
"""
|
||||
Code illustration: 6.09
|
||||
|
||||
Modules imported here:
|
||||
from tkinter import messagebox
|
||||
from tkinter import filedialog
|
||||
|
||||
Attributes added here:
|
||||
file_name = "untitled"
|
||||
|
||||
Methods modified here:
|
||||
on_new_file_menu_clicked()
|
||||
on_save_menu_clicked()
|
||||
on_save_as_menu_clicked()
|
||||
on_close_menu_clicked()
|
||||
on_undo_menu_clicked()
|
||||
on_canvas_zoom_in_menu_clicked()
|
||||
on_canvas_zoom_out_menu_clicked()
|
||||
on_about_menu_clicked()
|
||||
|
||||
Methods added here
|
||||
start_new_project()
|
||||
actual_save()
|
||||
close_window()
|
||||
undo()
|
||||
canvas_zoom_in()
|
||||
canvas_zoom_out()
|
||||
|
||||
@ Tkinter GUI Application Development Blueprints
|
||||
"""
|
||||
import math
|
||||
from PIL import Image, ImageTk, ImageDraw
|
||||
import tkinter as tk
|
||||
from tkinter import colorchooser
|
||||
from tkinter import ttk
|
||||
from tkinter import messagebox
|
||||
from tkinter import filedialog
|
||||
import framework
|
||||
import decensor
|
||||
|
||||
|
||||
class PaintApplication(framework.Framework):
|
||||
|
||||
def __init__(self, root):
|
||||
super().__init__(root)
|
||||
|
||||
self.drawn_img = None
|
||||
self.screen_width = root.winfo_screenwidth()
|
||||
self.screen_height = root.winfo_screenheight()
|
||||
self.start_x, self.start_y = 0, 0
|
||||
self.end_x, self.end_y = 0, 0
|
||||
self.current_item = None
|
||||
self.fill = "#00ff00"
|
||||
self.fill_pil = (0,255,0,255)
|
||||
self.outline = "#00ff00"
|
||||
self.brush_width = 2
|
||||
self.background = 'white'
|
||||
self.foreground = "#00ff00"
|
||||
self.file_name = "Untitled"
|
||||
self.tool_bar_functions = (
|
||||
"draw_line", "draw_irregular_line"
|
||||
)
|
||||
self.selected_tool_bar_function = self.tool_bar_functions[0]
|
||||
|
||||
self.create_gui()
|
||||
self.bind_mouse()
|
||||
|
||||
def on_new_file_menu_clicked(self, event=None):
|
||||
self.start_new_project()
|
||||
|
||||
def start_new_project(self):
|
||||
self.canvas.delete(tk.ALL)
|
||||
self.canvas.config(bg="#ffffff")
|
||||
self.root.title('untitled')
|
||||
|
||||
def on_open_image_menu_clicked(self, event=None):
|
||||
self.open_image()
|
||||
|
||||
def open_image(self):
|
||||
self.file_name = filedialog.askopenfilename(master=self.root, filetypes = [("All Files","*.*")], title="Open...")
|
||||
print(self.file_name)
|
||||
self.canvas.img = Image.open(self.file_name)
|
||||
self.canvas.img_width, self.canvas.img_height = self.canvas.img.size
|
||||
#make reference to image to prevent garbage collection
|
||||
#https://stackoverflow.com/questions/20061396/image-display-on-tkinter-canvas-not-working
|
||||
self.canvas.tk_img = ImageTk.PhotoImage(self.canvas.img)
|
||||
self.canvas.config(width=self.canvas.img_width, height=self.canvas.img_height)
|
||||
self.canvas.create_image(self.canvas.img_width/2.0,self.canvas.img_height/2.0,image=self.canvas.tk_img)
|
||||
|
||||
self.drawn_img = Image.new("RGBA", self.canvas.img.size)
|
||||
self.drawn_img_draw = ImageDraw.Draw(self.drawn_img)
|
||||
|
||||
|
||||
def on_import_mask_clicked(self, event=None):
|
||||
self.import_mask()
|
||||
|
||||
def display_canvas(self):
|
||||
composite_img = Image.alpha_composite(self.canvas.img.convert('RGBA'), self.drawn_img).convert('RGB')
|
||||
self.canvas.tk_img = ImageTk.PhotoImage(composite_img)
|
||||
|
||||
self.canvas.create_image(self.canvas.img_width/2.0,self.canvas.img_height/2.0,image=self.canvas.tk_img)
|
||||
|
||||
def import_mask(self):
|
||||
file_name_mask = filedialog.askopenfilename(master=self.root, filetypes = [("All Files","*.*")], title="Import mask...")
|
||||
mask_img = Image.open(file_name_mask)
|
||||
if (mask_img.size != self.canvas.img.size):
|
||||
messagebox.showerror("Import mask", "Mask image size does not match the original image size! Mask image not imported.")
|
||||
return
|
||||
self.drawn_img = mask_img
|
||||
self.drawn_img_draw = ImageDraw.Draw(self.drawn_img)
|
||||
self.display_canvas()
|
||||
|
||||
def on_save_menu_clicked(self, event=None):
|
||||
if self.file_name == 'untitled':
|
||||
self.on_save_as_menu_clicked()
|
||||
else:
|
||||
self.actual_save()
|
||||
|
||||
def on_save_as_menu_clicked(self):
|
||||
file_name = filedialog.asksaveasfilename(
|
||||
master=self.root, filetypes=[('All Files', ('*.ps', '*.ps'))], title="Save...")
|
||||
if not file_name:
|
||||
return
|
||||
self.file_name = file_name
|
||||
self.actual_save()
|
||||
|
||||
def actual_save(self):
|
||||
self.canvas.postscript(file=self.file_name, colormode='color')
|
||||
self.root.title(self.file_name)
|
||||
|
||||
def on_close_menu_clicked(self):
|
||||
self.close_window()
|
||||
|
||||
def close_window(self):
|
||||
if messagebox.askokcancel("Quit", "Do you really want to quit?"):
|
||||
self.root.destroy()
|
||||
|
||||
def on_undo_menu_clicked(self, event=None):
|
||||
self.undo()
|
||||
|
||||
def undo(self):
|
||||
items_stack = list(self.canvas.find("all"))
|
||||
try:
|
||||
last_item_id = items_stack.pop()
|
||||
except IndexError:
|
||||
return
|
||||
self.canvas.delete(last_item_id)
|
||||
|
||||
def on_canvas_zoom_in_menu_clicked(self):
|
||||
self.canvas_zoom_in()
|
||||
|
||||
def on_canvas_zoom_out_menu_clicked(self):
|
||||
self.canvas_zoom_out()
|
||||
|
||||
def canvas_zoom_in(self):
|
||||
self.canvas.scale("all", 0, 0, 1.2, 1.2)
|
||||
self.canvas.config(scrollregion=self.canvas.bbox(tk.ALL))
|
||||
self.canvas.pack(side=tk.RIGHT, expand=tk.YES, fill=tk.BOTH)
|
||||
|
||||
def canvas_zoom_out(self):
|
||||
self.canvas.scale("all", 0, 0, .8, .8)
|
||||
self.canvas.config(scrollregion=self.canvas.bbox(tk.ALL))
|
||||
self.canvas.pack(side=tk.RIGHT, expand=tk.YES, fill=tk.BOTH)
|
||||
|
||||
def on_decensor_menu_clicked(self, event=None):
|
||||
combined_img = Image.alpha_composite(self.canvas.img.convert('RGBA'), self.drawn_img)
|
||||
decensorer = decensor.Decensor()
|
||||
decensorer.decensor_image(combined_img.convert('RGB'), self.file_name + ".png")
|
||||
messagebox.showinfo(
|
||||
"Decensoring", "Decensoring complete!")
|
||||
|
||||
def on_about_menu_clicked(self, event=None):
|
||||
# messagebox.showinfo(
|
||||
# "Decensoring", "Decensoring in progress.")
|
||||
messagebox.showinfo(
|
||||
"About", "Tkinter GUI Application\n Development Blueprints")
|
||||
|
||||
def get_all_configurations_for_item(self):
|
||||
configuration_dict = {}
|
||||
for key, value in self.canvas.itemconfig("current").items():
|
||||
if value[-1] and value[-1] not in ["0", "0.0", "0,0", "current"]:
|
||||
configuration_dict[key] = value[-1]
|
||||
return configuration_dict
|
||||
|
||||
def canvas_function_wrapper(self, function_name, *arg, **kwargs):
|
||||
func = getattr(self.canvas, function_name)
|
||||
func(*arg, **kwargs)
|
||||
|
||||
def adjust_canvas_coords(self, x_coordinate, y_coordinate):
|
||||
# low_x, high_x = self.x_scroll.get()
|
||||
# percent_x = low_x/(1+low_x-high_x)
|
||||
|
||||
# low_y, high_y = self.y_scroll.get()
|
||||
# percent_y = low_y/(1+low_y-high_y)
|
||||
|
||||
low_x, high_x = self.x_scroll.get()
|
||||
low_y, high_y = self.y_scroll.get()
|
||||
#length_y = high_y - low_y
|
||||
return low_x * 800 + x_coordinate, low_y * 800 + y_coordinate
|
||||
|
||||
def create_circle(self, x, y, r, **kwargs):
|
||||
return self.canvas.create_oval(x-r, y-r, x+r, y+r, **kwargs)
|
||||
|
||||
def draw_irregular_line(self):
|
||||
# self.current_item = self.canvas.create_line(
|
||||
# self.start_x, self.start_y, self.end_x, self.end_y, fill=self.fill, width=self.brush_width)
|
||||
# self.current_item = self.create_circle(self.end_x, self.end_y, self.brush_width/2.0, fill=self.fill, width=0)
|
||||
|
||||
#draw in PIL
|
||||
self.drawn_img_draw.line((self.start_x, self.start_y, self.end_x, self.end_y), fill=self.fill_pil, width=int(self.brush_width))
|
||||
self.drawn_img_draw.ellipse((self.end_x - self.brush_width/2.0, self.end_y - self.brush_width/2.0, self.end_x + self.brush_width/2.0, self.end_y + self.brush_width/2.0), fill=self.fill_pil)
|
||||
|
||||
self.display_canvas()
|
||||
# composite_img = Image.alpha_composite(self.canvas.img.convert('RGBA'), self.drawn_img).convert('RGB')
|
||||
# self.canvas.tk_img = ImageTk.PhotoImage(composite_img)
|
||||
|
||||
# self.canvas.create_image(self.canvas.img_width/2.0,self.canvas.img_height/2.0,image=self.canvas.tk_img)
|
||||
|
||||
self.canvas.bind("<B1-Motion>", self.draw_irregular_line_update_x_y)
|
||||
|
||||
def draw_irregular_line_update_x_y(self, event=None):
|
||||
self.start_x, self.start_y = self.end_x, self.end_y
|
||||
self.end_x, self.end_y = self.adjust_canvas_coords(event.x, event.y)
|
||||
self.draw_irregular_line()
|
||||
|
||||
def draw_irregular_line_options(self):
|
||||
self.create_fill_options_combobox()
|
||||
self.create_width_options_combobox()
|
||||
|
||||
def on_tool_bar_button_clicked(self, button_index):
|
||||
self.selected_tool_bar_function = self.tool_bar_functions[button_index]
|
||||
self.remove_options_from_top_bar()
|
||||
self.display_options_in_the_top_bar()
|
||||
self.bind_mouse()
|
||||
|
||||
def float_range(self, x, y, step):
|
||||
while x < y:
|
||||
yield x
|
||||
x += step
|
||||
|
||||
def set_foreground_color(self, event=None):
|
||||
self.foreground = self.get_color_from_chooser(
|
||||
self.foreground, "foreground")
|
||||
self.color_palette.itemconfig(
|
||||
self.foreground_palette, width=0, fill=self.foreground)
|
||||
|
||||
def set_background_color(self, event=None):
|
||||
self.background = self.get_color_from_chooser(
|
||||
self.background, "background")
|
||||
self.color_palette.itemconfig(
|
||||
self.background_palette, width=0, fill=self.background)
|
||||
|
||||
def get_color_from_chooser(self, initial_color, color_type="a"):
|
||||
color = colorchooser.askcolor(
|
||||
color=initial_color,
|
||||
title="select {} color".format(color_type)
|
||||
)[-1]
|
||||
if color:
|
||||
return color
|
||||
# dialog has been cancelled
|
||||
else:
|
||||
return initial_color
|
||||
|
||||
def try_to_set_fill_after_palette_change(self):
|
||||
try:
|
||||
self.set_fill()
|
||||
except:
|
||||
pass
|
||||
|
||||
def try_to_set_outline_after_palette_change(self):
|
||||
try:
|
||||
self.set_outline()
|
||||
except:
|
||||
pass
|
||||
|
||||
def display_options_in_the_top_bar(self):
|
||||
self.show_selected_tool_icon_in_top_bar(
|
||||
self.selected_tool_bar_function)
|
||||
options_function_name = "{}_options".format(self.selected_tool_bar_function)
|
||||
func = getattr(self, options_function_name, self.function_not_defined)
|
||||
func()
|
||||
|
||||
def draw_line_options(self):
|
||||
self.create_fill_options_combobox()
|
||||
self.create_width_options_combobox()
|
||||
|
||||
def create_fill_options_combobox(self):
|
||||
tk.Label(self.top_bar, text='Fill:').pack(side="left")
|
||||
self.fill_combobox = ttk.Combobox(
|
||||
self.top_bar, state='readonly', width=5)
|
||||
self.fill_combobox.pack(side="left")
|
||||
self.fill_combobox['values'] = ('none', 'fg', 'bg', 'black', 'white')
|
||||
self.fill_combobox.bind('<<ComboboxSelected>>', self.set_fill)
|
||||
self.fill_combobox.set(self.fill)
|
||||
|
||||
def create_outline_options_combobox(self):
|
||||
tk.Label(self.top_bar, text='Outline:').pack(side="left")
|
||||
self.outline_combobox = ttk.Combobox(
|
||||
self.top_bar, state='readonly', width=5)
|
||||
self.outline_combobox.pack(side="left")
|
||||
self.outline_combobox['values'] = (
|
||||
'none', 'fg', 'bg', 'black', 'white')
|
||||
self.outline_combobox.bind('<<ComboboxSelected>>', self.set_outline)
|
||||
self.outline_combobox.set(self.outline)
|
||||
|
||||
def create_width_options_combobox(self):
|
||||
tk.Label(self.top_bar, text='Width:').pack(side="left")
|
||||
self.width_combobox = ttk.Combobox(
|
||||
self.top_bar, state='readonly', width=3)
|
||||
self.width_combobox.pack(side="left")
|
||||
self.width_combobox['values'] = (
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50)
|
||||
self.width_combobox.bind('<<ComboboxSelected>>', self.set_brush_width)
|
||||
self.width_combobox.set(self.brush_width)
|
||||
|
||||
def set_fill(self, event=None):
|
||||
fill_color = self.fill_combobox.get()
|
||||
if fill_color == 'none':
|
||||
self.fill = '' # transparent
|
||||
elif fill_color == 'fg':
|
||||
self.fill = self.foreground
|
||||
elif fill_color == 'bg':
|
||||
self.fill = self.background
|
||||
else:
|
||||
self.fill = fill_color
|
||||
|
||||
def set_outline(self, event=None):
|
||||
outline_color = self.outline_combobox.get()
|
||||
if outline_color == 'none':
|
||||
self.outline = '' # transparent
|
||||
elif outline_color == 'fg':
|
||||
self.outline = self.foreground
|
||||
elif outline_color == 'bg':
|
||||
self.outline = self.background
|
||||
else:
|
||||
self.outline = outline_color
|
||||
|
||||
def set_brush_width(self, event):
|
||||
self.brush_width = float(self.width_combobox.get())
|
||||
|
||||
def create_color_palette(self):
|
||||
self.color_palette = tk.Canvas(self.tool_bar, height=55, width=55)
|
||||
self.color_palette.grid(row=10, column=1, columnspan=2, pady=5, padx=3)
|
||||
self.background_palette = self.color_palette.create_rectangle(
|
||||
15, 15, 48, 48, outline=self.background, fill=self.background)
|
||||
self.foreground_palette = self.color_palette.create_rectangle(
|
||||
1, 1, 33, 33, outline=self.foreground, fill=self.foreground)
|
||||
self.bind_color_palette()
|
||||
|
||||
def bind_color_palette(self):
|
||||
self.color_palette.tag_bind(
|
||||
self.background_palette, "<Button-1>", self.set_background_color)
|
||||
self.color_palette.tag_bind(
|
||||
self.foreground_palette, "<Button-1>", self.set_foreground_color)
|
||||
|
||||
def create_current_coordinate_label(self):
|
||||
self.current_coordinate_label = tk.Label(
|
||||
self.tool_bar, text='x:0\ny: 0 ')
|
||||
self.current_coordinate_label.grid(
|
||||
row=13, column=1, columnspan=2, pady=5, padx=1, sticky='w')
|
||||
|
||||
def show_current_coordinates(self, event=None):
|
||||
x_coordinate = event.x
|
||||
y_coordinate = event.y
|
||||
coordinate_string = "x:{0}\ny:{1}".format(x_coordinate, y_coordinate)
|
||||
self.current_coordinate_label.config(text=coordinate_string)
|
||||
|
||||
def function_not_defined(self):
|
||||
pass
|
||||
|
||||
def execute_selected_method(self):
|
||||
self.current_item = None
|
||||
func = getattr(
|
||||
self, self.selected_tool_bar_function, self.function_not_defined)
|
||||
func()
|
||||
|
||||
def draw_line(self):
|
||||
self.current_item = self.canvas.create_line(
|
||||
self.start_x, self.start_y, self.end_x, self.end_y, fill=self.fill, width=self.brush_width)
|
||||
|
||||
# self.drawn_img_draw.line((self.start_x, self.start_y, self.end_x, self.end_y), fill=self.fill_pil, width=self.brush_width)
|
||||
|
||||
def create_tool_bar_buttons(self):
|
||||
for index, name in enumerate(self.tool_bar_functions):
|
||||
icon = tk.PhotoImage(file='icons/' + name + '.gif')
|
||||
self.button = tk.Button(
|
||||
self.tool_bar, image=icon, command=lambda index=index: self.on_tool_bar_button_clicked(index))
|
||||
self.button.grid(
|
||||
row=index // 2, column=1 + index % 2, sticky='nsew')
|
||||
self.button.image = icon
|
||||
|
||||
def remove_options_from_top_bar(self):
|
||||
for child in self.top_bar.winfo_children():
|
||||
child.destroy()
|
||||
|
||||
def show_selected_tool_icon_in_top_bar(self, function_name):
|
||||
display_name = function_name.replace("_", " ").capitalize() + ":"
|
||||
tk.Label(self.top_bar, text=display_name).pack(side="left")
|
||||
photo = tk.PhotoImage(
|
||||
file='icons/' + function_name + '.gif')
|
||||
label = tk.Label(self.top_bar, image=photo)
|
||||
label.image = photo
|
||||
label.pack(side="left")
|
||||
|
||||
def bind_mouse(self):
|
||||
self.canvas.bind("<Button-1>", self.on_mouse_button_pressed)
|
||||
self.canvas.bind(
|
||||
"<Button1-Motion>", self.on_mouse_button_pressed_motion)
|
||||
self.canvas.bind(
|
||||
"<Button1-ButtonRelease>", self.on_mouse_button_released)
|
||||
self.canvas.bind("<Motion>", self.on_mouse_unpressed_motion)
|
||||
|
||||
def on_mouse_button_pressed(self, event):
|
||||
self.start_x = self.end_x = self.canvas.canvasx(event.x)
|
||||
self.start_y = self.end_y = self.canvas.canvasy(event.y)
|
||||
self.execute_selected_method()
|
||||
|
||||
def on_mouse_button_pressed_motion(self, event):
|
||||
self.end_x = self.canvas.canvasx(event.x)
|
||||
self.end_y = self.canvas.canvasy(event.y)
|
||||
self.canvas.delete(self.current_item)
|
||||
self.execute_selected_method()
|
||||
|
||||
def on_mouse_button_released(self, event):
|
||||
self.end_x = self.canvas.canvasx(event.x)
|
||||
self.end_y = self.canvas.canvasy(event.y)
|
||||
|
||||
def on_mouse_unpressed_motion(self, event):
|
||||
self.show_current_coordinates(event)
|
||||
|
||||
def create_gui(self):
|
||||
self.create_menu()
|
||||
self.create_top_bar()
|
||||
self.create_tool_bar()
|
||||
self.create_tool_bar_buttons()
|
||||
self.create_drawing_canvas()
|
||||
self.create_color_palette()
|
||||
self.create_current_coordinate_label()
|
||||
self.bind_menu_accelrator_keys()
|
||||
self.show_selected_tool_icon_in_top_bar("draw_line")
|
||||
self.draw_line_options()
|
||||
|
||||
def create_menu(self):
|
||||
self.menubar = tk.Menu(self.root)
|
||||
menu_definitions = (
|
||||
'File- &New/Ctrl+N/self.on_new_file_menu_clicked, Open/Ctrl+O/self.on_open_image_menu_clicked, Import Mask/Ctrl+M/self.on_import_mask_clicked, Save/Ctrl+S/self.on_save_menu_clicked, SaveAs/ /self.on_save_as_menu_clicked, sep, Exit/Alt+F4/self.on_close_menu_clicked',
|
||||
'Edit- Undo/Ctrl+Z/self.on_undo_menu_clicked, sep',
|
||||
'View- Zoom in//self.on_canvas_zoom_in_menu_clicked,Zoom Out//self.on_canvas_zoom_out_menu_clicked',
|
||||
'Decensor- Decensor/Ctrl+D/self.on_decensor_menu_clicked',
|
||||
'About- About/F1/self.on_about_menu_clicked'
|
||||
)
|
||||
self.build_menu(menu_definitions)
|
||||
|
||||
def create_top_bar(self):
|
||||
self.top_bar = tk.Frame(self.root, height=25, relief="raised")
|
||||
self.top_bar.pack(fill="x", side="top", pady=2)
|
||||
|
||||
def create_tool_bar(self):
|
||||
self.tool_bar = tk.Frame(self.root, relief="raised", width=50)
|
||||
self.tool_bar.pack(fill="y", side="left", pady=3)
|
||||
|
||||
def create_drawing_canvas(self):
|
||||
self.canvas_frame = tk.Frame(self.root, width=900, height=900)
|
||||
self.canvas_frame.pack(side="right", expand="yes", fill="both")
|
||||
self.canvas = tk.Canvas(self.canvas_frame, background="white",
|
||||
width=512, height=512, scrollregion=(0, 0, 512, 512))
|
||||
self.create_scroll_bar()
|
||||
self.canvas.pack(side=tk.RIGHT, expand=tk.YES, fill=tk.BOTH)
|
||||
|
||||
self.canvas.img = Image.open('./icons/canvas_top_test.png').convert('RGBA')
|
||||
self.canvas.img = self.canvas.img.resize((512,512))
|
||||
self.canvas.tk_img = ImageTk.PhotoImage(self.canvas.img)
|
||||
self.canvas.create_image(256,256,image=self.canvas.tk_img)
|
||||
|
||||
def create_scroll_bar(self):
|
||||
self.x_scroll = tk.Scrollbar(self.canvas_frame, orient="horizontal")
|
||||
self.x_scroll.pack(side="bottom", fill="x")
|
||||
self.x_scroll.config(command=self.canvas.xview)
|
||||
self.y_scroll = tk.Scrollbar(self.canvas_frame, orient="vertical")
|
||||
self.y_scroll.pack(side="right", fill="y")
|
||||
self.y_scroll.config(command=self.canvas.yview)
|
||||
self.canvas.config(
|
||||
xscrollcommand=self.x_scroll.set, yscrollcommand=self.y_scroll.set)
|
||||
|
||||
def bind_menu_accelrator_keys(self):
|
||||
self.root.bind('<KeyPress-F1>', self.on_about_menu_clicked)
|
||||
self.root.bind('<Control-N>', self.on_new_file_menu_clicked)
|
||||
self.root.bind('<Control-n>', self.on_new_file_menu_clicked)
|
||||
self.root.bind('<Control-s>', self.on_save_menu_clicked)
|
||||
self.root.bind('<Control-S>', self.on_save_menu_clicked)
|
||||
self.root.bind('<Control-z>', self.on_undo_menu_clicked)
|
||||
self.root.bind('<Control-Z>', self.on_undo_menu_clicked)
|
||||
|
||||
if __name__ == '__main__':
|
||||
root = tk.Tk()
|
||||
app = PaintApplication(root)
|
||||
root.mainloop()
|
Loading…
x
Reference in New Issue
Block a user