This commit is contained in:
deep pomf 2018-02-09 21:03:26 -05:00
parent 1c89c5b153
commit f0c84d38d1
6 changed files with 31 additions and 12 deletions

View File

@ -16,8 +16,31 @@ Link coming soon
# Usage
## I. Decensoring hentai
## I. Prepare the training data
Put the images for training the "data/images" directory and convert images to npy format.
```
$ cd data
$ python to_npy.py
```
The dataset will not be released
## II. Train the GLCIC model
```
$ cd src
$ python train.py
```
You can download the trained model file: [glcic_model.tar.gz](
https://drive.google.com/open?id=1jvP2czv_gX8Q1l0tUPNWLV8HLacK6n_Q)
# To do
- Add a user interface
- Incorporate GAN loss into training

View File

@ -6,14 +6,13 @@ import os
import matplotlib.pyplot as plt
import sys
sys.path.append('..')
from network import Network
from model import Model
IMAGE_SIZE = 128
LOCAL_SIZE = 64
HOLE_MIN = 24
HOLE_MAX = 48
BATCH_SIZE = 16
PRETRAIN_EPOCH = 100
test_npy = './lfw.npy'
@ -25,13 +24,13 @@ def test():
local_completion = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3])
is_training = tf.placeholder(tf.bool, [])
model = Network(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE)
model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE)
sess = tf.Session()
init_op = tf.global_variables_initializer()
sess.run(init_op)
saver = tf.train.Saver()
saver.restore(sess, '../backup/latest')
saver.restore(sess, '../saved_models/latest')
x_test = np.load(test_npy)
np.random.shuffle(x_test)

3
src/test/.gitignore vendored
View File

@ -1,3 +0,0 @@
output/*
*.npy
!.gitkeep

View File

@ -33,9 +33,9 @@ def train():
init_op = tf.global_variables_initializer()
sess.run(init_op)
if tf.train.get_checkpoint_state('./backup'):
if tf.train.get_checkpoint_state('./saved_model'):
saver = tf.train.Saver()
saver.restore(sess, './backup/latest')
saver.restore(sess, './saved_model/latest')
x_train, x_test = load.load()
x_train = np.array([a / 127.5 - 1 for a in x_train])
@ -69,9 +69,9 @@ def train():
saver = tf.train.Saver()
saver.save(sess, './backup/latest', write_meta_graph=False)
saver.save(sess, './saved_model/latest', write_meta_graph=False)
if sess.run(epoch) == PRETRAIN_EPOCH:
saver.save(sess, './backup/pretrained', write_meta_graph=False)
saver.save(sess, './saved_model/pretrained', write_meta_graph=False)
# Discrimitation
@ -109,7 +109,7 @@ def train():
cv2.imwrite('./output/{}.jpg'.format("{0:06d}".format(sess.run(epoch))), cv2.cvtColor(sample, cv2.COLOR_RGB2BGR))
saver = tf.train.Saver()
saver.save(sess, './backup/latest', write_meta_graph=False)
saver.save(sess, './saved_model/latest', write_meta_graph=False)
def get_points():