mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2024-11-29 05:10:43 +00:00
random rotation of rectangle masks
This commit is contained in:
parent
e268490bf9
commit
296bb7109e
@ -83,9 +83,11 @@ $ python train.py
|
|||||||
|
|
||||||
# To do
|
# To do
|
||||||
- ~~Add Python 3 compatibility~~
|
- ~~Add Python 3 compatibility~~
|
||||||
|
- Add random rotations in cropping rectangles
|
||||||
- Retrain for arbitrary shape censors
|
- Retrain for arbitrary shape censors
|
||||||
- Add a user interface
|
- Add a user interface
|
||||||
- Incorporate GAN loss into training
|
- Incorporate GAN loss into training
|
||||||
|
- Update the model to the new version
|
||||||
|
|
||||||
Contributions are welcome!
|
Contributions are welcome!
|
||||||
|
|
||||||
|
12
train.py
12
train.py
@ -4,6 +4,7 @@ from PIL import Image
|
|||||||
import tqdm
|
import tqdm
|
||||||
from model import Model
|
from model import Model
|
||||||
import load
|
import load
|
||||||
|
import scipy.ndimage
|
||||||
|
|
||||||
IMAGE_SIZE = 128
|
IMAGE_SIZE = 128
|
||||||
LOCAL_SIZE = 64
|
LOCAL_SIZE = 64
|
||||||
@ -12,6 +13,8 @@ HOLE_MAX = 48
|
|||||||
LEARNING_RATE = 1e-3
|
LEARNING_RATE = 1e-3
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
PRETRAIN_EPOCH = 100
|
PRETRAIN_EPOCH = 100
|
||||||
|
#the chance the rectangle crop will be rotated
|
||||||
|
ROTATE_CHANCE = 0.5
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3])
|
x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3])
|
||||||
@ -129,12 +132,17 @@ def get_points():
|
|||||||
|
|
||||||
m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8)
|
m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8)
|
||||||
m[q1:q2 + 1, p1:p2 + 1] = 1
|
m[q1:q2 + 1, p1:p2 + 1] = 1
|
||||||
mask.append(m)
|
|
||||||
|
|
||||||
|
if (np.random.random() < ROTATE_CHANCE):
|
||||||
|
#rotate random amount between 0 and 90 degrees
|
||||||
|
m = scipy.ndimage.rotate(m, np.random.random()*90, reshape = False)
|
||||||
|
#set all elements greater than 0 to 1
|
||||||
|
m[m > 0] = 1
|
||||||
|
|
||||||
|
mask.append(m)
|
||||||
|
|
||||||
return np.array(points), np.array(mask)
|
return np.array(points), np.array(mask)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train()
|
train()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user