diff --git a/src/train.py b/src/train.py index e5c6c1b..477443b 100644 --- a/src/train.py +++ b/src/train.py @@ -21,7 +21,7 @@ def train(): local_completion = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3]) is_training = tf.placeholder(tf.bool, []) - model = Network(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE) + model = Model(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE) sess = tf.Session() global_step = tf.Variable(0, name='global_step', trainable=False) epoch = tf.Variable(0, name='epoch', trainable=False)