DeepCreamPy/decensor.py

85 lines
3.4 KiB
Python
Raw Normal View History

2018-02-11 03:19:48 +00:00
import numpy as np
import tensorflow as tf
from PIL import Image
import tqdm
import os
import matplotlib.pyplot as plt
import sys
sys.path.append('..')
2018-02-26 15:59:14 +00:00
2018-02-11 03:19:48 +00:00
from model import Model
2018-02-11 18:57:44 +00:00
from poisson_blend import blend
2018-02-26 15:59:14 +00:00
from config import *
2018-02-11 03:19:48 +00:00
2018-02-28 04:15:57 +00:00
#TODO: allow variable batch sizes when decensoring. changing BATCH_SIZE will likely result in crashing
2018-02-11 18:57:44 +00:00
BATCH_SIZE = 1
2018-02-11 03:19:48 +00:00
2018-03-15 01:04:45 +00:00
mask_color = [args.mask_color_red, args.mask_color_green, args.mask_color_blue]
2018-02-12 17:21:10 +00:00
poisson_blending_enabled = False
2018-02-11 03:19:48 +00:00
2018-02-27 16:34:14 +00:00
def decensor(args):
x = tf.placeholder(tf.float32, [BATCH_SIZE, args.input_size, args.input_size, args.input_channel_size])
mask = tf.placeholder(tf.float32, [BATCH_SIZE, args.input_size, args.input_size, 1])
local_x = tf.placeholder(tf.float32, [BATCH_SIZE, args.local_input_size, args.local_input_size, args.input_channel_size])
global_completion = tf.placeholder(tf.float32, [BATCH_SIZE, args.input_size, args.input_size, args.input_channel_size])
local_completion = tf.placeholder(tf.float32, [BATCH_SIZE, args.local_input_size, args.local_input_size, args.input_channel_size])
2018-02-11 03:19:48 +00:00
is_training = tf.placeholder(tf.bool, [])
model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE)
sess = tf.Session()
init_op = tf.global_variables_initializer()
sess.run(init_op)
saver = tf.train.Saver()
saver.restore(sess, './models/latest')
2018-02-11 05:45:14 +00:00
x_decensor = []
mask_decensor = []
for subdir, dirs, files in sorted(os.walk(args.decensor_input_path)):
2018-02-11 05:45:14 +00:00
for file in sorted(files):
file_path = os.path.join(subdir, file)
if os.path.isfile(file_path) and os.path.splitext(file_path)[1] == ".png":
2018-02-12 17:08:02 +00:00
print(file_path)
2018-02-11 05:45:14 +00:00
image = Image.open(file_path).convert('RGB')
image = np.array(image)
image = np.array(image / 127.5 - 1)
x_decensor.append(image)
x_decensor = np.array(x_decensor)
2018-02-12 17:08:02 +00:00
print(x_decensor.shape)
2018-02-11 05:45:14 +00:00
step_num = int(len(x_decensor) / BATCH_SIZE)
2018-02-11 03:19:48 +00:00
cnt = 0
for i in tqdm.tqdm(range(step_num)):
2018-02-11 05:45:14 +00:00
x_batch = x_decensor[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
mask_batch = get_mask(x_batch)
2018-02-11 03:19:48 +00:00
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
for i in range(BATCH_SIZE):
img = completion[i]
img = np.array((img + 1) * 127.5, dtype=np.uint8)
2018-02-11 18:57:44 +00:00
original = x_batch[i]
original = np.array((original + 1) * 127.5, dtype=np.uint8)
2018-02-12 17:21:10 +00:00
if (poisson_blending_enabled):
img = blend(original, img, mask_batch[0,:,:,0])
2018-02-12 17:08:02 +00:00
output = Image.fromarray(img.astype('uint8'), 'RGB')
dst = args.decensor_output_path + '{}.png'.format("{0:06d}".format(cnt))
2018-02-11 05:45:14 +00:00
output.save(dst)
cnt += 1
2018-02-11 03:19:48 +00:00
2018-02-11 05:45:14 +00:00
def get_mask(x_batch):
2018-02-11 03:19:48 +00:00
points = []
mask = []
for i in range(BATCH_SIZE):
2018-02-11 05:45:14 +00:00
raw = x_batch[i]
raw = np.array((raw + 1) * 127.5, dtype=np.uint8)
m = np.zeros((args.input_size, args.input_size, 1), dtype=np.uint8)
for x in range(args.input_size):
for y in range(args.input_size):
2018-02-27 16:36:20 +00:00
if np.array_equal(raw[x][y], mask_color):
2018-02-11 05:45:14 +00:00
m[x, y] = 1
2018-02-11 03:19:48 +00:00
mask.append(m)
2018-02-11 05:45:14 +00:00
return np.array(mask)
2018-02-11 03:19:48 +00:00
if __name__ == '__main__':
if not os.path.exists(args.decensor_output_path):
os.makedirs(args.decensor_output_path)
2018-02-28 04:15:57 +00:00
decensor(args)