random rotation of rectangle masks

This commit is contained in:
deeppomf 2018-02-16 23:44:32 -05:00
parent e268490bf9
commit 296bb7109e
2 changed files with 13 additions and 3 deletions

View File

@ -83,9 +83,11 @@ $ python train.py
# To do # To do
- ~~Add Python 3 compatibility~~ - ~~Add Python 3 compatibility~~
- Add random rotations in cropping rectangles
- Retrain for arbitrary shape censors - Retrain for arbitrary shape censors
- Add a user interface - Add a user interface
- Incorporate GAN loss into training - Incorporate GAN loss into training
- Update the model to the new version
Contributions are welcome! Contributions are welcome!

View File

@ -4,6 +4,7 @@ from PIL import Image
import tqdm import tqdm
from model import Model from model import Model
import load import load
import scipy.ndimage
IMAGE_SIZE = 128 IMAGE_SIZE = 128
LOCAL_SIZE = 64 LOCAL_SIZE = 64
@ -12,6 +13,8 @@ HOLE_MAX = 48
LEARNING_RATE = 1e-3 LEARNING_RATE = 1e-3
BATCH_SIZE = 16 BATCH_SIZE = 16
PRETRAIN_EPOCH = 100 PRETRAIN_EPOCH = 100
#the chance the rectangle crop will be rotated
ROTATE_CHANCE = 0.5
def train(): def train():
x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3]) x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3])
@ -129,12 +132,17 @@ def get_points():
m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8) m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8)
m[q1:q2 + 1, p1:p2 + 1] = 1 m[q1:q2 + 1, p1:p2 + 1] = 1
mask.append(m)
if (np.random.random() < ROTATE_CHANCE):
#rotate random amount between 0 and 90 degrees
m = scipy.ndimage.rotate(m, np.random.random()*90, reshape = False)
#set all elements greater than 0 to 1
m[m > 0] = 1
mask.append(m)
return np.array(points), np.array(mask) return np.array(points), np.array(mask)
if __name__ == '__main__': if __name__ == '__main__':
train() train()