Initialize mask_batch and x_batch

This commit is contained in:
deep pomf 2018-02-28 14:49:57 -05:00
parent 018c86c959
commit a7779d608a

View File

@ -51,12 +51,15 @@ def train(args):
np.random.shuffle(x_train) np.random.shuffle(x_train)
x_batch = []
mask_batch = []
# Completion # Completion
if sess.run(epoch) <= PRETRAIN_EPOCH: if sess.run(epoch) <= PRETRAIN_EPOCH:
g_loss_value = 0 g_loss_value = 0
for i in tqdm.tqdm(range(step_num)): for i in tqdm.tqdm(range(step_num)):
x_batch = x_train[i * args.batch_size:(i + 1) * args.batch_size] x_batch = x_train[i * args.batch_size:(i + 1) * args.batch_size]
points_batch, mask_batch = get_points() _, mask_batch = get_points()
_, g_loss = sess.run([g_train_op, model.g_loss], feed_dict={x: x_batch, mask: mask_batch, is_training: True}) _, g_loss = sess.run([g_train_op, model.g_loss], feed_dict={x: x_batch, mask: mask_batch, is_training: True})
g_loss_value += g_loss g_loss_value += g_loss
@ -68,7 +71,7 @@ def train(args):
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False}) completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8) sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
result = Image.fromarray(sample) result = Image.fromarray(sample)
result.save(args.training_samples_path + '/{}.jpg'.format("{0:06d}".format(sess.run(epoch)))) result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch))))
saver = tf.train.Saver() saver = tf.train.Saver()
saver.save(sess, './models/latest', write_meta_graph=False) saver.save(sess, './models/latest', write_meta_graph=False)
@ -135,7 +138,7 @@ def get_points():
if (np.random.random() < args.rotate_chance): if (np.random.random() < args.rotate_chance):
#rotate random amount between 0 and 90 degrees #rotate random amount between 0 and 90 degrees
m = scipy.ndimage.rotate(m, np.random.random()*90, reshape = False) m = scipy.ndimage.rotate(m, np.random.random()*90, reshape = False)
#set all elements greater than 0 to 1 #set all elements greater than 0.5 to 1
m[m > 0.5] = 1 m[m > 0.5] = 1
mask.append(m) mask.append(m)