mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2024-11-29 05:10:43 +00:00
Initialize mask_batch and x_batch
This commit is contained in:
parent
018c86c959
commit
a7779d608a
9
train.py
9
train.py
@ -51,12 +51,15 @@ def train(args):
|
|||||||
|
|
||||||
np.random.shuffle(x_train)
|
np.random.shuffle(x_train)
|
||||||
|
|
||||||
|
x_batch = []
|
||||||
|
mask_batch = []
|
||||||
|
|
||||||
# Completion
|
# Completion
|
||||||
if sess.run(epoch) <= PRETRAIN_EPOCH:
|
if sess.run(epoch) <= PRETRAIN_EPOCH:
|
||||||
g_loss_value = 0
|
g_loss_value = 0
|
||||||
for i in tqdm.tqdm(range(step_num)):
|
for i in tqdm.tqdm(range(step_num)):
|
||||||
x_batch = x_train[i * args.batch_size:(i + 1) * args.batch_size]
|
x_batch = x_train[i * args.batch_size:(i + 1) * args.batch_size]
|
||||||
points_batch, mask_batch = get_points()
|
_, mask_batch = get_points()
|
||||||
|
|
||||||
_, g_loss = sess.run([g_train_op, model.g_loss], feed_dict={x: x_batch, mask: mask_batch, is_training: True})
|
_, g_loss = sess.run([g_train_op, model.g_loss], feed_dict={x: x_batch, mask: mask_batch, is_training: True})
|
||||||
g_loss_value += g_loss
|
g_loss_value += g_loss
|
||||||
@ -68,7 +71,7 @@ def train(args):
|
|||||||
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
|
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
|
||||||
sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
|
sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
|
||||||
result = Image.fromarray(sample)
|
result = Image.fromarray(sample)
|
||||||
result.save(args.training_samples_path + '/{}.jpg'.format("{0:06d}".format(sess.run(epoch))))
|
result.save(args.training_samples_path + '{}.jpg'.format("{0:06d}".format(sess.run(epoch))))
|
||||||
|
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
saver.save(sess, './models/latest', write_meta_graph=False)
|
saver.save(sess, './models/latest', write_meta_graph=False)
|
||||||
@ -135,7 +138,7 @@ def get_points():
|
|||||||
if (np.random.random() < args.rotate_chance):
|
if (np.random.random() < args.rotate_chance):
|
||||||
#rotate random amount between 0 and 90 degrees
|
#rotate random amount between 0 and 90 degrees
|
||||||
m = scipy.ndimage.rotate(m, np.random.random()*90, reshape = False)
|
m = scipy.ndimage.rotate(m, np.random.random()*90, reshape = False)
|
||||||
#set all elements greater than 0 to 1
|
#set all elements greater than 0.5 to 1
|
||||||
m[m > 0.5] = 1
|
m[m > 0.5] = 1
|
||||||
|
|
||||||
mask.append(m)
|
mask.append(m)
|
||||||
|
Loading…
Reference in New Issue
Block a user