comment out unused functions

This commit is contained in:
deeppomf 2018-10-21 19:15:43 -04:00
parent 6348a78310
commit 0bb651b5fb

View File

@ -6,7 +6,7 @@ from keras.models import load_model
from keras.optimizers import Adam from keras.optimizers import Adam
from keras.layers import Input, Conv2D, UpSampling2D, Dropout, LeakyReLU, BatchNormalization, Activation from keras.layers import Input, Conv2D, UpSampling2D, Dropout, LeakyReLU, BatchNormalization, Activation
from keras.layers.merge import Concatenate from keras.layers.merge import Concatenate
from keras.applications import VGG16 #from keras.applications import VGG16
from keras import backend as K from keras import backend as K
from libs.pconv_layer import PConv2D from libs.pconv_layer import PConv2D
@ -26,36 +26,36 @@ class PConvUnet(object):
# Set current epoch # Set current epoch
self.current_epoch = 0 self.current_epoch = 0
# VGG layers to extract features from (first maxpooling layers, see pp. 7 of paper) # # VGG layers to extract features from (first maxpooling layers, see pp. 7 of paper)
self.vgg_layers = [3, 6, 10] # self.vgg_layers = [3, 6, 10]
# Get the vgg16 model for perceptual loss # # Get the vgg16 model for perceptual loss
self.vgg = self.build_vgg() # self.vgg = self.build_vgg()
# Create UNet-like model # Create UNet-like model
self.model = self.build_pconv_unet() self.model = self.build_pconv_unet()
def build_vgg(self): # def build_vgg(self):
""" # """
Load pre-trained VGG16 from keras applications # Load pre-trained VGG16 from keras applications
Extract features to be used in loss function from last conv layer, see architecture at: # Extract features to be used in loss function from last conv layer, see architecture at:
https://github.com/keras-team/keras/blob/master/keras/applications/vgg16.py # https://github.com/keras-team/keras/blob/master/keras/applications/vgg16.py
""" # """
# Input image to extract features from # # Input image to extract features from
img = Input(shape=(self.img_rows, self.img_cols, 3)) # img = Input(shape=(self.img_rows, self.img_cols, 3))
# Get the vgg network from Keras applications # # Get the vgg network from Keras applications
vgg = VGG16(weights="imagenet", include_top=False) # vgg = VGG16(weights="imagenet", include_top=False)
# Output the first three pooling layers # # Output the first three pooling layers
vgg.outputs = [vgg.layers[i].output for i in self.vgg_layers] # vgg.outputs = [vgg.layers[i].output for i in self.vgg_layers]
# Create model and compile # # Create model and compile
model = Model(inputs=img, outputs=vgg(img)) # model = Model(inputs=img, outputs=vgg(img))
model.trainable = False # model.trainable = False
model.compile(loss='mse', optimizer='adam') # model.compile(loss='mse', optimizer='adam')
return model # return model
def build_pconv_unet(self, train_bn=True, lr=0.0002): def build_pconv_unet(self, train_bn=True, lr=0.0002):
@ -111,76 +111,78 @@ class PConvUnet(object):
# Compile the model # Compile the model
model.compile( model.compile(
optimizer = Adam(lr=lr), optimizer = Adam(lr=lr),
loss=self.loss_total(loss_mask) loss='mse'
#loss really isn't mse, but we don't need the vgg16 model for inference so we don't to have to download the vgg16 model
#loss=self.loss_total(loss_mask)
) )
return model return model
def loss_total(self, mask): # def loss_total(self, mask):
""" # """
Creates a loss function which sums all the loss components # Creates a loss function which sums all the loss components
and multiplies by their weights. See paper eq. 7. # and multiplies by their weights. See paper eq. 7.
""" # """
def loss(y_true, y_pred): # def loss(y_true, y_pred):
# Compute predicted image with non-hole pixels set to ground truth # # Compute predicted image with non-hole pixels set to ground truth
y_comp = mask * y_true + (1-mask) * y_pred # y_comp = mask * y_true + (1-mask) * y_pred
# Compute the vgg features # # Compute the vgg features
vgg_out = self.vgg(y_pred) # vgg_out = self.vgg(y_pred)
vgg_gt = self.vgg(y_true) # vgg_gt = self.vgg(y_true)
vgg_comp = self.vgg(y_comp) # vgg_comp = self.vgg(y_comp)
# Compute loss components # # Compute loss components
l1 = self.loss_valid(mask, y_true, y_pred) # l1 = self.loss_valid(mask, y_true, y_pred)
l2 = self.loss_hole(mask, y_true, y_pred) # l2 = self.loss_hole(mask, y_true, y_pred)
l3 = self.loss_perceptual(vgg_out, vgg_gt, vgg_comp) # l3 = self.loss_perceptual(vgg_out, vgg_gt, vgg_comp)
l4 = self.loss_style(vgg_out, vgg_gt) # l4 = self.loss_style(vgg_out, vgg_gt)
l5 = self.loss_style(vgg_comp, vgg_gt) # l5 = self.loss_style(vgg_comp, vgg_gt)
l6 = self.loss_tv(mask, y_comp) # l6 = self.loss_tv(mask, y_comp)
# Return loss function # # Return loss function
return l1 + 6*l2 + 0.05*l3 + 120*(l4+l5) + 0.1*l6 # return l1 + 6*l2 + 0.05*l3 + 120*(l4+l5) + 0.1*l6
return loss # return loss
def loss_hole(self, mask, y_true, y_pred): # def loss_hole(self, mask, y_true, y_pred):
"""Pixel L1 loss within the hole / mask""" # """Pixel L1 loss within the hole / mask"""
return self.l1((1-mask) * y_true, (1-mask) * y_pred) # return self.l1((1-mask) * y_true, (1-mask) * y_pred)
def loss_valid(self, mask, y_true, y_pred): # def loss_valid(self, mask, y_true, y_pred):
"""Pixel L1 loss outside the hole / mask""" # """Pixel L1 loss outside the hole / mask"""
return self.l1(mask * y_true, mask * y_pred) # return self.l1(mask * y_true, mask * y_pred)
def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp): # def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp):
"""Perceptual loss based on VGG16, see. eq. 3 in paper""" # """Perceptual loss based on VGG16, see. eq. 3 in paper"""
loss = 0 # loss = 0
for o, c, g in zip(vgg_out, vgg_comp, vgg_gt): # for o, c, g in zip(vgg_out, vgg_comp, vgg_gt):
loss += self.l1(o, g) + self.l1(c, g) # loss += self.l1(o, g) + self.l1(c, g)
return loss # return loss
def loss_style(self, output, vgg_gt): # def loss_style(self, output, vgg_gt):
"""Style loss based on output/computation, used for both eq. 4 & 5 in paper""" # """Style loss based on output/computation, used for both eq. 4 & 5 in paper"""
loss = 0 # loss = 0
for o, g in zip(output, vgg_gt): # for o, g in zip(output, vgg_gt):
loss += self.l1(self.gram_matrix(o), self.gram_matrix(g)) # loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))
return loss # return loss
def loss_tv(self, mask, y_comp): # def loss_tv(self, mask, y_comp):
"""Total variation loss, used for smoothing the hole region, see. eq. 6""" # """Total variation loss, used for smoothing the hole region, see. eq. 6"""
# Create dilated hole region using a 3x3 kernel of all 1s. # # Create dilated hole region using a 3x3 kernel of all 1s.
kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3])) # kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3]))
dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same') # dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same')
# Cast values to be [0., 1.], and compute dilated hole region of y_comp # # Cast values to be [0., 1.], and compute dilated hole region of y_comp
dilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32') # dilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32')
P = dilated_mask * y_comp # P = dilated_mask * y_comp
# Calculate total variation loss # # Calculate total variation loss
a = self.l1(P[:,1:,:,:], P[:,:-1,:,:]) # a = self.l1(P[:,1:,:,:], P[:,:-1,:,:])
b = self.l1(P[:,:,1:,:], P[:,:,:-1,:]) # b = self.l1(P[:,:,1:,:], P[:,:,:-1,:])
return a+b # return a+b
def fit(self, generator, epochs=10, plot_callback=None, *args, **kwargs): def fit(self, generator, epochs=10, plot_callback=None, *args, **kwargs):
"""Fit the U-Net to a (images, targets) generator """Fit the U-Net to a (images, targets) generator