mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2024-11-29 05:10:43 +00:00
random rotation of rectangle masks
This commit is contained in:
parent
e268490bf9
commit
296bb7109e
@ -83,9 +83,11 @@ $ python train.py
|
||||
|
||||
# To do
|
||||
- ~~Add Python 3 compatibility~~
|
||||
- Add random rotations in cropping rectangles
|
||||
- Retrain for arbitrary shape censors
|
||||
- Add a user interface
|
||||
- Incorporate GAN loss into training
|
||||
- Update the model to the new version
|
||||
|
||||
Contributions are welcome!
|
||||
|
||||
|
12
train.py
12
train.py
@ -4,6 +4,7 @@ from PIL import Image
|
||||
import tqdm
|
||||
from model import Model
|
||||
import load
|
||||
import scipy.ndimage
|
||||
|
||||
IMAGE_SIZE = 128
|
||||
LOCAL_SIZE = 64
|
||||
@ -12,6 +13,8 @@ HOLE_MAX = 48
|
||||
LEARNING_RATE = 1e-3
|
||||
BATCH_SIZE = 16
|
||||
PRETRAIN_EPOCH = 100
|
||||
#the chance the rectangle crop will be rotated
|
||||
ROTATE_CHANCE = 0.5
|
||||
|
||||
def train():
|
||||
x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3])
|
||||
@ -129,12 +132,17 @@ def get_points():
|
||||
|
||||
m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8)
|
||||
m[q1:q2 + 1, p1:p2 + 1] = 1
|
||||
mask.append(m)
|
||||
|
||||
if (np.random.random() < ROTATE_CHANCE):
|
||||
#rotate random amount between 0 and 90 degrees
|
||||
m = scipy.ndimage.rotate(m, np.random.random()*90, reshape = False)
|
||||
#set all elements greater than 0 to 1
|
||||
m[m > 0] = 1
|
||||
|
||||
mask.append(m)
|
||||
|
||||
return np.array(points), np.array(mask)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user