more args

This commit is contained in:
deep pomf 2018-02-26 10:59:14 -05:00
parent c0c817c10d
commit 405456a8f0

View File

@ -6,26 +6,25 @@ import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import sys import sys
sys.path.append('..') sys.path.append('..')
from model import Model from model import Model
from poisson_blend import blend from poisson_blend import blend
from config import *
IMAGE_SIZE = 128 #size of input of local discrimnator. do not change this value.
LOCAL_SIZE = 64 LOCAL_SIZE = 64
HOLE_MIN = 24
HOLE_MAX = 48
BATCH_SIZE = 1 BATCH_SIZE = 1
image_folder = 'decensor_input_images/' image_folder = 'decensor_input_images/'
mask_color = [0, 255, 0] mask_color = [0, 255, 0]
poisson_blending_enabled = False poisson_blending_enabled = False
def decensor(): def decensor():
x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3]) x = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size])
mask = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 1]) mask = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, 1])
local_x = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3]) local_x = tf.placeholder(tf.float32, [args.batch_size, args.local_image_size, args.local_image_size, args.input_channel_size])
global_completion = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3]) global_completion = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size])
local_completion = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3]) local_completion = tf.placeholder(tf.float32, [args.batch_size, args.local_image_size, args.local_image_size, args.input_channel_size])
is_training = tf.placeholder(tf.bool, []) is_training = tf.placeholder(tf.bool, [])
model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE) model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE)
@ -74,9 +73,9 @@ def get_mask(x_batch):
for i in range(BATCH_SIZE): for i in range(BATCH_SIZE):
raw = x_batch[i] raw = x_batch[i]
raw = np.array((raw + 1) * 127.5, dtype=np.uint8) raw = np.array((raw + 1) * 127.5, dtype=np.uint8)
m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8) m = np.zeros((args.image_size, args.image_size, 1), dtype=np.uint8)
for x in range(IMAGE_SIZE): for x in range(args.image_size):
for y in range(IMAGE_SIZE): for y in range(args.image_size):
if np.array_equal(raw[x][y], [0, 255, 0]): if np.array_equal(raw[x][y], [0, 255, 0]):
m[x, y] = 1 m[x, y] = 1
mask.append(m) mask.append(m)