JAX入門ガイド:Pythonで自動微分・JIT・GPU計算を始めよう

機械学習

はじめに:JAXとは? 🤔

JAXは、高性能な数値計算、特に機械学習研究のために設計されたPythonライブラリです。Google Research(現 Google DeepMind)によって開発されました。NumPyと非常によく似たAPIを提供しつつ、自動微分、JITコンパイル、GPU/TPUサポートなどの強力な機能を追加しています。

一言で言えば、JAXは「自動微分とXLA(Accelerated Linear Algebra)を組み合わせ、高性能な機械学習研究を実現するライブラリ」です。NumPyに慣れている開発者であれば、比較的スムーズに移行できます。

JAXの主な特徴は以下の通りです:

  • NumPyライクなAPI: 既存のNumPyコードからの移行が容易です。
  • 自動微分(grad): 複雑な関数の勾配を自動的に計算します。
  • JITコンパイル(jit): Python関数をXLAでコンパイルし、CPU、GPU、TPU上で高速に実行します。
  • 自動ベクトル化(vmap): 関数を自動的にベクトル化し、バッチ処理を効率化します。
  • 並列化(pmap): 複数のデバイス(GPU/TPUコア)に計算を分散させます。

これらの機能により、研究者やエンジニアは、使い慣れたPython環境で、最新のアクセラレータを活用した高速な計算処理を記述できます。

JAXのコア機能 ✨

JAXの強力さは、そのコアとなる関数変換(Function Transformations)に由来します。主要なものをいくつか見ていきましょう。

機械学習、特にニューラルネットワークの学習において、損失関数の勾配計算は不可欠です。JAXのgrad関数は、Python関数を受け取り、その勾配を計算する新しい関数を返します。

gradは、ループ、分岐、再帰、クロージャを含むネイティブなPythonやNumPyの関数を微分できます。また、高階微分(微分の微分)も可能です。リバースモード自動微分(バックプロパゲーション)だけでなく、フォワードモード自動微分もサポートしており、これらを任意に組み合わせることもできます。


import jax
import jax.numpy as jnp

# 簡単な関数を定義
def sum_of_squares(x):
  return jnp.sum(x**2)

# x = [1.0, 2.0, 3.0] での勾配を計算
x = jnp.array([1.0, 2.0, 3.0])

# 勾配関数を取得
grad_fn = jax.grad(sum_of_squares)

# 勾配を計算して表示
gradients = grad_fn(x)
print(f"勾配: {gradients}")
# 出力: 勾配: [2. 4. 6.] (2*x に対応)
    

jit (Just-In-Time) コンパイラは、JAXの計算を高速化する鍵となります。jitでデコレートされたPython関数は、XLA(Accelerated Linear Algebra)コンパイラによって最適化され、CPU、GPU、TPUなどのアクセラレータ上で効率的に実行されるマシンコードにコンパイルされます。

通常、JAXは演算を一つずつディスパッチしますが、jitを使用すると、一連の演算をまとめてコンパイルし、Pythonのオーバーヘッドを削減し、演算の融合(fusion)などの最適化を可能にします。


import time
import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(0)
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)

# JITなしで実行
start_time = time.time()
jnp.dot(x, x.T).block_until_ready() # block_until_ready() で非同期実行を待つ
print(f"JITなし実行時間: {time.time() - start_time:.4f} 秒")

# JITありで実行
@jax.jit
def optimized_dot(x):
  return jnp.dot(x, x.T)

# 最初の実行でコンパイルされる
optimized_dot(x).block_until_ready()

# 2回目の実行(コンパイル済みコードを使用)
start_time = time.time()
optimized_dot(x).block_until_ready()
print(f"JITあり実行時間: {time.time() - start_time:.4f} 秒")
    

jitを使用する際の注意点として、コンパイルは関数の入力の型と形状(shape)に基づいて行われるため、これらが変わると再コンパイルが発生します。また、副作用を持つ関数(例:print文やグローバル変数の変更)はjitと相性が悪いです。(デバッグ用のプリントにはjax.debug.printが利用できます)。

vmap (vectorizing map) は、関数を自動的にベクトル化するための変換です。ループを書く代わりに、vmapを使うと、関数がバッチ処理を行うように変換され、パフォーマンスが向上します。特に、jitと組み合わせることで、手動でバッチ次元を追加して書き直したコードと同等のパフォーマンスを発揮することがあります。

例えば、行列とベクトルの積を計算する関数を、vmapを使って行列と行列の積(複数のベクトルに同時に適用)に拡張できます。


import jax
import jax.numpy as jnp

# 行列とベクトルの積
mat = jnp.array([[1., 2.], [3., 4.]])
vec = jnp.array([5., 6.])
result = jnp.dot(mat, vec)
print(f"行列とベクトルの積: {result}")

# 複数のベクトル (バッチ)
vecs = jnp.array([[5., 6.], [7., 8.]])

# vmapを使って行列と複数のベクトルの積を計算
# jnp.dot の最初の引数 (行列) は固定し、2番目の引数 (ベクトル) の次元0に沿ってマップする
batch_dot = jax.vmap(jnp.dot, in_axes=(None, 0), out_axes=0)
batch_result = batch_dot(mat, vecs)
print(f"vmapによるバッチ処理結果:\n{batch_result}")

# NumPyでの同等の計算 (確認用)
numpy_result = jnp.dot(mat, vecs.T).T
print(f"NumPyでの確認結果:\n{numpy_result}")
    

vmapはネストすることも可能で、複雑なバッチ処理や高次元データの操作を簡潔に記述できます。

pmap (parallel map) は、複数のデバイス(例:複数のGPUやTPUコア)にまたがって計算を並列化するための変換です。SPMD (Single-Program, Multiple-Data) スタイルのプログラミングを可能にします。

pmapvmapと似ていますが、vmapが単一デバイス内でベクトル化するのに対し、pmapは関数を各デバイスに複製し、データの異なる部分をそれぞれのデバイスで並列に処理します。これにより、大規模なモデルの学習やデータ並列処理を効率的に行うことができます。

pmapは内部でjitと同様にXLAによるコンパイルを行うため、通常jitと組み合わせる必要はありません。使用する際には、マップする軸のサイズが利用可能なローカルデバイス数以下である必要があります。


import jax
import jax.numpy as jnp

# デバイス数を取得 (CPUのみの場合は1になる)
print(f"利用可能なローカルデバイス数: {jax.local_device_count()}")

# 8個のデバイスで実行することを想定したダミーデータ
data = jnp.arange(8 * 3).reshape(8, 3)

# 各デバイスで実行される関数 (ここでは単純に要素を2倍)
@jax.pmap
def parallel_double(x):
  print("コンパイル中...") # デバイスごとにコンパイルされる
  return x * 2

# pmapで関数を実行
# data の最初の次元 (サイズ8) がデバイスに分散される
result = parallel_double(data)
print(f"pmap実行結果:\n{result}")
# 注意: 実行環境に複数のデバイスがない場合、動作が異なる可能性があります。
# CPU環境では、`pmap` は `vmap` のように振る舞うことがあります。
# 複数CPUコアを使うには環境変数の設定が必要な場合があります。
    

pmapは、デバイス間の集合演算(例:all-reduce)もサポートしており、分散学習の実装を容易にします。

JAX と NumPy の比較 ⚖️

JAXはNumPyと非常に似たAPIを提供しているため、NumPyユーザーにとっては学習コストが比較的低いです。import numpy as npimport jax.numpy as jnp に置き換えるだけで、多くのコードが動作します。しかし、いくつかの重要な違いがあります。

特徴 JAX (jax.numpy) NumPy
実行環境 CPU, GPU, TPUに対応 CPUのみ
実行モデル 非同期ディスパッチ (デフォルト)、JITコンパイル (XLA) 同期的
自動微分 jax.grad などで可能 不可 (外部ライブラリが必要)
ベクトル化/並列化 jax.vmap, jax.pmap で自動化 手動での実装が必要
配列の不変性 (Immutability) 配列は不変 (変更不可) 配列は可変 (インプレース変更可能)
乱数生成 明示的なキー (PRNGKey) が必要 (状態を持たない純粋関数) グローバルな状態を持つ
データ型 デフォルトで32ビット浮動小数点数 (float32) を好む傾向 デフォルトで64ビット浮動小数点数 (float64)
プログラミングパラダイム 関数型プログラミングを推奨 手続き型/オブジェクト指向

JAXの最も大きな特徴の一つが配列の不変性です。NumPyでは arr[0] = 100 のように配列の一部を直接変更できますが、JAXではこのような操作はエラーになります。代わりに、変更を伴う操作では新しい配列が返されます。例えば、配列の一部を更新するには jax.numpy.ndarray.at[index].set(value) のような構文を使用します。これは純粋関数という関数型プログラミングの思想に基づいており、jitなどの変換を適用しやすくするためです。


import jax.numpy as jnp

arr_np = np.array([1, 2, 3])
arr_np[0] = 100 # NumPyではOK
print(f"NumPy 配列 (変更後): {arr_np}")

arr_jax = jnp.array([1, 2, 3])
# arr_jax[0] = 100 # これはエラーになる

# JAXでの更新方法
arr_jax_updated = arr_jax.at[0].set(100)
print(f"JAX 配列 (元の配列): {arr_jax}")
print(f"JAX 配列 (更新後の新しい配列): {arr_jax_updated}")
    

JAXの乱数生成は状態を持たないため、NumPyとは異なります。jax.randomモジュールを使用し、乱数を生成する際には必ず PRNGKey (Pseudo-Random Number Generator Key) を渡す必要があります。同じキーを使えば常に同じ乱数シーケンスが生成され、異なるキーを使えば異なるシーケンスが得られます。これにより、実験の再現性が保証されやすくなります。キーは使用後に分割 (split) して新しいキーを生成し、再利用しないことが推奨されます。


import jax.random as random

key = random.PRNGKey(42) # シード値からキーを生成
print(f"最初のキー: {key}")

# キーを使って乱数を生成
random_nums = random.normal(key, (3,))
print(f"生成された乱数: {random_nums}")

# キーを分割して新しいキーを生成
key, subkey = random.split(key)
print(f"分割後のキー: {key}")
print(f"使用するサブキー: {subkey}")

# サブキーを使って次の乱数を生成
more_random_nums = random.uniform(subkey, (2,))
print(f"次に生成された乱数: {more_random_nums}")
    

JAXのエコシステム 🌳

JAX自体は、コアとなる数値計算と関数変換に焦点を当てたライブラリですが、その周りには活発なエコシステムが形成されています。これにより、JAXを基盤として、より高度な機械学習タスクや特定の応用分野に対応できます。

主要なライブラリには以下のようなものがあります:

  • Flax: Googleによって開発された、柔軟性の高いニューラルネットワークライブラリ。モジュラーな設計で、研究用途に適しています。以前のLinen APIと、よりPyTorchライクな新しいNNX APIがあります。
  • Haiku: DeepMindによって開発されたニューラルネットワークライブラリ。Sonnet(TensorFlow向け)に似たAPIを持ち、シンプルさと明示的な状態管理を特徴とします。
  • Optax: DeepMindによって開発された勾配処理と最適化のためのライブラリ。様々なオプティマイザや学習率スケジューリングを提供します。
  • Equinox: JAXのためのPyTorchライクなニューラルネットワークライブラリ。PyTreeを最大限に活用し、関数変換との親和性が高い設計です。
  • NumPyro: Uberによって開発された確率的プログラミング言語。PyroのJAXバックエンドとして機能し、ベイジアンモデリングと推論を高速化します。
  • JAX MD: 分子動力学シミュレーションのためのライブラリ。JAXの自動微分やハードウェアアクセラレーションを活用します。
  • Brax: 高速で並列化された物理シミュレーションエンジン。強化学習の研究などに利用されます。
  • Chex: DeepMindによるユーティリティライブラリ。JAXコードのテストや信頼性向上のためのアサーションなどを提供します。
  • Orbax: チェックポイント(モデルの保存・読み込み)管理のためのライブラリ。大規模なモデルや分散環境での状態管理を容易にします。
  • TensorFlow Datasets (TFDS) / Hugging Face Datasets: これらのデータセットライブラリはJAXプロジェクトでも広く利用されており、データローディングパイプラインを構築するのに役立ちます。

これらのライブラリは、JAXのコア機能を補完し、特定のタスク(ニューラルネットワークの構築、最適化、確率的モデリング、物理シミュレーションなど)をより容易かつ効率的に行うことを可能にします。JAXのエコシステムは現在も活発に開発が進んでいます。

はじめの一歩 🐾

JAXを使い始めるのは簡単です。pipを使ってインストールできます。


# CPUのみサポートする場合
pip install -U "jax[cpu]"

# NVIDIA GPU (CUDA 12) をサポートする場合
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# TPUサポート (Google Cloud TPUなど)
# pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# (TPU環境に応じて適切なバージョンや手順が必要になる場合があります)
    

インストール後、NumPyライクな操作、自動微分、JITコンパイルを組み合わせてみましょう。


import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

# 1. NumPyライクな操作
key = random.PRNGKey(0)
x = random.normal(key, (10,)) # 乱数ベクトル生成
y = jnp.sin(x) + 0.1 * random.normal(key, (10,)) # ノイズを加える

# 2. モデルと損失関数を定義 (簡単な線形回帰を想定)
def model(theta, x):
  """線形モデル: y = w*x + b"""
  w, b = theta
  return w * x + b

def loss_fn(theta, x, y):
  """平均二乗誤差損失"""
  prediction = model(theta, x)
  return jnp.mean((prediction - y)**2)

# 3. 自動微分を使って損失関数の勾配関数を取得
grad_loss = grad(loss_fn)

# 4. JITコンパイルで学習ステップを高速化
@jit
def update_step(theta, x, y, learning_rate=0.1):
  """勾配降下法による1ステップ更新"""
  gradients = grad_loss(theta, x, y)
  # パラメータ更新 (JAXは不変なので新しいthetaを返す)
  new_theta = [
      theta[0] - learning_rate * gradients[0],
      theta[1] - learning_rate * gradients[1]
  ]
  return new_theta, loss_fn(new_theta, x, y) # 新しいパラメータと損失を返す

# 初期パラメータ (w=0, b=0 を想定)
initial_theta = [jnp.zeros(()), jnp.zeros(())]

# 学習ループ (簡略版)
theta = initial_theta
print("学習開始...")
for step in range(100):
  theta, loss_value = update_step(theta, x, y)
  if step % 10 == 0:
    print(f"ステップ {step}, 損失: {loss_value:.4f}")

print("学習完了!")
print(f"最終パラメータ (w, b): {theta}")
print(f"最終損失: {loss_value:.4f}")

# vmapの使用例 (もしデータがバッチ処理されていたら)
# batch_x = random.normal(key, (5, 10)) # 5つのサンプル、各10次元
# batch_y = jnp.sin(batch_x) + 0.1 * random.normal(key, (5, 10))
# バッチに対応した損失関数
# batch_loss_fn = vmap(loss_fn, in_axes=(None, 0, 0)) # thetaは固定、xとyはバッチ次元でマップ
# batch_loss = batch_loss_fn(theta, batch_x, batch_y)
# print(f"バッチ全体の損失: {jnp.mean(batch_loss)}")
    

考慮事項とまとめ 🤔💡

JAXは非常に強力なライブラリですが、いくつかの特性を理解しておくことが重要です。

  • 関数型プログラミング: JAXは純粋関数(副作用がなく、同じ入力に対して常に同じ出力を返す関数)を基本としています。これにより、jit, grad, vmap, pmap といった強力な変換が可能になりますが、状態の管理や副作用のあるコードの扱いに慣れが必要です。
  • 不変性: 配列が変更不可であるため、NumPyに慣れていると最初は戸惑うかもしれません。.at[...].set(...) 構文などを使った更新方法に慣れる必要があります。
  • JITコンパイルの制約: jitでコンパイルする関数内でのPythonの動的な制御フロー(データに依存する条件分岐やループ)には制約があります。jax.lax.condjax.lax.scan など、JAXが提供する制御フロー構文を使うことが推奨されます。
  • デバッグ: JITコンパイルされたコードや非同期実行は、デバッグを少し難しくすることがあります。jax.debug モジュールや、コンパイルを無効にする設定 (jax.disable_jit()) が役立ちます。

まとめ: JAXは、NumPyの使いやすさと、自動微分、ハードウェアアクセラレーション、関数変換といった最新の研究開発に必要な機能を融合させた画期的なライブラリです。特に機械学習の研究開発において、そのパフォーマンスと柔軟性から注目を集めています。関数型プログラミングのパラダイムに慣れれば、非常に効率的でスケーラブルな数値計算コードを記述するための強力なツールとなるでしょう。エコシステムも成長しており、今後ますます多くの分野での活用が期待されます。 😊

コメント

タイトルとURLをコピーしました