From 9aeffa0c6c54afaa54cd654d5c0451983389e555 Mon Sep 17 00:00:00 2001 From: deep pomf Date: Wed, 28 Feb 2018 16:10:33 -0500 Subject: [PATCH] only generate sample images during training when number of test images greater than batch size --- train.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index ed806b4..15d28eb 100644 --- a/train.py +++ b/train.py @@ -66,12 +66,15 @@ def train(args): print('Completion loss: {}'.format(g_loss_value)) - np.random.shuffle(x_test) - x_batch = x_test[:args.batch_size] - completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False}) - sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8) - result = Image.fromarray(sample) - result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch)))) + print(x_test.shape[0]) + #stop gap solution. sample images only generated when number of test images greater than or equal to batch size + if x_test.shape[0] >= args.batch_size: + np.random.shuffle(x_test) + x_batch = x_test[:args.batch_size] + completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False}) + sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8) + result = Image.fromarray(sample) + result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch)))) saver = tf.train.Saver() saver.save(sess, './models/latest', write_meta_graph=False)