This commit is contained in:
deeppomf 2018-02-26 20:36:45 -05:00
parent 76b5903772
commit 7e8c8c2967

View File

@ -11,11 +11,11 @@ from config import *
PRETRAIN_EPOCH = 100
def train(args):
x = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size])
mask = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, 1])
local_x = tf.placeholder(tf.float32, [args.batch_size, args.local_image_size, args.local_image_size, args.input_channel_size])
global_completion = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size])
local_completion = tf.placeholder(tf.float32, [args.batch_size, args.local_image_size, args.local_image_size, args.input_channel_size])
x = tf.placeholder(tf.float32, [args.batch_size, args.input_size, args.input_size, args.input_channel_size])
mask = tf.placeholder(tf.float32, [args.batch_size, args.input_size, args.input_size, 1])
local_x = tf.placeholder(tf.float32, [args.batch_size, args.local_input_size, args.local_input_size, args.input_channel_size])
global_completion = tf.placeholder(tf.float32, [args.batch_size, args.input_size, args.input_size, args.input_channel_size])
local_completion = tf.placeholder(tf.float32, [args.batch_size, args.local_input_size, args.local_input_size, args.input_channel_size])
is_training = tf.placeholder(tf.bool, [])
model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=args.batch_size)
@ -118,17 +118,17 @@ def get_points():
points = []
mask = []
for i in range(args.batch_size):
x1, y1 = np.random.randint(0, args.image_size - args.local_image_size + 1, 2)
x2, y2 = np.array([x1, y1]) + args.local_image_size
x1, y1 = np.random.randint(0, args.input_size - args.local_input_size + 1, 2)
x2, y2 = np.array([x1, y1]) + args.local_input_size
points.append([x1, y1, x2, y2])
w, h = np.random.randint(args.min_mask_size, args.max_mask_size + 1, 2)
p1 = x1 + np.random.randint(0, args.local_image_size - w)
q1 = y1 + np.random.randint(0, args.local_image_size - h)
p1 = x1 + np.random.randint(0, args.local_input_size - w)
q1 = y1 + np.random.randint(0, args.local_input_size - h)
p2 = p1 + w
q2 = q1 + h
m = np.zeros((args.image_size, args.image_size, 1), dtype=np.uint8)
m = np.zeros((args.input_size, args.input_size, 1), dtype=np.uint8)
m[q1:q2 + 1, p1:p2 + 1] = 1
if (np.random.random() < args.rotate_chance):