check missing file for non-binary users

This commit is contained in:
ccppoo 2020-01-01 15:59:10 +09:00
parent 40dbd1d294
commit 21c4dc4d71

View File

@ -6,7 +6,7 @@ import module as mm
#suppress tensorflow deprecation warnings
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
class InpaintNN():
class InpaintNN:
def __init__(self, input_height=256, input_width=256, batch_size = 1, bar_model_name=None, bar_checkpoint_name=None, mosaic_model_name=None, mosaic_checkpoint_name = None, is_mosaic=False):
self.bar_model_name = bar_model_name
@ -19,8 +19,15 @@ class InpaintNN():
self.input_width = input_width
self.batch_size = batch_size
self.check_model_file()
self.build_model()
def check_model_file(self):
if not os.path.exists(self.bar_model_name) or not os.path.exists(self.mosaic_model_name) :
print("\nMissing Train Model, download train model")
print("Read : https://github.com/deeppomf/DeepCreamPy/blob/master/docs/INSTALLATION.md#run-code-yourself \n")
exit(-1)
def build_model(self):
# ------- variables
@ -97,9 +104,9 @@ class InpaintNN():
Restore.restore(self.sess, tf.train.latest_checkpoint(self.mosaic_checkpoint_name))
else:
Restore = tf.train.import_meta_graph(self.bar_model_name)
Restore.restore(self.sess, tf.train.latest_checkpoint(self.bar_checkpoint_name))
Restore.restore(self.sess, tf.train.latest_checkpoint(self.bar_checkpoint_name))
def predict(self, censored, unused, mask):
img_sample = self.sess.run(self.image_result, feed_dict={self.X: censored, self.Y: unused, self.MASK: mask})
return img_sample
return img_sample