diff --git a/config.py b/config.py index cf08e8d..090e9d5 100644 --- a/config.py +++ b/config.py @@ -49,7 +49,7 @@ parser.add_argument('--learning_rate', dest='learning_rate', default=0.001, help # parser.add_argument('--checkpoints_path', dest='checkpoints_path', default='./checkpoints/', help='saved model checkpoint path') # parser.add_argument('--graph_path', dest='graph_path', default='./graphs/', help='tensorboard graph') # parser.add_argument('--images_path', dest='images_path', default='./images/', help='result images path') - +parser.add_argument('--training_samples_path', dest='training_samples_path', default='./training_samples/', help='samples images generated during training path') args = parser.parse_args() \ No newline at end of file diff --git a/train.py b/train.py index 4067923..3bd4c9f 100644 --- a/train.py +++ b/train.py @@ -108,7 +108,7 @@ def train(args): completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False}) sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8) result = Image.fromarray(sample) - result.save('./training_output_images/{}.jpg'.format("{0:06d}".format(sess.run(epoch)))) + result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch)))) saver = tf.train.Saver() saver.save(sess, './models/latest', write_meta_graph=False) @@ -143,4 +143,6 @@ def get_points(): if __name__ == '__main__': + if not os.path.exists(args.training_samples_path): + os.makedirs(args.training_samples_path) train(args) \ No newline at end of file