diff --git a/train.py b/train.py index 767ab94..86d5be3 100644 --- a/train.py +++ b/train.py @@ -137,7 +137,7 @@ def get_points(): #rotate random amount between 0 and 90 degrees m = scipy.ndimage.rotate(m, np.random.random()*90, reshape = False) #set all elements greater than 0 to 1 - m[m > 0] = 1 + m[m > 0.5] = 1 mask.append(m)