- Project Description
- Project Setup
- How To Train Model In this Project
- How To Restore Images In this Project
- How To Check Model Architecture
The GAN (Generative Adversarial Netwrok) algorithm is a class of machine learning frameworks designed by Ian Goodfellow and his colleagues in June 2014. It's based on "Game Theory", to make two neural networks contest with each other.
This project will restore image using GAN model, and here is how it works:
- Model setup:
- Train model:
- Yields batches of images from
training_data. Thetraining_data's shape is(image_count, image_width, image_hight, image_channel) - Put the random mask over data (each picture)
- Customize the loss function of discriminator and generator
- Gradient descent with respect to variables of discriminator and generator
- using
tensorflow.GradientTapeto implement gradient descent
- using
- Plot training progress bar in terminal
- using
richpackages of Python to plotepochs,completeness,generator lossanddiscriminator loss
- using
- Save model structure and parameters when it finish model training
- Yields batches of images from
- Image Restoration
- Load trained model
- Get any image with mask fits
training_data's shape, e.g.(image_count, image_width, image_hight, image_channel) - Restore image
To avoid TensorFlow version conflicts, the project use pipenv (Python vitural environment) to install Python packages.
Notice: Before executing the following command, please refer to TensorFlow Installation Source and modify the TensorFlow version in
PipfileandPipfile.lock(or modifyPipfileand removePipfile.lock)
pip install pipenv
pipenv shell
pipenv installIn model training stage, you can modify the model architecture or the hyperparameter in src/model/GAN.py like epochs, learning_rate, learning_rate_decay, etc.
python src/train.pyYou can use model you trained or apply the following model to restore images:
- The example model
generator_example.h5atsrc/model/trained_model/ - Other trained model on Google Drive
python src/predict.pyYou can modify model_path in src/watch_model_architecture.py to watch any model you want
python src/watch_model_architecture.py

