yonesuke

JAX

0
0
# Install this skill:
npx skills add yonesuke/skills --skill "JAX"

Install specific skill from multi-skill repository

# Description

Essential tools for using JAX in machine learning and mathematical analysis, covering core concepts, transformations, ML specifics, control flow, and parallelism.

# SKILL.md


name: JAX
description: Essential tools for using JAX in machine learning and mathematical analysis, covering core concepts, transformations, ML specifics, control flow, and parallelism.


JAX Skill

JAX is Autograd and XLA, brought together for high-performance machine learning research.

Contents

  • Concepts & Theory
  • Immutability
  • The 4 Transformations
  • Pytrees
  • Code Examples
  • jit, grad, vmap, random usage
  • Control Flow (scan, cond, fori_loop)
  • Parallelism (sharding)

Common Workflows

1. Developing a new Model

  1. Define your parameters as a Pytree (dict/dataclass).
  2. Define your forward pass function (pure).
  3. Define your loss function.
  4. Use jax.value_and_grad to get gradients.
  5. Use jax.jit to speed up the update step.
  6. See examples.md for snippets.

2. Debugging Shapes/NaNs

  1. Disable JIT: jax.config.update("jax_disable_jit", True) to debug with standard python tools.
  2. Use jax.debug.print inside JITted functions.

# Supported AI Coding Agents

This skill is compatible with the SKILL.md standard and works with all major AI coding agents:

Learn more about the SKILL.md standard and how to use these skills with your preferred AI coding agent.