diff --git a/train.py b/train.py index 53d850d..4067923 100644 --- a/train.py +++ b/train.py @@ -11,11 +11,11 @@ from config import * PRETRAIN_EPOCH = 100 def train(args): - x = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size]) - mask = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, 1]) - local_x = tf.placeholder(tf.float32, [args.batch_size, args.local_image_size, args.local_image_size, args.input_channel_size]) - global_completion = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size]) - local_completion = tf.placeholder(tf.float32, [args.batch_size, args.local_image_size, args.local_image_size, args.input_channel_size]) + x = tf.placeholder(tf.float32, [args.batch_size, args.input_size, args.input_size, args.input_channel_size]) + mask = tf.placeholder(tf.float32, [args.batch_size, args.input_size, args.input_size, 1]) + local_x = tf.placeholder(tf.float32, [args.batch_size, args.local_input_size, args.local_input_size, args.input_channel_size]) + global_completion = tf.placeholder(tf.float32, [args.batch_size, args.input_size, args.input_size, args.input_channel_size]) + local_completion = tf.placeholder(tf.float32, [args.batch_size, args.local_input_size, args.local_input_size, args.input_channel_size]) is_training = tf.placeholder(tf.bool, []) model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=args.batch_size) @@ -118,17 +118,17 @@ def get_points(): points = [] mask = [] for i in range(args.batch_size): - x1, y1 = np.random.randint(0, args.image_size - args.local_image_size + 1, 2) - x2, y2 = np.array([x1, y1]) + args.local_image_size + x1, y1 = np.random.randint(0, args.input_size - args.local_input_size + 1, 2) + x2, y2 = np.array([x1, y1]) + args.local_input_size points.append([x1, y1, x2, y2]) w, h = np.random.randint(args.min_mask_size, args.max_mask_size + 1, 2) - p1 = x1 + np.random.randint(0, args.local_image_size - w) - q1 = y1 + np.random.randint(0, args.local_image_size - h) + p1 = x1 + np.random.randint(0, args.local_input_size - w) + q1 = y1 + np.random.randint(0, args.local_input_size - h) p2 = p1 + w q2 = q1 + h - m = np.zeros((args.image_size, args.image_size, 1), dtype=np.uint8) + m = np.zeros((args.input_size, args.input_size, 1), dtype=np.uint8) m[q1:q2 + 1, p1:p2 + 1] = 1 if (np.random.random() < args.rotate_chance):