only generate sample images during training when number of test images greater than batch size

This commit is contained in:
deep pomf 2018-02-28 16:10:33 -05:00
parent a7779d608a
commit 9aeffa0c6c

View File

@ -66,12 +66,15 @@ def train(args):
print('Completion loss: {}'.format(g_loss_value))
np.random.shuffle(x_test)
x_batch = x_test[:args.batch_size]
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
result = Image.fromarray(sample)
result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch))))
print(x_test.shape[0])
#stop gap solution. sample images only generated when number of test images greater than or equal to batch size
if x_test.shape[0] >= args.batch_size:
np.random.shuffle(x_test)
x_batch = x_test[:args.batch_size]
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
result = Image.fromarray(sample)
result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch))))
saver = tf.train.Saver()
saver.save(sess, './models/latest', write_meta_graph=False)