aboutsummaryrefslogtreecommitdiff
path: root/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'main.py')
-rw-r--r--main.py13
1 files changed, 4 insertions, 9 deletions
diff --git a/main.py b/main.py
index b4c1b228..527394cb 100644
--- a/main.py
+++ b/main.py
@@ -5,12 +5,8 @@ import argparse
import tensorflow as tf
from skimage import transform
-import PIL.Image
-
import hyperparameters as hp
-from losses import YourModel
-# from tensorboard_utils import \
-# ImageLabelingLogger, ConfusionMatrixLogger, CustomModelSaver
+from model import YourModel
from skimage.io import imread, imsave
from matplotlib import pyplot as plt
@@ -42,9 +38,9 @@ def parse_args():
help='Y if you want to load the most recent weights'
)
-
return parser.parse_args()
+
def train(model: YourModel):
for i in range(hp.num_epochs):
if i % 50 == 0:
@@ -55,6 +51,7 @@ def train(model: YourModel):
np.save('checkpoint.npy', model.x)
model.train_step(i)
+
def main():
""" Main function. """
if os.path.exists(ARGS.content):
@@ -63,13 +60,11 @@ def main():
ARGS.style = os.path.abspath(ARGS.style)
os.chdir(sys.path[0])
print('this is', ARGS.content)
-
-
content_image = imread(ARGS.content)
style_image = imread(ARGS.style)
style_image = transform.resize(style_image, content_image.shape, anti_aliasing=True)
-
+
my_model = YourModel(content_image=content_image, style_image=style_image)
if (ARGS.load == 'Y'):