mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2025-01-26 03:35:28 +00:00
update
This commit is contained in:
parent
1c89c5b153
commit
f0c84d38d1
23
README.md
23
README.md
@ -16,8 +16,31 @@ Link coming soon
|
||||
|
||||
# Usage
|
||||
|
||||
## I. Decensoring hentai
|
||||
|
||||
|
||||
|
||||
## I. Prepare the training data
|
||||
|
||||
Put the images for training the "data/images" directory and convert images to npy format.
|
||||
|
||||
```
|
||||
$ cd data
|
||||
$ python to_npy.py
|
||||
```
|
||||
|
||||
The dataset will not be released
|
||||
|
||||
## II. Train the GLCIC model
|
||||
|
||||
```
|
||||
$ cd src
|
||||
$ python train.py
|
||||
```
|
||||
|
||||
You can download the trained model file: [glcic_model.tar.gz](
|
||||
https://drive.google.com/open?id=1jvP2czv_gX8Q1l0tUPNWLV8HLacK6n_Q)
|
||||
|
||||
# To do
|
||||
- Add a user interface
|
||||
- Incorporate GAN loss into training
|
||||
|
@ -6,14 +6,13 @@ import os
|
||||
import matplotlib.pyplot as plt
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
from network import Network
|
||||
from model import Model
|
||||
|
||||
IMAGE_SIZE = 128
|
||||
LOCAL_SIZE = 64
|
||||
HOLE_MIN = 24
|
||||
HOLE_MAX = 48
|
||||
BATCH_SIZE = 16
|
||||
PRETRAIN_EPOCH = 100
|
||||
|
||||
test_npy = './lfw.npy'
|
||||
|
||||
@ -25,13 +24,13 @@ def test():
|
||||
local_completion = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3])
|
||||
is_training = tf.placeholder(tf.bool, [])
|
||||
|
||||
model = Network(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE)
|
||||
model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE)
|
||||
sess = tf.Session()
|
||||
init_op = tf.global_variables_initializer()
|
||||
sess.run(init_op)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
saver.restore(sess, '../backup/latest')
|
||||
saver.restore(sess, '../saved_models/latest')
|
||||
|
||||
x_test = np.load(test_npy)
|
||||
np.random.shuffle(x_test)
|
3
src/test/.gitignore
vendored
3
src/test/.gitignore
vendored
@ -1,3 +0,0 @@
|
||||
output/*
|
||||
*.npy
|
||||
!.gitkeep
|
10
src/train.py
10
src/train.py
@ -33,9 +33,9 @@ def train():
|
||||
init_op = tf.global_variables_initializer()
|
||||
sess.run(init_op)
|
||||
|
||||
if tf.train.get_checkpoint_state('./backup'):
|
||||
if tf.train.get_checkpoint_state('./saved_model'):
|
||||
saver = tf.train.Saver()
|
||||
saver.restore(sess, './backup/latest')
|
||||
saver.restore(sess, './saved_model/latest')
|
||||
|
||||
x_train, x_test = load.load()
|
||||
x_train = np.array([a / 127.5 - 1 for a in x_train])
|
||||
@ -69,9 +69,9 @@ def train():
|
||||
|
||||
|
||||
saver = tf.train.Saver()
|
||||
saver.save(sess, './backup/latest', write_meta_graph=False)
|
||||
saver.save(sess, './saved_model/latest', write_meta_graph=False)
|
||||
if sess.run(epoch) == PRETRAIN_EPOCH:
|
||||
saver.save(sess, './backup/pretrained', write_meta_graph=False)
|
||||
saver.save(sess, './saved_model/pretrained', write_meta_graph=False)
|
||||
|
||||
|
||||
# Discrimitation
|
||||
@ -109,7 +109,7 @@ def train():
|
||||
cv2.imwrite('./output/{}.jpg'.format("{0:06d}".format(sess.run(epoch))), cv2.cvtColor(sample, cv2.COLOR_RGB2BGR))
|
||||
|
||||
saver = tf.train.Saver()
|
||||
saver.save(sess, './backup/latest', write_meta_graph=False)
|
||||
saver.save(sess, './saved_model/latest', write_meta_graph=False)
|
||||
|
||||
|
||||
def get_points():
|
||||
|
Loading…
x
Reference in New Issue
Block a user