From c0c817c10d53f52ebb8710026949ebc4db3841a4 Mon Sep 17 00:00:00 2001 From: deep pomf Date: Mon, 26 Feb 2018 10:55:48 -0500 Subject: [PATCH] more args --- config.py | 3 ++- train.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/config.py b/config.py index 1b5be65..cf08e8d 100644 --- a/config.py +++ b/config.py @@ -19,11 +19,12 @@ parser.add_argument('--input_channel_size', dest='input_channel_size', default=3 parser.add_argument('--min_mask_size', dest='min_mask_size', default=24, help='minimum mask size') parser.add_argument('--max_mask_size', dest='max_mask_size', default=48, help='maximum mask size') parser.add_argument('--rotate_chance', dest='rotate_chance', default=0.5, help='chance the mask will be randomly rotated') +parser.add_argument('--train_mosaic', dest ='train_mosaic', default=False, help='train neural network to decensor mosaics') # parser.add_argument('--input_dim', dest='input_dim', default=100, help='input z size') # #Training Settings -# parser.add_argument('--continue_training', dest='continue_training', default=False, type=str2bool, help='flag to continue training') +parser.add_argument('--continue_training', dest='continue_training', default=False, type=str2bool, help='flag to continue training') # parser.add_argument('--data', dest='data', default='../ambientGAN_TF/data', help='cats image train path') diff --git a/train.py b/train.py index 8a4d376..48f2808 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,6 @@ import load from config import * PRETRAIN_EPOCH = 100 -#the chance the rectangle crop will be rotated def train(args): x = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size]) @@ -32,9 +31,13 @@ def train(args): init_op = tf.global_variables_initializer() sess.run(init_op) - if tf.train.get_checkpoint_state('./models'): - saver = tf.train.Saver() - saver.restore(sess, './models/latest') + if args.continue_training: + if tf.train.get_checkpoint_state('./models'): + print("Continuing training from checkpoint.") + saver = tf.train.Saver() + saver.restore(sess, './models/latest') + else: + print("Checkpoint not found! Training new model from scratch.") x_train, x_test = load.load() x_train = np.array([a / 127.5 - 1 for a in x_train])