From a7779d608a9c2d6a0d61f9a8de2460d8eaad7cd4 Mon Sep 17 00:00:00 2001 From: deep pomf Date: Wed, 28 Feb 2018 14:49:57 -0500 Subject: [PATCH] Initialize mask_batch and x_batch --- train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index e00f3d4..ed806b4 100644 --- a/train.py +++ b/train.py @@ -51,12 +51,15 @@ def train(args): np.random.shuffle(x_train) + x_batch = [] + mask_batch = [] + # Completion if sess.run(epoch) <= PRETRAIN_EPOCH: g_loss_value = 0 for i in tqdm.tqdm(range(step_num)): x_batch = x_train[i * args.batch_size:(i + 1) * args.batch_size] - points_batch, mask_batch = get_points() + _, mask_batch = get_points() _, g_loss = sess.run([g_train_op, model.g_loss], feed_dict={x: x_batch, mask: mask_batch, is_training: True}) g_loss_value += g_loss @@ -68,7 +71,7 @@ def train(args): completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False}) sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8) result = Image.fromarray(sample) - result.save(args.training_samples_path + '/{}.jpg'.format("{0:06d}".format(sess.run(epoch)))) + result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch)))) saver = tf.train.Saver() saver.save(sess, './models/latest', write_meta_graph=False) @@ -135,7 +138,7 @@ def get_points(): if (np.random.random() < args.rotate_chance): #rotate random amount between 0 and 90 degrees m = scipy.ndimage.rotate(m, np.random.random()*90, reshape = False) - #set all elements greater than 0 to 1 + #set all elements greater than 0.5 to 1 m[m > 0.5] = 1 mask.append(m)