mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2025-03-23 13:20:54 +00:00
more args
This commit is contained in:
parent
1b28796199
commit
c0c817c10d
@ -19,11 +19,12 @@ parser.add_argument('--input_channel_size', dest='input_channel_size', default=3
|
||||
parser.add_argument('--min_mask_size', dest='min_mask_size', default=24, help='minimum mask size')
|
||||
parser.add_argument('--max_mask_size', dest='max_mask_size', default=48, help='maximum mask size')
|
||||
parser.add_argument('--rotate_chance', dest='rotate_chance', default=0.5, help='chance the mask will be randomly rotated')
|
||||
parser.add_argument('--train_mosaic', dest ='train_mosaic', default=False, help='train neural network to decensor mosaics')
|
||||
|
||||
# parser.add_argument('--input_dim', dest='input_dim', default=100, help='input z size')
|
||||
|
||||
# #Training Settings
|
||||
# parser.add_argument('--continue_training', dest='continue_training', default=False, type=str2bool, help='flag to continue training')
|
||||
parser.add_argument('--continue_training', dest='continue_training', default=False, type=str2bool, help='flag to continue training')
|
||||
|
||||
# parser.add_argument('--data', dest='data', default='../ambientGAN_TF/data', help='cats image train path')
|
||||
|
||||
|
11
train.py
11
train.py
@ -10,7 +10,6 @@ import load
|
||||
from config import *
|
||||
|
||||
PRETRAIN_EPOCH = 100
|
||||
#the chance the rectangle crop will be rotated
|
||||
|
||||
def train(args):
|
||||
x = tf.placeholder(tf.float32, [args.batch_size, args.image_size, args.image_size, args.input_channel_size])
|
||||
@ -32,9 +31,13 @@ def train(args):
|
||||
init_op = tf.global_variables_initializer()
|
||||
sess.run(init_op)
|
||||
|
||||
if tf.train.get_checkpoint_state('./models'):
|
||||
saver = tf.train.Saver()
|
||||
saver.restore(sess, './models/latest')
|
||||
if args.continue_training:
|
||||
if tf.train.get_checkpoint_state('./models'):
|
||||
print("Continuing training from checkpoint.")
|
||||
saver = tf.train.Saver()
|
||||
saver.restore(sess, './models/latest')
|
||||
else:
|
||||
print("Checkpoint not found! Training new model from scratch.")
|
||||
|
||||
x_train, x_test = load.load()
|
||||
x_train = np.array([a / 127.5 - 1 for a in x_train])
|
||||
|
Loading…
x
Reference in New Issue
Block a user