mirror of
https://github.com/Deepshift/DeepCreamPy.git
synced 2024-11-24 08:09:53 +00:00
check missing file for non-binary users
This commit is contained in:
parent
40dbd1d294
commit
21c4dc4d71
13
model.py
13
model.py
@ -6,7 +6,7 @@ import module as mm
|
||||
#suppress tensorflow deprecation warnings
|
||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
||||
|
||||
class InpaintNN():
|
||||
class InpaintNN:
|
||||
|
||||
def __init__(self, input_height=256, input_width=256, batch_size = 1, bar_model_name=None, bar_checkpoint_name=None, mosaic_model_name=None, mosaic_checkpoint_name = None, is_mosaic=False):
|
||||
self.bar_model_name = bar_model_name
|
||||
@ -19,8 +19,15 @@ class InpaintNN():
|
||||
self.input_width = input_width
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.check_model_file()
|
||||
self.build_model()
|
||||
|
||||
def check_model_file(self):
|
||||
if not os.path.exists(self.bar_model_name) or not os.path.exists(self.mosaic_model_name) :
|
||||
print("\nMissing Train Model, download train model")
|
||||
print("Read : https://github.com/deeppomf/DeepCreamPy/blob/master/docs/INSTALLATION.md#run-code-yourself \n")
|
||||
exit(-1)
|
||||
|
||||
def build_model(self):
|
||||
# ------- variables
|
||||
|
||||
@ -97,9 +104,9 @@ class InpaintNN():
|
||||
Restore.restore(self.sess, tf.train.latest_checkpoint(self.mosaic_checkpoint_name))
|
||||
else:
|
||||
Restore = tf.train.import_meta_graph(self.bar_model_name)
|
||||
Restore.restore(self.sess, tf.train.latest_checkpoint(self.bar_checkpoint_name))
|
||||
Restore.restore(self.sess, tf.train.latest_checkpoint(self.bar_checkpoint_name))
|
||||
|
||||
def predict(self, censored, unused, mask):
|
||||
img_sample = self.sess.run(self.image_result, feed_dict={self.X: censored, self.Y: unused, self.MASK: mask})
|
||||
|
||||
return img_sample
|
||||
return img_sample
|
||||
|
Loading…
Reference in New Issue
Block a user