Refactor high-complexity React components in Dify frontend. Use when `pnpm analyze-component...
npx skills add yonesuke/skills --skill "developing-flax-models"
Install specific skill from multi-skill repository
# Description
A comprehensive guide for developing, training, and managing neural networks using Flax NNX. Use when defining models, managing state, or writing training loops.
# SKILL.md
name: developing-flax-models
description: A comprehensive guide for developing, training, and managing neural networks using Flax NNX. Use when defining models, managing state, or writing training loops.
Developing Flax Models
This skill helps you develop neural networks using Flax NNX, the object-oriented module system for JAX. Use this skill when you need to define models, handle state/randomness, or implement training loops.
Workflows
1. Implement a New Model
- Define Module: Subclass
nnx.Module. Define layers/parameters in__init__. - Handle Randomness: Pass
nnx.Rngsto__init__for weight initialization. Passrngsto__call__for stochastic operations (e.g., Dropout). - Sanity Check: Add a
if __name__ == "__main__":block to instantiate the model and run a dummy forward pass.
2. Implement a Training Loop
- Choose Strategy:
- Automatic (
nnx.jit): Easiest. Use@nnx.jiton your update function. Handles mutable state management automatically. - Functional (
nnx.split/nnx.merge): Use for advanced control or when interfacing with pure JAX transformations likescanorvmap(thoughnnx.vmapexists).
- Automatic (
- Define Loss: Write a loss function
loss_fn(model, batch). - Optimizer: Wrap the model with
nnx.Optimizer(model, tx, wrt=nnx.Param).
> [!WARNING]
> Crucial Change: As of Flax 0.11.0, thewrtargument (e.g.,wrt=nnx.Param) is REQUIRED fornnx.Optimizer. Failure to include it will raise aTypeError.
Core Concepts (Reference)
Flax NNX (v2.0+) replaces the immutable, functional design of flax.linen with standard Python classes and mutable state, while maintaining JAX compatibility.
Key Differences
- Object-Oriented: Models are standard Python classes. You assign to
self.param. - Reference Semantics: Layers hold their parameters directly.
- Not Pytrees:
nnx.Moduleobjects are not Pytrees. You cannot pass them directly tojax.jit. You must usennx.jitor manually split/merge state.
Variable Types
NNX variables allow granular state management via "Collections".
* nnx.Param: Trainable parameters (weights, biases).
* nnx.BatchStat: Batch normalization statistics (running mean/var).
* nnx.Rngs: Random Number Generator streams key management.
* nnx.Variable: Base class for custom state.
State Management
You can filter and manipulate state sets:
# Get only Parameters
params = nnx.state(model, nnx.Param)
# Get everything EXCEPT BatchStats
state = nnx.state(model, filter=nnx.All - nnx.BatchStat)
Examples
See examples.md for detailed code patterns mirrored from the scripts/ directory.
* Defining Modules: Basic layer structure.
* Randomness: Handling Dropout and stochastic layers.
* Training: Comparison of nnx.jit vs Pure JAX loops.
* Functional API: Using vmap and split/merge.
# Supported AI Coding Agents
This skill is compatible with the SKILL.md standard and works with all major AI coding agents:
Learn more about the SKILL.md standard and how to use these skills with your preferred AI coding agent.