From f55ad8a0875209459b4426e126b81dc08ffe3e98 Mon Sep 17 00:00:00 2001 From: deep pomf Date: Wed, 28 Feb 2018 19:43:55 -0500 Subject: [PATCH] stop samples from being generated with small datasets --- train.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 15d28eb..d577a22 100644 --- a/train.py +++ b/train.py @@ -66,7 +66,6 @@ def train(args): print('Completion loss: {}'.format(g_loss_value)) - print(x_test.shape[0]) #stop gap solution. sample images only generated when number of test images greater than or equal to batch size if x_test.shape[0] >= args.batch_size: np.random.shuffle(x_test) @@ -110,12 +109,14 @@ def train(args): print('Completion loss: {}'.format(g_loss_value)) print('Discriminator loss: {}'.format(d_loss_value)) - np.random.shuffle(x_test) - x_batch = x_test[:args.batch_size] - completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False}) - sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8) - result = Image.fromarray(sample) - result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch)))) + #stop gap solution. sample images only generated when number of test images greater than or equal to batch size + if x_test.shape[0] >= args.batch_size: + np.random.shuffle(x_test) + x_batch = x_test[:args.batch_size] + completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False}) + sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8) + result = Image.fromarray(sample) + result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch)))) saver = tf.train.Saver() saver.save(sess, './models/latest', write_meta_graph=False)