mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2024-11-29 05:10:43 +00:00
fix args
This commit is contained in:
parent
76b5903772
commit
7e8c8c2967
20
train.py
20
train.py
@ -11,11 +11,11 @@ from config import *
|
|||||||
PRETRAIN_EPOCH = 100
|
PRETRAIN_EPOCH = 100
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
x = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size])
|
x = tf.placeholder(tf.float32, [args.batch_size, args.input_size, args.input_size, args.input_channel_size])
|
||||||
mask = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, 1])
|
mask = tf.placeholder(tf.float32, [args.batch_size, args.input_size, args.input_size, 1])
|
||||||
local_x = tf.placeholder(tf.float32, [args.batch_size, args.local_image_size, args.local_image_size, args.input_channel_size])
|
local_x = tf.placeholder(tf.float32, [args.batch_size, args.local_input_size, args.local_input_size, args.input_channel_size])
|
||||||
global_completion = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size])
|
global_completion = tf.placeholder(tf.float32, [args.batch_size, args.input_size, args.input_size, args.input_channel_size])
|
||||||
local_completion = tf.placeholder(tf.float32, [args.batch_size, args.local_image_size, args.local_image_size, args.input_channel_size])
|
local_completion = tf.placeholder(tf.float32, [args.batch_size, args.local_input_size, args.local_input_size, args.input_channel_size])
|
||||||
is_training = tf.placeholder(tf.bool, [])
|
is_training = tf.placeholder(tf.bool, [])
|
||||||
|
|
||||||
model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=args.batch_size)
|
model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=args.batch_size)
|
||||||
@ -118,17 +118,17 @@ def get_points():
|
|||||||
points = []
|
points = []
|
||||||
mask = []
|
mask = []
|
||||||
for i in range(args.batch_size):
|
for i in range(args.batch_size):
|
||||||
x1, y1 = np.random.randint(0, args.image_size - args.local_image_size + 1, 2)
|
x1, y1 = np.random.randint(0, args.input_size - args.local_input_size + 1, 2)
|
||||||
x2, y2 = np.array([x1, y1]) + args.local_image_size
|
x2, y2 = np.array([x1, y1]) + args.local_input_size
|
||||||
points.append([x1, y1, x2, y2])
|
points.append([x1, y1, x2, y2])
|
||||||
|
|
||||||
w, h = np.random.randint(args.min_mask_size, args.max_mask_size + 1, 2)
|
w, h = np.random.randint(args.min_mask_size, args.max_mask_size + 1, 2)
|
||||||
p1 = x1 + np.random.randint(0, args.local_image_size - w)
|
p1 = x1 + np.random.randint(0, args.local_input_size - w)
|
||||||
q1 = y1 + np.random.randint(0, args.local_image_size - h)
|
q1 = y1 + np.random.randint(0, args.local_input_size - h)
|
||||||
p2 = p1 + w
|
p2 = p1 + w
|
||||||
q2 = q1 + h
|
q2 = q1 + h
|
||||||
|
|
||||||
m = np.zeros((args.image_size, args.image_size, 1), dtype=np.uint8)
|
m = np.zeros((args.input_size, args.input_size, 1), dtype=np.uint8)
|
||||||
m[q1:q2 + 1, p1:p2 + 1] = 1
|
m[q1:q2 + 1, p1:p2 + 1] = 1
|
||||||
|
|
||||||
if (np.random.random() < args.rotate_chance):
|
if (np.random.random() < args.rotate_chance):
|
||||||
|
Loading…
Reference in New Issue
Block a user