Skip to content

Commit b87cb41

Browse files
authored
Merge pull request #1094 from NA-V10/add-vae-readme
Add README and comments for VAE examples (variational autoencoder folder)
2 parents 9d64fbc + 8c3dc25 commit b87cb41

File tree

7 files changed

+174
-2
lines changed

7 files changed

+174
-2
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/MemorizeSequence.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
* SPDX-License-Identifier: Apache-2.0
1818
******************************************************************************/
1919

20+
// MemorizeSequence.java
21+
// A simple RNN example where the network learns to memorize and reproduce a short sequence.
22+
// Demonstrates basic RNN training and backpropagation-through-time.
23+
2024
package org.deeplearning4j.examples.quickstart.modeling.recurrent;
2125

2226
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@@ -87,6 +91,8 @@ public static void main(String[] args) {
8791
hiddenLayerBuilder.activation(Activation.TANH);
8892
listBuilder.layer(i, hiddenLayerBuilder.build());
8993
}
94+
95+
// Build a simple RNN with one recurrent layer to memorize the sequence
9096

9197
// we need to use RnnOutputLayer for our RNN
9298
RnnOutputLayer.Builder outputLayerBuilder = new RnnOutputLayer.Builder(LossFunction.MCXENT);
@@ -96,7 +102,8 @@ public static void main(String[] args) {
96102
outputLayerBuilder.nIn(HIDDEN_LAYER_WIDTH);
97103
outputLayerBuilder.nOut(LEARNSTRING_CHARS.size());
98104
listBuilder.layer(HIDDEN_LAYER_CONT, outputLayerBuilder.build());
99-
105+
106+
100107
// create network
101108
MultiLayerConfiguration conf = listBuilder.build();
102109
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@@ -122,6 +129,9 @@ public static void main(String[] args) {
122129
labels.putScalar(new int[] { 0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), samplePos }, 1);
123130
samplePos++;
124131
}
132+
133+
// Train the RNN to output the same sequence it receives as input
134+
125135
DataSet trainingData = new DataSet(input, labels);
126136

127137
// some epochs
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Recurrent Neural Network (RNN) Examples – DeepLearning4J
2+
3+
This folder contains simple recurrent neural network (RNN) examples using LSTM and embedding layers.
4+
These examples demonstrate how to work with sequential data such as signals, sequences, and text-like inputs.
5+
6+
---
7+
8+
## 🔁 MemorizeSequence.java
9+
A minimal RNN example where the network learns to memorize and reproduce a fixed sequence.
10+
11+
### What this example teaches
12+
- How RNNs store information across time steps
13+
- How backpropagation-through-time works
14+
- How sequence learning differs from feedforward networks
15+
16+
### Expected Behavior
17+
The network eventually outputs the same sequence it was trained on.
18+
19+
---
20+
21+
## 🔡 RNNEmbedding.java
22+
Demonstrates the use of an **EmbeddingLayer** followed by RNN layers.
23+
24+
### Key Concepts
25+
- Turning integer-encoded inputs into dense vectors
26+
- Word/token embedding
27+
- Passing embedded sequences into RNN layers
28+
29+
This is a useful template for NLP-style models.
30+
31+
---
32+
33+
## 📊 UCISequenceClassification.java
34+
Sequence classification on a dataset from the UCI machine learning repository.
35+
36+
### What this example shows
37+
- Loading sequential datasets
38+
- Recurrent classification (predict a label for a whole sequence)
39+
- Time-series preprocessing and normalization
40+
41+
---
42+
43+
## ✔ How to Run Any Example
44+
45+
Use the following command:
46+
47+
mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.quickstart.modeling.recurrent.<ClassName>"
48+
49+
50+
Example:
51+
52+
53+
54+
mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.quickstart.modeling.recurrent.MemorizeSequence"
55+
56+
57+
---
58+
59+
## 🙌 Why This README Helps
60+
61+
This folder previously had no documentation.
62+
This README explains:
63+
- What each RNN example does
64+
- What concepts it teaches
65+
- How to run each file
66+
67+
This improves clarity for beginners working with sequential models in DL4J.

dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/RNNEmbedding.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
* SPDX-License-Identifier: Apache-2.0
1818
******************************************************************************/
1919

20+
21+
// RNNEmbedding.java
22+
// Demonstrates how to use an EmbeddingLayer + RNN layers for sequence modeling.
23+
// Useful for NLP-style integer token inputs.
24+
2025
package org.deeplearning4j.examples.quickstart.modeling.recurrent;
2126

2227
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@@ -45,6 +50,9 @@
4550
*
4651
* @author Alex Black
4752
*/
53+
54+
// Convert integer token IDs into dense embedding vectors
55+
4856
public class RNNEmbedding {
4957
public static void main(String[] args) {
5058

@@ -64,6 +72,10 @@ public static void main(String[] args) {
6472
}
6573
}
6674

75+
76+
// Feed embedded vectors into an LSTM/RNN to capture sequence structure
77+
78+
6779
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
6880
.activation(Activation.RELU)
6981
.list()

dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/recurrent/UCISequenceClassification.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
* SPDX-License-Identifier: Apache-2.0
1818
******************************************************************************/
1919

20+
21+
// UCISequenceClassification.java
22+
// Demonstrates sequence classification using an RNN on UCI dataset sequences.
23+
2024
package org.deeplearning4j.examples.quickstart.modeling.recurrent;
2125

2226
import org.apache.commons.io.FileUtils;
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Variational Autoencoder (VAE) Examples – DeepLearning4J
2+
3+
This folder contains two examples demonstrating how to use Variational Autoencoders (VAEs) in DeepLearning4J.
4+
VAEs are generative models that learn latent representations of data and can be used for visualization, sampling, and anomaly detection.
5+
6+
---
7+
8+
## 🧩 VaeMNIST2dPlots.java
9+
Trains a VAE on the MNIST digit dataset and visualizes the **2-dimensional latent space**.
10+
11+
### What this example shows
12+
- How to build a VAE in DL4J
13+
- Encoding MNIST images into a 2D latent space
14+
- Plotting how digits cluster in latent space
15+
- How VAEs learn smooth and continuous representations
16+
17+
### Why it’s useful
18+
A 2D latent space allows easy visualization of how the model separates digits.
19+
20+
---
21+
22+
## ⚠️ VaeMNISTAnomaly.java
23+
Uses a trained VAE for **anomaly detection** on MNIST.
24+
25+
### Key Concepts
26+
- VAEs reconstruct normal data well
27+
- They reconstruct anomalies poorly
28+
- Reconstruction error can be used as an anomaly score
29+
30+
### What this example demonstrates
31+
- Reconstructing MNIST digits
32+
- Computing reconstruction probability
33+
- Detecting out-of-distribution or corrupted samples
34+
35+
---
36+
37+
## ✔ How to Run Any Example
38+
39+
mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.quickstart.modeling.variationalautoencoder.<ClassName>"
40+
41+
42+
Example:
43+
44+
45+
46+
mvn -q exec:java -Dexec.mainClass="org.deeplearning4j.examples.quickstart.modeling.variationalautoencoder.VaeMNIST2dPlots"
47+
48+
49+
---
50+
51+
## 🙌 Why This README Helps
52+
The VAE folder previously had no documentation.
53+
This README explains:
54+
- Purpose of each example
55+
- The ML concepts involved
56+
- How to run and understand the results
57+
58+
This improves clarity for beginners working with generative models in DL4J.

dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/VaeMNIST2dPlots.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
* SPDX-License-Identifier: Apache-2.0
1818
******************************************************************************/
1919

20+
// VaeMNIST2dPlots.java
21+
// Trains a Variational Autoencoder (VAE) on MNIST and visualizes the 2D latent space.
22+
// Shows how VAEs compress images into smooth, continuous latent variables.
23+
2024
package org.deeplearning4j.examples.quickstart.modeling.variationalautoencoder;
2125

2226
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
@@ -98,10 +102,13 @@ public static void main(String[] args) throws IOException {
98102
MultiLayerNetwork net = new MultiLayerNetwork(conf);
99103
net.init();
100104

105+
// Build a VAE with a 2-dimensional latent space for visualization
106+
101107
//Get the variational autoencoder layer
102108
org.deeplearning4j.nn.layers.variational.VariationalAutoencoder vae
103109
= (org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) net.getLayer(0);
104110

111+
105112

106113
//Test data for plotting
107114
DataSet testdata = new MnistDataSetIterator(10000, false, rngSeed).next();
@@ -121,12 +128,15 @@ public static void main(String[] args) throws IOException {
121128
// (b) collect the reconstructions at each point in the grid
122129
net.setListeners(new PlottingListener(100, testFeatures, latentSpaceGrid, latentSpaceVsEpoch, digitsGrid));
123130

124-
//Perform training
131+
//Perform training
132+
// Train the VAE to encode and reconstruct MNIST digits
125133
for (int i = 0; i < nEpochs; i++) {
126134
log.info("Starting epoch {} of {}",(i+1),nEpochs);
127135
net.pretrain(trainIter); //Note use of .pretrain(DataSetIterator) not fit(DataSetIterator) for unsupervised training
128136
}
129137

138+
// Visualize how different digit classes cluster in the latent space
139+
130140
//plot by default
131141
if (visualize) {
132142
//Plot MNIST test set - latent space vs. iteration (every 100 minibatches by default)

dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/variationalautoencoder/VaeMNISTAnomaly.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
* SPDX-License-Identifier: Apache-2.0
1818
******************************************************************************/
1919

20+
// VaeMNISTAnomaly.java
21+
// Demonstrates anomaly detection using a Variational Autoencoder (VAE).
22+
// Normal digits reconstruct well, anomalies reconstruct poorly.
23+
2024
package org.deeplearning4j.examples.quickstart.modeling.variationalautoencoder;
2125

2226
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
@@ -94,6 +98,9 @@ public static void main(String[] args) throws IOException {
9498
.build())
9599
.build();
96100

101+
// Load trained VAE model and MNIST test data
102+
103+
97104
MultiLayerNetwork net = new MultiLayerNetwork(conf);
98105
net.init();
99106

@@ -151,6 +158,10 @@ public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
151158
Collections.sort(list, c);
152159
}
153160

161+
// Compute reconstruction probability for each test image
162+
// Low probability indicates an anomaly
163+
164+
154165
//Select the 5 best and 5 worst numbers (by reconstruction probability) for each digit
155166
List<INDArray> best = new ArrayList<>(50);
156167
List<INDArray> worst = new ArrayList<>(50);

0 commit comments

Comments
 (0)