Skip to content

Aneeshers/NNX-Control

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 

Repository files navigation

NNX-Control

High-performance JAX control environments with end-to-end PPO training in a single file.

Features

  • Pure JAX Implementation: Fully JITtable, vmapable, and scanable control environments with no Gym/Gymnax dependencies
  • JAX-Native Rendering: Hardware-accelerated rendering using pure JAX operations
  • NNX-Based PPO: Proximal Policy Optimization implemented with Flax NNX (not Flax Linen) for modern JAX neural networks
  • One-File Implementation: Complete environment, training loop, and rendering in a single, self-contained file
  • Parallel Training: Vectorized rollouts across multiple parallel environments for maximum throughput

Supported Environments

Currently supports:

  • CartPole-v1: Classic cart-pole balancing task with pure-JAX physics simulation

Performance

The implementation leverages JAX's compilation and parallelization capabilities:

  • JIT compilation for fast environment steps and policy updates
  • vmap for batched environment rollouts (64 parallel envs)
  • scan for efficient sequential operations in training loops
  • Pure-JAX rendering with hardware acceleration

Training Results

Training curve showing average reward per episode

Trained policy evaluation on 4 parallel environments

Usage

import jax
from cartpole import CartPole, env

# Create environment
env = CartPole()
params = env.default_params

# Reset environment
key = jax.random.PRNGKey(0)
obs, state = env.reset(key, params)

# Step environment
action = 1  # 0 = left, 1 = right
obs_next, state_next, reward, done, info = env.step(key, state, action, params)

# Render state (returns RGB array)
rgb_array = env.render(state, params)

Training is fully integrated - just run:

python cartpole.py

Related Projects

For a similar high-performance gridworld environment implementation, check out NNX-Gridworld - a super-fast JITtable, vmapped gridworld with JAX-based rendering and one-file PPO for both vision (with JAX-rendered observations) and state-based tasks.

Technical Details

Environment

  • Pure functional API with no mutable state
  • Physics simulation using Euler integration
  • Auto-reset on episode termination for continuous training
  • Compatible with standard RL benchmarking protocols

PPO Implementation

  • Actor-critic architecture with separate networks
  • Generalized Advantage Estimation (GAE)
  • Clipped surrogate objective (ε=0.2)
  • Value function loss with coefficient weighting
  • 4 update iterations per rollout
  • Normalized advantages and returns for training stability

Architecture

  • Feed-forward networks: 4 → 64 → 64 → {2, 1} (actor/critic outputs)
  • ReLU activations
  • Adam optimizer (lr=3e-4)
  • Batch size: 64 parallel environments × 500 timesteps

Requirements

  • JAX
  • Flax (NNX)
  • Optax
  • Matplotlib (for plotting)
  • Imageio (for GIF generation)

About

Extremely fast jitted RL control environments with jax rendering and fast NNX policies

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages