more args

This commit is contained in:
deep pomf 2018-02-26 10:55:48 -05:00
parent 1b28796199
commit c0c817c10d
2 changed files with 9 additions and 5 deletions

View File

@ -19,11 +19,12 @@ parser.add_argument('--input_channel_size', dest='input_channel_size', default=3
parser.add_argument('--min_mask_size', dest='min_mask_size', default=24, help='minimum mask size')
parser.add_argument('--max_mask_size', dest='max_mask_size', default=48, help='maximum mask size')
parser.add_argument('--rotate_chance', dest='rotate_chance', default=0.5, help='chance the mask will be randomly rotated')
parser.add_argument('--train_mosaic', dest ='train_mosaic', default=False, help='train neural network to decensor mosaics')
# parser.add_argument('--input_dim', dest='input_dim', default=100, help='input z size')
# #Training Settings
# parser.add_argument('--continue_training', dest='continue_training', default=False, type=str2bool, help='flag to continue training')
parser.add_argument('--continue_training', dest='continue_training', default=False, type=str2bool, help='flag to continue training')
# parser.add_argument('--data', dest='data', default='../ambientGAN_TF/data', help='cats image train path')

View File

@ -10,7 +10,6 @@ import load
from config import *
PRETRAIN_EPOCH = 100
#the chance the rectangle crop will be rotated
def train(args):
x = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size])
@ -32,9 +31,13 @@ def train(args):
init_op = tf.global_variables_initializer()
sess.run(init_op)
if tf.train.get_checkpoint_state('./models'):
saver = tf.train.Saver()
saver.restore(sess, './models/latest')
if args.continue_training:
if tf.train.get_checkpoint_state('./models'):
print("Continuing training from checkpoint.")
saver = tf.train.Saver()
saver.restore(sess, './models/latest')
else:
print("Checkpoint not found! Training new model from scratch.")
x_train, x_test = load.load()
x_train = np.array([a / 127.5 - 1 for a in x_train])