mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2025-01-28 00:35:16 +00:00
update
This commit is contained in:
parent
1c89c5b153
commit
f0c84d38d1
23
README.md
23
README.md
@ -16,8 +16,31 @@ Link coming soon
|
|||||||
|
|
||||||
# Usage
|
# Usage
|
||||||
|
|
||||||
|
## I. Decensoring hentai
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## I. Prepare the training data
|
||||||
|
|
||||||
|
Put the images for training the "data/images" directory and convert images to npy format.
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cd data
|
||||||
|
$ python to_npy.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The dataset will not be released
|
||||||
|
|
||||||
|
## II. Train the GLCIC model
|
||||||
|
|
||||||
|
```
|
||||||
|
$ cd src
|
||||||
|
$ python train.py
|
||||||
|
```
|
||||||
|
|
||||||
|
You can download the trained model file: [glcic_model.tar.gz](
|
||||||
|
https://drive.google.com/open?id=1jvP2czv_gX8Q1l0tUPNWLV8HLacK6n_Q)
|
||||||
|
|
||||||
# To do
|
# To do
|
||||||
- Add a user interface
|
- Add a user interface
|
||||||
- Incorporate GAN loss into training
|
- Incorporate GAN loss into training
|
||||||
|
@ -6,14 +6,13 @@ import os
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import sys
|
import sys
|
||||||
sys.path.append('..')
|
sys.path.append('..')
|
||||||
from network import Network
|
from model import Model
|
||||||
|
|
||||||
IMAGE_SIZE = 128
|
IMAGE_SIZE = 128
|
||||||
LOCAL_SIZE = 64
|
LOCAL_SIZE = 64
|
||||||
HOLE_MIN = 24
|
HOLE_MIN = 24
|
||||||
HOLE_MAX = 48
|
HOLE_MAX = 48
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
PRETRAIN_EPOCH = 100
|
|
||||||
|
|
||||||
test_npy = './lfw.npy'
|
test_npy = './lfw.npy'
|
||||||
|
|
||||||
@ -25,13 +24,13 @@ def test():
|
|||||||
local_completion = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3])
|
local_completion = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3])
|
||||||
is_training = tf.placeholder(tf.bool, [])
|
is_training = tf.placeholder(tf.bool, [])
|
||||||
|
|
||||||
model = Network(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE)
|
model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE)
|
||||||
sess = tf.Session()
|
sess = tf.Session()
|
||||||
init_op = tf.global_variables_initializer()
|
init_op = tf.global_variables_initializer()
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
|
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
saver.restore(sess, '../backup/latest')
|
saver.restore(sess, '../saved_models/latest')
|
||||||
|
|
||||||
x_test = np.load(test_npy)
|
x_test = np.load(test_npy)
|
||||||
np.random.shuffle(x_test)
|
np.random.shuffle(x_test)
|
3
src/test/.gitignore
vendored
3
src/test/.gitignore
vendored
@ -1,3 +0,0 @@
|
|||||||
output/*
|
|
||||||
*.npy
|
|
||||||
!.gitkeep
|
|
10
src/train.py
10
src/train.py
@ -33,9 +33,9 @@ def train():
|
|||||||
init_op = tf.global_variables_initializer()
|
init_op = tf.global_variables_initializer()
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
|
|
||||||
if tf.train.get_checkpoint_state('./backup'):
|
if tf.train.get_checkpoint_state('./saved_model'):
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
saver.restore(sess, './backup/latest')
|
saver.restore(sess, './saved_model/latest')
|
||||||
|
|
||||||
x_train, x_test = load.load()
|
x_train, x_test = load.load()
|
||||||
x_train = np.array([a / 127.5 - 1 for a in x_train])
|
x_train = np.array([a / 127.5 - 1 for a in x_train])
|
||||||
@ -69,9 +69,9 @@ def train():
|
|||||||
|
|
||||||
|
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
saver.save(sess, './backup/latest', write_meta_graph=False)
|
saver.save(sess, './saved_model/latest', write_meta_graph=False)
|
||||||
if sess.run(epoch) == PRETRAIN_EPOCH:
|
if sess.run(epoch) == PRETRAIN_EPOCH:
|
||||||
saver.save(sess, './backup/pretrained', write_meta_graph=False)
|
saver.save(sess, './saved_model/pretrained', write_meta_graph=False)
|
||||||
|
|
||||||
|
|
||||||
# Discrimitation
|
# Discrimitation
|
||||||
@ -109,7 +109,7 @@ def train():
|
|||||||
cv2.imwrite('./output/{}.jpg'.format("{0:06d}".format(sess.run(epoch))), cv2.cvtColor(sample, cv2.COLOR_RGB2BGR))
|
cv2.imwrite('./output/{}.jpg'.format("{0:06d}".format(sess.run(epoch))), cv2.cvtColor(sample, cv2.COLOR_RGB2BGR))
|
||||||
|
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
saver.save(sess, './backup/latest', write_meta_graph=False)
|
saver.save(sess, './saved_model/latest', write_meta_graph=False)
|
||||||
|
|
||||||
|
|
||||||
def get_points():
|
def get_points():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user