mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2024-11-29 05:10:43 +00:00
fix args
This commit is contained in:
parent
76b5903772
commit
7e8c8c2967
20
train.py
20
train.py
@ -11,11 +11,11 @@ from config import *
|
||||
PRETRAIN_EPOCH = 100
|
||||
|
||||
def train(args):
|
||||
x = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size])
|
||||
mask = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, 1])
|
||||
local_x = tf.placeholder(tf.float32, [args.batch_size, args.local_image_size, args.local_image_size, args.input_channel_size])
|
||||
global_completion = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size])
|
||||
local_completion = tf.placeholder(tf.float32, [args.batch_size, args.local_image_size, args.local_image_size, args.input_channel_size])
|
||||
x = tf.placeholder(tf.float32, [args.batch_size, args.input_size, args.input_size, args.input_channel_size])
|
||||
mask = tf.placeholder(tf.float32, [args.batch_size, args.input_size, args.input_size, 1])
|
||||
local_x = tf.placeholder(tf.float32, [args.batch_size, args.local_input_size, args.local_input_size, args.input_channel_size])
|
||||
global_completion = tf.placeholder(tf.float32, [args.batch_size, args.input_size, args.input_size, args.input_channel_size])
|
||||
local_completion = tf.placeholder(tf.float32, [args.batch_size, args.local_input_size, args.local_input_size, args.input_channel_size])
|
||||
is_training = tf.placeholder(tf.bool, [])
|
||||
|
||||
model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=args.batch_size)
|
||||
@ -118,17 +118,17 @@ def get_points():
|
||||
points = []
|
||||
mask = []
|
||||
for i in range(args.batch_size):
|
||||
x1, y1 = np.random.randint(0, args.image_size - args.local_image_size + 1, 2)
|
||||
x2, y2 = np.array([x1, y1]) + args.local_image_size
|
||||
x1, y1 = np.random.randint(0, args.input_size - args.local_input_size + 1, 2)
|
||||
x2, y2 = np.array([x1, y1]) + args.local_input_size
|
||||
points.append([x1, y1, x2, y2])
|
||||
|
||||
w, h = np.random.randint(args.min_mask_size, args.max_mask_size + 1, 2)
|
||||
p1 = x1 + np.random.randint(0, args.local_image_size - w)
|
||||
q1 = y1 + np.random.randint(0, args.local_image_size - h)
|
||||
p1 = x1 + np.random.randint(0, args.local_input_size - w)
|
||||
q1 = y1 + np.random.randint(0, args.local_input_size - h)
|
||||
p2 = p1 + w
|
||||
q2 = q1 + h
|
||||
|
||||
m = np.zeros((args.image_size, args.image_size, 1), dtype=np.uint8)
|
||||
m = np.zeros((args.input_size, args.input_size, 1), dtype=np.uint8)
|
||||
m[q1:q2 + 1, p1:p2 + 1] = 1
|
||||
|
||||
if (np.random.random() < args.rotate_chance):
|
||||
|
Loading…
Reference in New Issue
Block a user