From 44fb1a32b68e9057ea08cfef0aebe101f03ac1e2 Mon Sep 17 00:00:00 2001 From: deeppomf Date: Mon, 26 Feb 2018 22:42:59 -0500 Subject: [PATCH] generate training samples folder at runtime --- config.py | 2 +- train.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/config.py b/config.py index cf08e8d..090e9d5 100644 --- a/config.py +++ b/config.py @@ -49,7 +49,7 @@ parser.add_argument('--learning_rate', dest='learning_rate', default=0.001, help # parser.add_argument('--checkpoints_path', dest='checkpoints_path', default='./checkpoints/', help='saved model checkpoint path') # parser.add_argument('--graph_path', dest='graph_path', default='./graphs/', help='tensorboard graph') # parser.add_argument('--images_path', dest='images_path', default='./images/', help='result images path') - +parser.add_argument('--training_samples_path', dest='training_samples_path', default='./training_samples/', help='samples images generated during training path') args = parser.parse_args() \ No newline at end of file diff --git a/train.py b/train.py index 4067923..3bd4c9f 100644 --- a/train.py +++ b/train.py @@ -108,7 +108,7 @@ def train(args): completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False}) sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8) result = Image.fromarray(sample) - result.save('./training_output_images/{}.jpg'.format("{0:06d}".format(sess.run(epoch)))) + result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch)))) saver = tf.train.Saver() saver.save(sess, './models/latest', write_meta_graph=False) @@ -143,4 +143,6 @@ def get_points(): if __name__ == '__main__': + if not os.path.exists(args.training_samples_path): + os.makedirs(args.training_samples_path) train(args) \ No newline at end of file