GPU Development¶
Overview¶
Guidelines for GPU-accelerated development with Opifex using JAX and CUDA.
GPU Setup¶
CUDA Installation¶
# Install CUDA toolkit
wget https://developer.download.nvidia.com/compute/cuda/12.6.0/local_installers/cuda_12.6.0_560.28.03_linux.run
sudo sh cuda_12.6.0_560.28.03_linux.run
# Verify installation
nvcc --version
nvidia-smi
JAX GPU Configuration¶
import jax
print("Available devices:", jax.devices())
print("Default backend:", jax.default_backend())
# Force GPU usage
jax.config.update('jax_platform_name', 'gpu')
Memory Management¶
GPU Memory Optimization¶
# Enable memory preallocation
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
# Monitor memory usage
print(f"GPU memory: {jax.extend.backend.get_backend().get_device_memory_info()}")
Batch Size Tuning¶
# Auto-tune batch size for memory
def find_optimal_batch_size(model, max_size=1024):
for batch_size in [32, 64, 128, 256, 512, 1024]:
try:
test_batch = jnp.ones((batch_size, input_dim))
_ = model(test_batch)
print(f"Batch size {batch_size}: OK")
except RuntimeError: # JAX raises RuntimeError for OOM conditions
print(f"Batch size {batch_size}: OOM")
return batch_size // 2
return max_size
Performance Optimization¶
JIT Compilation¶
@jax.jit
def optimized_training_step(params, batch):
"""JIT-compiled training step for speed."""
loss, grads = jax.value_and_grad(loss_fn)(params, batch)
return loss, grads
# Compile once, run many times
compiled_step = jax.jit(training_step)
PMAP for Multi-GPU¶
# Parallel computation across GPUs
@jax.pmap
def parallel_training_step(params, batch):
"""Multi-GPU training step."""
return jax.value_and_grad(loss_fn)(params, batch)
# Replicate across devices
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(
lambda x: jnp.stack([x] * n_devices), params
)
Profiling and Debugging¶
Performance Profiling¶
# Profile GPU kernels
with jax.profiler.trace("/tmp/tensorboard"):
for i in range(100):
result = model(batch)
# View in TensorBoard
# tensorboard --logdir=/tmp/tensorboard
Memory Profiling¶
# Track memory allocation
def memory_usage():
device = jax.devices()[0]
stats = device.memory_stats()
return stats['bytes_in_use'] / 1e9 # GB
print(f"Memory before: {memory_usage():.2f} GB")
result = large_computation()
print(f"Memory after: {memory_usage():.2f} GB")
Best Practices¶
Code Patterns¶
- Vectorization: Use
jax.vmapfor batched element-wise operations - Broadcasting: Leverage JAX broadcasting for efficiency
- Tree Operations: Use
jax.tree_mapfor nested structures - Gradient Checkpointing: Save memory with
TrainingConfig(gradient_checkpointing=True)
Common Pitfalls¶
- Host-device transfers (minimize)
- Shape mismatches (check dimensions)
- Memory leaks (clear unused arrays)
- Sequential operations (vectorize when possible)