mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2025-04-18 02:42:22 +00:00
comment out unused functions
This commit is contained in:
parent
6348a78310
commit
0bb651b5fb
@ -6,7 +6,7 @@ from keras.models import load_model
|
|||||||
from keras.optimizers import Adam
|
from keras.optimizers import Adam
|
||||||
from keras.layers import Input, Conv2D, UpSampling2D, Dropout, LeakyReLU, BatchNormalization, Activation
|
from keras.layers import Input, Conv2D, UpSampling2D, Dropout, LeakyReLU, BatchNormalization, Activation
|
||||||
from keras.layers.merge import Concatenate
|
from keras.layers.merge import Concatenate
|
||||||
from keras.applications import VGG16
|
#from keras.applications import VGG16
|
||||||
from keras import backend as K
|
from keras import backend as K
|
||||||
from libs.pconv_layer import PConv2D
|
from libs.pconv_layer import PConv2D
|
||||||
|
|
||||||
@ -26,36 +26,36 @@ class PConvUnet(object):
|
|||||||
# Set current epoch
|
# Set current epoch
|
||||||
self.current_epoch = 0
|
self.current_epoch = 0
|
||||||
|
|
||||||
# VGG layers to extract features from (first maxpooling layers, see pp. 7 of paper)
|
# # VGG layers to extract features from (first maxpooling layers, see pp. 7 of paper)
|
||||||
self.vgg_layers = [3, 6, 10]
|
# self.vgg_layers = [3, 6, 10]
|
||||||
|
|
||||||
# Get the vgg16 model for perceptual loss
|
# # Get the vgg16 model for perceptual loss
|
||||||
self.vgg = self.build_vgg()
|
# self.vgg = self.build_vgg()
|
||||||
|
|
||||||
# Create UNet-like model
|
# Create UNet-like model
|
||||||
self.model = self.build_pconv_unet()
|
self.model = self.build_pconv_unet()
|
||||||
|
|
||||||
def build_vgg(self):
|
# def build_vgg(self):
|
||||||
"""
|
# """
|
||||||
Load pre-trained VGG16 from keras applications
|
# Load pre-trained VGG16 from keras applications
|
||||||
Extract features to be used in loss function from last conv layer, see architecture at:
|
# Extract features to be used in loss function from last conv layer, see architecture at:
|
||||||
https://github.com/keras-team/keras/blob/master/keras/applications/vgg16.py
|
# https://github.com/keras-team/keras/blob/master/keras/applications/vgg16.py
|
||||||
"""
|
# """
|
||||||
# Input image to extract features from
|
# # Input image to extract features from
|
||||||
img = Input(shape=(self.img_rows, self.img_cols, 3))
|
# img = Input(shape=(self.img_rows, self.img_cols, 3))
|
||||||
|
|
||||||
# Get the vgg network from Keras applications
|
# # Get the vgg network from Keras applications
|
||||||
vgg = VGG16(weights="imagenet", include_top=False)
|
# vgg = VGG16(weights="imagenet", include_top=False)
|
||||||
|
|
||||||
# Output the first three pooling layers
|
# # Output the first three pooling layers
|
||||||
vgg.outputs = [vgg.layers[i].output for i in self.vgg_layers]
|
# vgg.outputs = [vgg.layers[i].output for i in self.vgg_layers]
|
||||||
|
|
||||||
# Create model and compile
|
# # Create model and compile
|
||||||
model = Model(inputs=img, outputs=vgg(img))
|
# model = Model(inputs=img, outputs=vgg(img))
|
||||||
model.trainable = False
|
# model.trainable = False
|
||||||
model.compile(loss='mse', optimizer='adam')
|
# model.compile(loss='mse', optimizer='adam')
|
||||||
|
|
||||||
return model
|
# return model
|
||||||
|
|
||||||
def build_pconv_unet(self, train_bn=True, lr=0.0002):
|
def build_pconv_unet(self, train_bn=True, lr=0.0002):
|
||||||
|
|
||||||
@ -111,76 +111,78 @@ class PConvUnet(object):
|
|||||||
# Compile the model
|
# Compile the model
|
||||||
model.compile(
|
model.compile(
|
||||||
optimizer = Adam(lr=lr),
|
optimizer = Adam(lr=lr),
|
||||||
loss=self.loss_total(loss_mask)
|
loss='mse'
|
||||||
|
#loss really isn't mse, but we don't need the vgg16 model for inference so we don't to have to download the vgg16 model
|
||||||
|
#loss=self.loss_total(loss_mask)
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def loss_total(self, mask):
|
# def loss_total(self, mask):
|
||||||
"""
|
# """
|
||||||
Creates a loss function which sums all the loss components
|
# Creates a loss function which sums all the loss components
|
||||||
and multiplies by their weights. See paper eq. 7.
|
# and multiplies by their weights. See paper eq. 7.
|
||||||
"""
|
# """
|
||||||
def loss(y_true, y_pred):
|
# def loss(y_true, y_pred):
|
||||||
|
|
||||||
# Compute predicted image with non-hole pixels set to ground truth
|
# # Compute predicted image with non-hole pixels set to ground truth
|
||||||
y_comp = mask * y_true + (1-mask) * y_pred
|
# y_comp = mask * y_true + (1-mask) * y_pred
|
||||||
|
|
||||||
# Compute the vgg features
|
# # Compute the vgg features
|
||||||
vgg_out = self.vgg(y_pred)
|
# vgg_out = self.vgg(y_pred)
|
||||||
vgg_gt = self.vgg(y_true)
|
# vgg_gt = self.vgg(y_true)
|
||||||
vgg_comp = self.vgg(y_comp)
|
# vgg_comp = self.vgg(y_comp)
|
||||||
|
|
||||||
# Compute loss components
|
# # Compute loss components
|
||||||
l1 = self.loss_valid(mask, y_true, y_pred)
|
# l1 = self.loss_valid(mask, y_true, y_pred)
|
||||||
l2 = self.loss_hole(mask, y_true, y_pred)
|
# l2 = self.loss_hole(mask, y_true, y_pred)
|
||||||
l3 = self.loss_perceptual(vgg_out, vgg_gt, vgg_comp)
|
# l3 = self.loss_perceptual(vgg_out, vgg_gt, vgg_comp)
|
||||||
l4 = self.loss_style(vgg_out, vgg_gt)
|
# l4 = self.loss_style(vgg_out, vgg_gt)
|
||||||
l5 = self.loss_style(vgg_comp, vgg_gt)
|
# l5 = self.loss_style(vgg_comp, vgg_gt)
|
||||||
l6 = self.loss_tv(mask, y_comp)
|
# l6 = self.loss_tv(mask, y_comp)
|
||||||
|
|
||||||
# Return loss function
|
# # Return loss function
|
||||||
return l1 + 6*l2 + 0.05*l3 + 120*(l4+l5) + 0.1*l6
|
# return l1 + 6*l2 + 0.05*l3 + 120*(l4+l5) + 0.1*l6
|
||||||
|
|
||||||
return loss
|
# return loss
|
||||||
|
|
||||||
def loss_hole(self, mask, y_true, y_pred):
|
# def loss_hole(self, mask, y_true, y_pred):
|
||||||
"""Pixel L1 loss within the hole / mask"""
|
# """Pixel L1 loss within the hole / mask"""
|
||||||
return self.l1((1-mask) * y_true, (1-mask) * y_pred)
|
# return self.l1((1-mask) * y_true, (1-mask) * y_pred)
|
||||||
|
|
||||||
def loss_valid(self, mask, y_true, y_pred):
|
# def loss_valid(self, mask, y_true, y_pred):
|
||||||
"""Pixel L1 loss outside the hole / mask"""
|
# """Pixel L1 loss outside the hole / mask"""
|
||||||
return self.l1(mask * y_true, mask * y_pred)
|
# return self.l1(mask * y_true, mask * y_pred)
|
||||||
|
|
||||||
def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp):
|
# def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp):
|
||||||
"""Perceptual loss based on VGG16, see. eq. 3 in paper"""
|
# """Perceptual loss based on VGG16, see. eq. 3 in paper"""
|
||||||
loss = 0
|
# loss = 0
|
||||||
for o, c, g in zip(vgg_out, vgg_comp, vgg_gt):
|
# for o, c, g in zip(vgg_out, vgg_comp, vgg_gt):
|
||||||
loss += self.l1(o, g) + self.l1(c, g)
|
# loss += self.l1(o, g) + self.l1(c, g)
|
||||||
return loss
|
# return loss
|
||||||
|
|
||||||
def loss_style(self, output, vgg_gt):
|
# def loss_style(self, output, vgg_gt):
|
||||||
"""Style loss based on output/computation, used for both eq. 4 & 5 in paper"""
|
# """Style loss based on output/computation, used for both eq. 4 & 5 in paper"""
|
||||||
loss = 0
|
# loss = 0
|
||||||
for o, g in zip(output, vgg_gt):
|
# for o, g in zip(output, vgg_gt):
|
||||||
loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))
|
# loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))
|
||||||
return loss
|
# return loss
|
||||||
|
|
||||||
def loss_tv(self, mask, y_comp):
|
# def loss_tv(self, mask, y_comp):
|
||||||
"""Total variation loss, used for smoothing the hole region, see. eq. 6"""
|
# """Total variation loss, used for smoothing the hole region, see. eq. 6"""
|
||||||
|
|
||||||
# Create dilated hole region using a 3x3 kernel of all 1s.
|
# # Create dilated hole region using a 3x3 kernel of all 1s.
|
||||||
kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3]))
|
# kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3]))
|
||||||
dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same')
|
# dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same')
|
||||||
|
|
||||||
# Cast values to be [0., 1.], and compute dilated hole region of y_comp
|
# # Cast values to be [0., 1.], and compute dilated hole region of y_comp
|
||||||
dilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32')
|
# dilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32')
|
||||||
P = dilated_mask * y_comp
|
# P = dilated_mask * y_comp
|
||||||
|
|
||||||
# Calculate total variation loss
|
# # Calculate total variation loss
|
||||||
a = self.l1(P[:,1:,:,:], P[:,:-1,:,:])
|
# a = self.l1(P[:,1:,:,:], P[:,:-1,:,:])
|
||||||
b = self.l1(P[:,:,1:,:], P[:,:,:-1,:])
|
# b = self.l1(P[:,:,1:,:], P[:,:,:-1,:])
|
||||||
return a+b
|
# return a+b
|
||||||
|
|
||||||
def fit(self, generator, epochs=10, plot_callback=None, *args, **kwargs):
|
def fit(self, generator, epochs=10, plot_callback=None, *args, **kwargs):
|
||||||
"""Fit the U-Net to a (images, targets) generator
|
"""Fit the U-Net to a (images, targets) generator
|
||||||
|
Loading…
x
Reference in New Issue
Block a user