Skip to content

Commit 8c3dc25

Browse files
committed
Add README and comments for VAE examples (variational autoencoder folder)
1 parent 5171fd9 commit 8c3dc25

File tree

3 files changed

+80
-1
lines changed

3 files changed

+80
-1
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Variational Autoencoder (VAE) Examples – DeepLearning4J
2+
3+
This folder contains two examples demonstrating how to use Variational Autoencoders (VAEs) in DeepLearning4J.
4+
VAEs are generative models that learn latent representations of data and can be used for visualization, sampling, and anomaly detection.
5+
6+
---
7+
8+
## 🧩 VaeMNIST2dPlots.java
9+
Trains a VAE on the MNIST digit dataset and visualizes the **2-dimensional latent space**.
10+
11+
### What this example shows
12+
- How to build a VAE in DL4J
13+
- Encoding MNIST images into a 2D latent space
14+
- Plotting how digits cluster in latent space
15+
- How VAEs learn smooth and continuous representations
16+
17+
### Why it’s useful
18+
A 2D latent space allows easy visualization of how the model separates digits.
19+
20+
---
21+
22+
## ⚠️ VaeMNISTAnomaly.java
23+
Uses a trained VAE for **anomaly detection** on MNIST.
24+
25+
### Key Concepts
26+
- VAEs reconstruct normal data well
27+
- They reconstruct anomalies poorly
28+
- Reconstruction error can be used as an anomaly score
29+
30+
### What this example demonstrates
31+
- Reconstructing MNIST digits
32+
- Computing reconstruction probability
33+
- Detecting out-of-distribution or corrupted samples
34+
35+
---
36+
37+
## ✔ How to Run Any Example
38+
39+
mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.quickstart.modeling.variationalautoencoder.<ClassName>"
40+
41+
42+
Example:
43+
44+
45+
46+
mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.quickstart.modeling.variationalautoencoder.VaeMNIST2dPlots"
47+
48+
49+
---
50+
51+
## 🙌 Why This README Helps
52+
The VAE folder previously had no documentation.
53+
This README explains:
54+
- Purpose of each example
55+
- The ML concepts involved
56+
- How to run and understand the results
57+
58+
This improves clarity for beginners working with generative models in DL4J.

dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/VaeMNIST2dPlots.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
* SPDX-License-Identifier: Apache-2.0
1818
******************************************************************************/
1919

20+
// VaeMNIST2dPlots.java
21+
// Trains a Variational Autoencoder (VAE) on MNIST and visualizes the 2D latent space.
22+
// Shows how VAEs compress images into smooth, continuous latent variables.
23+
2024
package org.deeplearning4j.examples.quickstart.modeling.variationalautoencoder;
2125

2226
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
@@ -98,10 +102,13 @@ public static void main(String[] args) throws IOException {
98102
MultiLayerNetwork net = new MultiLayerNetwork(conf);
99103
net.init();
100104

105+
// Build a VAE with a 2-dimensional latent space for visualization
106+
101107
//Get the variational autoencoder layer
102108
org.deeplearning4j.nn.layers.variational.VariationalAutoencoder vae
103109
= (org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) net.getLayer(0);
104110

111+
105112

106113
//Test data for plotting
107114
DataSet testdata = new MnistDataSetIterator(10000, false, rngSeed).next();
@@ -121,12 +128,15 @@ public static void main(String[] args) throws IOException {
121128
// (b) collect the reconstructions at each point in the grid
122129
net.setListeners(new PlottingListener(100, testFeatures, latentSpaceGrid, latentSpaceVsEpoch, digitsGrid));
123130

124-
//Perform training
131+
//Perform training
132+
// Train the VAE to encode and reconstruct MNIST digits
125133
for (int i = 0; i < nEpochs; i++) {
126134
log.info("Starting epoch {} of {}",(i+1),nEpochs);
127135
net.pretrain(trainIter); //Note use of .pretrain(DataSetIterator) not fit(DataSetIterator) for unsupervised training
128136
}
129137

138+
// Visualize how different digit classes cluster in the latent space
139+
130140
//plot by default
131141
if (visualize) {
132142
//Plot MNIST test set - latent space vs. iteration (every 100 minibatches by default)

dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/VaeMNISTAnomaly.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
* SPDX-License-Identifier: Apache-2.0
1818
******************************************************************************/
1919

20+
// VaeMNISTAnomaly.java
21+
// Demonstrates anomaly detection using a Variational Autoencoder (VAE).
22+
// Normal digits reconstruct well, anomalies reconstruct poorly.
23+
2024
package org.deeplearning4j.examples.quickstart.modeling.variationalautoencoder;
2125

2226
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
@@ -94,6 +98,9 @@ public static void main(String[] args) throws IOException {
9498
.build())
9599
.build();
96100

101+
// Load trained VAE model and MNIST test data
102+
103+
97104
MultiLayerNetwork net = new MultiLayerNetwork(conf);
98105
net.init();
99106

@@ -151,6 +158,10 @@ public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
151158
Collections.sort(list, c);
152159
}
153160

161+
// Compute reconstruction probability for each test image
162+
// Low probability indicates an anomaly
163+
164+
154165
//Select the 5 best and 5 worst numbers (by reconstruction probability) for each digit
155166
List<INDArray> best = new ArrayList<>(50);
156167
List<INDArray> worst = new ArrayList<>(50);

0 commit comments

Comments
 (0)