generate training samples folder at runtime

This commit is contained in:
deeppomf 2018-02-26 22:42:59 -05:00
parent 05cbc95ce6
commit 44fb1a32b6
2 changed files with 4 additions and 2 deletions

View File

@ -49,7 +49,7 @@ parser.add_argument('--learning_rate', dest='learning_rate', default=0.001, help
# parser.add_argument('--checkpoints_path', dest='checkpoints_path', default='./checkpoints/', help='saved model checkpoint path')
# parser.add_argument('--graph_path', dest='graph_path', default='./graphs/', help='tensorboard graph')
# parser.add_argument('--images_path', dest='images_path', default='./images/', help='result images path')
parser.add_argument('--training_samples_path', dest='training_samples_path', default='./training_samples/', help='samples images generated during training path')
args = parser.parse_args()

View File

@ -108,7 +108,7 @@ def train(args):
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
result = Image.fromarray(sample)
result.save('./training_output_images/{}.jpg'.format("{0:06d}".format(sess.run(epoch))))
result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch))))
saver = tf.train.Saver()
saver.save(sess, './models/latest', write_meta_graph=False)
@ -143,4 +143,6 @@ def get_points():
if __name__ == '__main__':
if not os.path.exists(args.training_samples_path):
os.makedirs(args.training_samples_path)
train(args)