From f0c84d38d14a40b39bbc517d8329601930284895 Mon Sep 17 00:00:00 2001 From: deep pomf Date: Fri, 9 Feb 2018 21:03:26 -0500 Subject: [PATCH] update --- README.md | 23 +++++++++++++++++++++++ src/{backup => saved_models}/.gitkeep | 0 src/{test => }/test.py | 7 +++---- src/test/.gitignore | 3 --- src/test/output/.gitkeep | 0 src/train.py | 10 +++++----- 6 files changed, 31 insertions(+), 12 deletions(-) rename src/{backup => saved_models}/.gitkeep (100%) rename src/{test => }/test.py (93%) delete mode 100644 src/test/.gitignore delete mode 100644 src/test/output/.gitkeep diff --git a/README.md b/README.md index ceab481..dcf7754 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,31 @@ Link coming soon # Usage +## I. Decensoring hentai + +## I. Prepare the training data + +Put the images for training the "data/images" directory and convert images to npy format. + +``` +$ cd data +$ python to_npy.py +``` + +The dataset will not be released + +## II. Train the GLCIC model + +``` +$ cd src +$ python train.py +``` + +You can download the trained model file: [glcic_model.tar.gz]( +https://drive.google.com/open?id=1jvP2czv_gX8Q1l0tUPNWLV8HLacK6n_Q) + # To do - Add a user interface - Incorporate GAN loss into training diff --git a/src/backup/.gitkeep b/src/saved_models/.gitkeep similarity index 100% rename from src/backup/.gitkeep rename to src/saved_models/.gitkeep diff --git a/src/test/test.py b/src/test.py similarity index 93% rename from src/test/test.py rename to src/test.py index 7dc4ad8..8393a55 100644 --- a/src/test/test.py +++ b/src/test.py @@ -6,14 +6,13 @@ import os import matplotlib.pyplot as plt import sys sys.path.append('..') -from network import Network +from model import Model IMAGE_SIZE = 128 LOCAL_SIZE = 64 HOLE_MIN = 24 HOLE_MAX = 48 BATCH_SIZE = 16 -PRETRAIN_EPOCH = 100 test_npy = './lfw.npy' @@ -25,13 +24,13 @@ def test(): local_completion = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3]) is_training = tf.placeholder(tf.bool, []) - model = Network(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE) + model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE) sess = tf.Session() init_op = tf.global_variables_initializer() sess.run(init_op) saver = tf.train.Saver() - saver.restore(sess, '../backup/latest') + saver.restore(sess, '../saved_models/latest') x_test = np.load(test_npy) np.random.shuffle(x_test) diff --git a/src/test/.gitignore b/src/test/.gitignore deleted file mode 100644 index 476f49e..0000000 --- a/src/test/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -output/* -*.npy -!.gitkeep diff --git a/src/test/output/.gitkeep b/src/test/output/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/src/train.py b/src/train.py index 477443b..5f14497 100644 --- a/src/train.py +++ b/src/train.py @@ -33,9 +33,9 @@ def train(): init_op = tf.global_variables_initializer() sess.run(init_op) - if tf.train.get_checkpoint_state('./backup'): + if tf.train.get_checkpoint_state('./saved_model'): saver = tf.train.Saver() - saver.restore(sess, './backup/latest') + saver.restore(sess, './saved_model/latest') x_train, x_test = load.load() x_train = np.array([a / 127.5 - 1 for a in x_train]) @@ -69,9 +69,9 @@ def train(): saver = tf.train.Saver() - saver.save(sess, './backup/latest', write_meta_graph=False) + saver.save(sess, './saved_model/latest', write_meta_graph=False) if sess.run(epoch) == PRETRAIN_EPOCH: - saver.save(sess, './backup/pretrained', write_meta_graph=False) + saver.save(sess, './saved_model/pretrained', write_meta_graph=False) # Discrimitation @@ -109,7 +109,7 @@ def train(): cv2.imwrite('./output/{}.jpg'.format("{0:06d}".format(sess.run(epoch))), cv2.cvtColor(sample, cv2.COLOR_RGB2BGR)) saver = tf.train.Saver() - saver.save(sess, './backup/latest', write_meta_graph=False) + saver.save(sess, './saved_model/latest', write_meta_graph=False) def get_points():