diff options
author | Benjamin Fiske <bffiske@gmail.com> | 2022-05-04 15:40:04 -0400 |
---|---|---|
committer | Benjamin Fiske <bffiske@gmail.com> | 2022-05-04 15:40:04 -0400 |
commit | df0a9240bac34d2bda0d3c7c836dbce2ce781344 (patch) | |
tree | 4c76febae25c7deba20281406fe69ce1ad6dd4d8 /main.py | |
parent | 137d14b2ceffa95407cacfc82d1222cb9ef35072 (diff) |
backprop and main
Diffstat (limited to 'main.py')
-rw-r--r-- | main.py | 229 |
1 files changed, 25 insertions, 204 deletions
@@ -12,14 +12,12 @@ from skimage.transform import resize # from tensorboard_utils import \ # ImageLabelingLogger, ConfusionMatrixLogger, CustomModelSaver -from skimage.io import imread +from skimage.io import imread, imsave from lime import lime_image from skimage.segmentation import mark_boundaries from matplotlib import pyplot as plt import numpy as np -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' - def parse_args(): """ Perform command-line argument parsing. """ @@ -28,221 +26,44 @@ def parse_args(): description="Let's train some neural nets!", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( - '--task', + '--content', required=True, - choices=['1', '3'], - help='''Which task of the assignment to run - - training from scratch (1), or fine tuning VGG-16 (3).''') - parser.add_argument( - '--data', - default='..'+os.sep+'data'+os.sep, - help='Location where the dataset is stored.') - parser.add_argument( - '--load-vgg', - default='vgg16_imagenet.h5', - help='''Path to pre-trained VGG-16 file (only applicable to - task 3).''') - parser.add_argument( - '--load-checkpoint', - default=None, - help='''Path to model checkpoint file (should end with the - extension .h5). Checkpoints are automatically saved when you - train your model. If you want to continue training from where - you left off, this is how you would load your weights.''') - parser.add_argument( - '--confusion', - action='store_true', - help='''Log a confusion matrix at the end of each - epoch (viewable in Tensorboard). This is turned off - by default as it takes a little bit of time to complete.''') + help='''Content image filepath''') parser.add_argument( - '--evaluate', - action='store_true', - help='''Skips training and evaluates on the test set once. - You can use this to test an already trained model by loading - its checkpoint.''') + '--style', + required=True, + help='Style image filepath') parser.add_argument( - '--lime-image', - default='test/Bedroom/image_0003.jpg', - help='''Name of an image in the dataset to use for LIME evaluation.''') - - return parser.parse_args() - - -def LIME_explainer(model, path, preprocess_fn): - """ - This function takes in a trained model and a path to an image and outputs 5 - visual explanations using the LIME model - """ - - def image_and_mask(title, positive_only=True, num_features=5, - hide_rest=True): - temp, mask = explanation.get_image_and_mask( - explanation.top_labels[0], positive_only=positive_only, - num_features=num_features, hide_rest=hide_rest) - plt.imshow(mark_boundaries(temp / 2 + 0.5, mask)) - plt.title(title) - plt.show() - - image = imread(path) - if len(image.shape) == 2: - image = np.stack([image, image, image], axis=-1) - image = preprocess_fn(image) - image = resize(image, (hp.img_size, hp.img_size, 3)) - - explainer = lime_image.LimeImageExplainer() - - explanation = explainer.explain_instance( - image.astype('double'), model.predict, top_labels=5, hide_color=0, - num_samples=1000) - - # The top 5 superpixels that are most positive towards the class with the - # rest of the image hidden - image_and_mask("Top 5 superpixels", positive_only=True, num_features=5, - hide_rest=True) - - # The top 5 superpixels with the rest of the image present - image_and_mask("Top 5 with the rest of the image present", - positive_only=True, num_features=5, hide_rest=False) - - # The 'pros and cons' (pros in green, cons in red) - image_and_mask("Pros(green) and Cons(red)", - positive_only=False, num_features=10, hide_rest=False) - - # Select the same class explained on the figures above. - ind = explanation.top_labels[0] - # Map each explanation weight to the corresponding superpixel - dict_heatmap = dict(explanation.local_exp[ind]) - heatmap = np.vectorize(dict_heatmap.get)(explanation.segments) - plt.imshow(heatmap, cmap='RdBu', vmin=-heatmap.max(), vmax=heatmap.max()) - plt.colorbar() - plt.title("Map each explanation weight to the corresponding superpixel") - plt.show() - - -def train(model, datasets, checkpoint_path, logs_path, init_epoch): - """ Training routine. """ - - # Keras callbacks for training - callback_list = [ - tf.keras.callbacks.TensorBoard( - log_dir=logs_path, - update_freq='batch', - profile_batch=0) - # ImageLabelingLogger(logs_path, datasets), - # CustomModelSaver(checkpoint_path, ARGS.task, hp.max_num_weights) - ] - - # Include confusion logger in callbacks if flag set - if ARGS.confusion: - callback_list.append(ConfusionMatrixLogger(logs_path, datasets)) - - # Begin training - model.fit( - x=datasets.train_data, - validation_data=datasets.test_data, - epochs=hp.num_epochs, - batch_size=None, - callbacks=callback_list, - initial_epoch=init_epoch, - ) - + '--savefile', + required=True, + help='Filename to save image') -def test(model, test_data): - """ Testing routine. """ - # Run model on test set - model.evaluate( - x=test_data, - verbose=1, - ) + return parser.parse_args() +def train(model): + for _ in hp.num_epochs: + model.train_step() def main(): """ Main function. """ - - time_now = datetime.now() - timestamp = time_now.strftime("%m%d%y-%H%M%S") - init_epoch = 0 - - # If loading from a checkpoint, the loaded checkpoint's directory - # will be used for future checkpoints - if ARGS.load_checkpoint is not None: - ARGS.load_checkpoint = os.path.abspath(ARGS.load_checkpoint) - - # Get timestamp and epoch from filename - regex = r"(?:.+)(?:\.e)(\d+)(?:.+)(?:.h5)" - init_epoch = int(re.match(regex, ARGS.load_checkpoint).group(1)) + 1 - timestamp = os.path.basename(os.path.dirname(ARGS.load_checkpoint)) - - # If paths provided by program arguments are accurate, then this will - # ensure they are used. If not, these directories/files will be - # set relative to the directory of run.py - if os.path.exists(ARGS.data): - ARGS.data = os.path.abspath(ARGS.data) - if os.path.exists(ARGS.load_vgg): - ARGS.load_vgg = os.path.abspath(ARGS.load_vgg) - - # Run script from location of run.py + if os.path.exists(ARGS.content): + ARGS.content = os.path.abspath(ARGS.content) + if os.path.exists(ARGS.style): + ARGS.style = os.path.abspath(ARGS.style) os.chdir(sys.path[0]) - datasets = Datasets(ARGS.data, ARGS.task) - - if ARGS.task == '1': - model = YourModel() - model(tf.keras.Input(shape=(hp.img_size, hp.img_size, 3))) - checkpoint_path = "checkpoints" + os.sep + \ - "your_model" + os.sep + timestamp + os.sep - logs_path = "logs" + os.sep + "your_model" + \ - os.sep + timestamp + os.sep - - # Print summary of model - model.summary() - else: - model = VGGModel() - checkpoint_path = "checkpoints" + os.sep + \ - "vgg_model" + os.sep + timestamp + os.sep - logs_path = "logs" + os.sep + "vgg_model" + \ - os.sep + timestamp + os.sep - model(tf.keras.Input(shape=(224, 224, 3))) - - # Print summaries for both parts of the model - model.vgg16.summary() - model.head.summary() - - # Load base of VGG model - model.vgg16.load_weights(ARGS.load_vgg, by_name=True) + content_image = imread(ARGS.content) + style_image = imread(ARGS.style) + my_model = YourModel(content_image=content_image, style_image=style_image) + train(my_model) + + final_image = my_model.x - # Load checkpoints - if ARGS.load_checkpoint is not None: - if ARGS.task == '1': - model.load_weights(ARGS.load_checkpoint, by_name=False) - else: - model.head.load_weights(ARGS.load_checkpoint, by_name=False) + plt.imshow(final_image) - # Make checkpoint directory if needed - if not ARGS.evaluate and not os.path.exists(checkpoint_path): - os.makedirs(checkpoint_path) + imsave(ARGS.savefile, final_image) - # Compile model graph - model.compile( - optimizer=model.optimizer, - loss=model.loss_fn, - metrics=["sparse_categorical_accuracy"]) - if ARGS.evaluate: - test(model, datasets.test_data) - - # TODO: change the image path to be the image of your choice by changing - # the lime-image flag when calling run.py to investigate - # i.e. python run.py --evaluate --lime-image test/Bedroom/image_003.jpg - path = ARGS.data + os.sep + ARGS.lime_image - LIME_explainer(model, path, datasets.preprocess_fn) - else: - train(model, datasets, checkpoint_path, logs_path, init_epoch) - - -# Make arguments global ARGS = parse_args() - main() |