diff --git a/README.md b/README.md index bac5e83..02ea57c 100644 --- a/README.md +++ b/README.md @@ -83,9 +83,11 @@ $ python train.py # To do - ~~Add Python 3 compatibility~~ +- Add random rotations in cropping rectangles - Retrain for arbitrary shape censors - Add a user interface - Incorporate GAN loss into training +- Update the model to the new version Contributions are welcome! diff --git a/train.py b/train.py index 4d406f5..767ab94 100644 --- a/train.py +++ b/train.py @@ -4,6 +4,7 @@ from PIL import Image import tqdm from model import Model import load +import scipy.ndimage IMAGE_SIZE = 128 LOCAL_SIZE = 64 @@ -12,6 +13,8 @@ HOLE_MAX = 48 LEARNING_RATE = 1e-3 BATCH_SIZE = 16 PRETRAIN_EPOCH = 100 +#the chance the rectangle crop will be rotated +ROTATE_CHANCE = 0.5 def train(): x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3]) @@ -129,12 +132,17 @@ def get_points(): m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8) m[q1:q2 + 1, p1:p2 + 1] = 1 - mask.append(m) + if (np.random.random() < ROTATE_CHANCE): + #rotate random amount between 0 and 90 degrees + m = scipy.ndimage.rotate(m, np.random.random()*90, reshape = False) + #set all elements greater than 0 to 1 + m[m > 0] = 1 + + mask.append(m) return np.array(points), np.array(mask) if __name__ == '__main__': - train() - + train() \ No newline at end of file