stop samples from being generated with small datasets

This commit is contained in:
deep pomf 2018-02-28 19:43:55 -05:00
parent 9aeffa0c6c
commit f55ad8a087

View File

@ -66,7 +66,6 @@ def train(args):
print('Completion loss: {}'.format(g_loss_value)) print('Completion loss: {}'.format(g_loss_value))
print(x_test.shape[0])
#stop gap solution. sample images only generated when number of test images greater than or equal to batch size #stop gap solution. sample images only generated when number of test images greater than or equal to batch size
if x_test.shape[0] >= args.batch_size: if x_test.shape[0] >= args.batch_size:
np.random.shuffle(x_test) np.random.shuffle(x_test)
@ -110,6 +109,8 @@ def train(args):
print('Completion loss: {}'.format(g_loss_value)) print('Completion loss: {}'.format(g_loss_value))
print('Discriminator loss: {}'.format(d_loss_value)) print('Discriminator loss: {}'.format(d_loss_value))
#stop gap solution. sample images only generated when number of test images greater than or equal to batch size
if x_test.shape[0] >= args.batch_size:
np.random.shuffle(x_test) np.random.shuffle(x_test)
x_batch = x_test[:args.batch_size] x_batch = x_test[:args.batch_size]
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False}) completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})