mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2025-01-07 09:54:45 +00:00
update decensor
This commit is contained in:
parent
704ff80bba
commit
13dc39afba
79
decensor.py
79
decensor.py
@ -12,11 +12,12 @@ IMAGE_SIZE = 128
|
||||
LOCAL_SIZE = 64
|
||||
HOLE_MIN = 24
|
||||
HOLE_MAX = 48
|
||||
BATCH_SIZE = 16
|
||||
BATCH_SIZE = 3
|
||||
|
||||
image_path = './lfw.npy'
|
||||
image_folder = 'decensor_input_images/'
|
||||
mask_color = [0, 255, 0]
|
||||
|
||||
def test():
|
||||
def decensor():
|
||||
x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3])
|
||||
mask = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 1])
|
||||
local_x = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3])
|
||||
@ -32,64 +33,48 @@ def test():
|
||||
saver = tf.train.Saver()
|
||||
saver.restore(sess, './models/latest')
|
||||
|
||||
x_test = np.load(test_npy)
|
||||
np.random.shuffle(x_test)
|
||||
x_test = np.array([a / 127.5 - 1 for a in x_test])
|
||||
|
||||
step_num = int(len(x_test) / BATCH_SIZE)
|
||||
x_decensor = []
|
||||
mask_decensor = []
|
||||
for subdir, dirs, files in sorted(os.walk(image_folder)):
|
||||
for file in sorted(files):
|
||||
file_path = os.path.join(subdir, file)
|
||||
if os.path.isfile(file_path) and os.path.splitext(file_path)[1] == ".png":
|
||||
print file_path
|
||||
image = Image.open(file_path).convert('RGB')
|
||||
image = np.array(image)
|
||||
image = np.array(image / 127.5 - 1)
|
||||
x_decensor.append(image)
|
||||
x_decensor = np.array(x_decensor)
|
||||
print x_decensor.shape
|
||||
step_num = int(len(x_decensor) / BATCH_SIZE)
|
||||
|
||||
cnt = 0
|
||||
for i in tqdm.tqdm(range(step_num)):
|
||||
x_batch = x_test[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
|
||||
_, mask_batch = get_points()
|
||||
x_batch = x_decensor[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
|
||||
mask_batch = get_mask(x_batch)
|
||||
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
|
||||
for i in range(BATCH_SIZE):
|
||||
cnt += 1
|
||||
raw = x_batch[i]
|
||||
raw = np.array((raw + 1) * 127.5, dtype=np.uint8)
|
||||
masked = raw * (1 - mask_batch[i]) + np.ones_like(raw) * mask_batch[i] * 255
|
||||
img = completion[i]
|
||||
img = np.array((img + 1) * 127.5, dtype=np.uint8)
|
||||
dst = './output/{}.jpg'.format("{0:06d}".format(cnt))
|
||||
output_image([['Input', masked], ['Output', img], ['Ground Truth', raw]], dst)
|
||||
output = Image.fromarray(img.astype('uint8'), 'RGB')
|
||||
dst = './decensor_output_images/{}.png'.format("{0:06d}".format(cnt))
|
||||
output.save(dst)
|
||||
|
||||
|
||||
def get_points():
|
||||
def get_mask(x_batch):
|
||||
points = []
|
||||
mask = []
|
||||
for i in range(BATCH_SIZE):
|
||||
x1, y1 = np.random.randint(0, IMAGE_SIZE - LOCAL_SIZE + 1, 2)
|
||||
x2, y2 = np.array([x1, y1]) + LOCAL_SIZE
|
||||
points.append([x1, y1, x2, y2])
|
||||
|
||||
w, h = np.random.randint(HOLE_MIN, HOLE_MAX + 1, 2)
|
||||
p1 = x1 + np.random.randint(0, LOCAL_SIZE - w)
|
||||
q1 = y1 + np.random.randint(0, LOCAL_SIZE - h)
|
||||
p2 = p1 + w
|
||||
q2 = q1 + h
|
||||
|
||||
raw = x_batch[i]
|
||||
raw = np.array((raw + 1) * 127.5, dtype=np.uint8)
|
||||
m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8)
|
||||
m[q1:q2 + 1, p1:p2 + 1] = 1
|
||||
for x in xrange(IMAGE_SIZE):
|
||||
for y in xrange(IMAGE_SIZE):
|
||||
if np.array_equal(raw[x][y], [0, 255, 0]):
|
||||
m[x, y] = 1
|
||||
mask.append(m)
|
||||
|
||||
return np.array(points), np.array(mask)
|
||||
|
||||
|
||||
def output_image(images, dst):
|
||||
fig = plt.figure()
|
||||
for i, image in enumerate(images):
|
||||
text, img = image
|
||||
fig.add_subplot(1, 3, i + 1)
|
||||
plt.imshow(img)
|
||||
plt.tick_params(labelbottom='off')
|
||||
plt.tick_params(labelleft='off')
|
||||
plt.gca().get_xaxis().set_ticks_position('none')
|
||||
plt.gca().get_yaxis().set_ticks_position('none')
|
||||
plt.xlabel(text)
|
||||
plt.savefig(dst)
|
||||
plt.close()
|
||||
|
||||
return np.array(mask)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
decensor()
|
||||
|
||||
|
0
decensor_input_images/.gitkeep
Normal file
0
decensor_input_images/.gitkeep
Normal file
0
decensor_output_images/.gitkeep
Normal file
0
decensor_output_images/.gitkeep
Normal file
0
testing_output_images/.gitkeep
Normal file
0
testing_output_images/.gitkeep
Normal file
Loading…
Reference in New Issue
Block a user