import numpy as np import tensorflow as tf from PIL import Image import tqdm import os import matplotlib.pyplot as plt import sys sys.path.append('..') from model import Model IMAGE_SIZE = 128 LOCAL_SIZE = 64 HOLE_MIN = 24 HOLE_MAX = 48 BATCH_SIZE = 16 test_npy = './lfw.npy' def test(): x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3]) mask = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 1]) local_x = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3]) global_completion = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3]) local_completion = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3]) is_training = tf.placeholder(tf.bool, []) model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE) sess = tf.Session() init_op = tf.global_variables_initializer() sess.run(init_op) saver = tf.train.Saver() saver.restore(sess, './models/latest') x_test = np.load(test_npy) np.random.shuffle(x_test) x_test = np.array([a / 127.5 - 1 for a in x_test]) step_num = int(len(x_test) / BATCH_SIZE) cnt = 0 for i in tqdm.tqdm(range(step_num)): x_batch = x_test[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] _, mask_batch = get_points() completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False}) for i in range(BATCH_SIZE): cnt += 1 raw = x_batch[i] raw = np.array((raw + 1) * 127.5, dtype=np.uint8) masked = raw * (1 - mask_batch[i]) + np.ones_like(raw) * mask_batch[i] * 255 img = completion[i] img = np.array((img + 1) * 127.5, dtype=np.uint8) dst = './testing_output_images/{}.jpg'.format("{0:06d}".format(cnt)) output_image([['Input', masked], ['Output', img], ['Ground Truth', raw]], dst) def get_points(): points = [] mask = [] for i in range(BATCH_SIZE): x1, y1 = np.random.randint(0, IMAGE_SIZE - LOCAL_SIZE + 1, 2) x2, y2 = np.array([x1, y1]) + LOCAL_SIZE points.append([x1, y1, x2, y2]) w, h = np.random.randint(HOLE_MIN, HOLE_MAX + 1, 2) p1 = x1 + np.random.randint(0, LOCAL_SIZE - w) q1 = y1 + np.random.randint(0, LOCAL_SIZE - h) p2 = p1 + w q2 = q1 + h m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8) m[q1:q2 + 1, p1:p2 + 1] = 1 mask.append(m) return np.array(points), np.array(mask) def output_image(images, dst): fig = plt.figure() for i, image in enumerate(images): text, img = image fig.add_subplot(1, 3, i + 1) plt.imshow(img) plt.tick_params(labelbottom='off') plt.tick_params(labelleft='off') plt.gca().get_xaxis().set_ticks_position('none') plt.gca().get_yaxis().set_ticks_position('none') plt.xlabel(text) plt.savefig(dst) plt.close() if __name__ == '__main__': test()