mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2025-03-22 10:20:57 +00:00
generate training samples folder at runtime
This commit is contained in:
parent
05cbc95ce6
commit
44fb1a32b6
@ -49,7 +49,7 @@ parser.add_argument('--learning_rate', dest='learning_rate', default=0.001, help
|
|||||||
# parser.add_argument('--checkpoints_path', dest='checkpoints_path', default='./checkpoints/', help='saved model checkpoint path')
|
# parser.add_argument('--checkpoints_path', dest='checkpoints_path', default='./checkpoints/', help='saved model checkpoint path')
|
||||||
# parser.add_argument('--graph_path', dest='graph_path', default='./graphs/', help='tensorboard graph')
|
# parser.add_argument('--graph_path', dest='graph_path', default='./graphs/', help='tensorboard graph')
|
||||||
# parser.add_argument('--images_path', dest='images_path', default='./images/', help='result images path')
|
# parser.add_argument('--images_path', dest='images_path', default='./images/', help='result images path')
|
||||||
|
parser.add_argument('--training_samples_path', dest='training_samples_path', default='./training_samples/', help='samples images generated during training path')
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
4
train.py
4
train.py
@ -108,7 +108,7 @@ def train(args):
|
|||||||
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
|
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
|
||||||
sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
|
sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
|
||||||
result = Image.fromarray(sample)
|
result = Image.fromarray(sample)
|
||||||
result.save('./training_output_images/{}.jpg'.format("{0:06d}".format(sess.run(epoch))))
|
result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch))))
|
||||||
|
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
saver.save(sess, './models/latest', write_meta_graph=False)
|
saver.save(sess, './models/latest', write_meta_graph=False)
|
||||||
@ -143,4 +143,6 @@ def get_points():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
if not os.path.exists(args.training_samples_path):
|
||||||
|
os.makedirs(args.training_samples_path)
|
||||||
train(args)
|
train(args)
|
Loading…
x
Reference in New Issue
Block a user