mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2024-11-29 05:10:43 +00:00
Initialize mask_batch and x_batch
This commit is contained in:
parent
018c86c959
commit
a7779d608a
9
train.py
9
train.py
@ -51,12 +51,15 @@ def train(args):
|
||||
|
||||
np.random.shuffle(x_train)
|
||||
|
||||
x_batch = []
|
||||
mask_batch = []
|
||||
|
||||
# Completion
|
||||
if sess.run(epoch) <= PRETRAIN_EPOCH:
|
||||
g_loss_value = 0
|
||||
for i in tqdm.tqdm(range(step_num)):
|
||||
x_batch = x_train[i * args.batch_size:(i + 1) * args.batch_size]
|
||||
points_batch, mask_batch = get_points()
|
||||
_, mask_batch = get_points()
|
||||
|
||||
_, g_loss = sess.run([g_train_op, model.g_loss], feed_dict={x: x_batch, mask: mask_batch, is_training: True})
|
||||
g_loss_value += g_loss
|
||||
@ -68,7 +71,7 @@ def train(args):
|
||||
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
|
||||
sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
|
||||
result = Image.fromarray(sample)
|
||||
result.save(args.training_samples_path + '/{}.jpg'.format("{0:06d}".format(sess.run(epoch))))
|
||||
result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch))))
|
||||
|
||||
saver = tf.train.Saver()
|
||||
saver.save(sess, './models/latest', write_meta_graph=False)
|
||||
@ -135,7 +138,7 @@ def get_points():
|
||||
if (np.random.random() < args.rotate_chance):
|
||||
#rotate random amount between 0 and 90 degrees
|
||||
m = scipy.ndimage.rotate(m, np.random.random()*90, reshape = False)
|
||||
#set all elements greater than 0 to 1
|
||||
#set all elements greater than 0.5 to 1
|
||||
m[m > 0.5] = 1
|
||||
|
||||
mask.append(m)
|
||||
|
Loading…
Reference in New Issue
Block a user