JAX (ライブラリ)のソースを表示
←
JAX (ライブラリ)
ナビゲーションに移動
検索に移動
あなたには「このページの編集」を行う権限がありません。理由は以下の通りです:
この操作は、次のグループに属する利用者のみが実行できます:
登録利用者
。
このページのソースの閲覧やコピーができます。
{{Infobox Software | 名称 = JAX | ロゴ = <!-- ロゴ画像。[[ファイル:example.png|50px]]のようにウィキ構文で指定する --> | スクリーンショット = <!-- スクリーンショット。[[ファイル:example.png|100px]]のようにウィキ構文で指定する --> | 説明文 = <!-- スクリーンショットの説明文 --> | 開発者 = <!-- 人物の名前 --> | 開発元 = [[Google]]、[[NVIDIA]]<ref>{{cite web |url=https://github.com/jax-ml/jax/blob/main/AUTHORS |title=jax/AUTHORS at main · jax-ml/jax |newspaper=Https: |date= |author= |accessdate= December 21, 2024}}</ref><!-- 組織の名前 --> | 初版 = {{Start date and age|2018|12}}<ref>{{Cite web |title=JAX: Accelerating Machine-Learning Research with Composable Function Transformations in Python {{!}} GTC Digital March 2020 {{!}} NVIDIA On-Demand |author= |work=NVIDIA |date= |access-date=23 December 2024 |url= https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21989/}}</ref><!-- 初版の発表日。{{Start date and age|年|月|日}} 等のテンプレートが便利 --> | 最新版 = 0.5.0 | 最新版発表日 = {{Start date and age|2025|01|18}}<ref>{{cite web |url=https://github.com/jax-ml/jax/releases |title=Releases · jax-ml/jax |newspaper=Https: |date= |author= |accessdate= February 2, 2025}}</ref><!-- 最新版(安定版)の発表日。{{Start date and age|年|月|日}} 等のテンプレートが便利 --> | 最新評価版 = | 最新評価版発表日 = <!-- 最新評価版の発表日。{{Start date and age|年|月|日}} 等のテンプレートが便利 --> | リポジトリ = {{GitHub|jax-ml/jax}}<!-- リポジトリのURL --> | プログラミング言語 = [[Python]] | 対応OS = [[Windows]]、[[macOS]]、[[Linux]] | エンジン = <!-- ソフトが使用しているエンジン。ウェブブラウザにおけるレンタリングエンジン(Gecko、WebKit)など --> | 対応プラットフォーム = <!-- CPUアーキテクチャなど --> * [[x86-64]] * [[ARM64]] * [[NVIDIA]] GPU * [[アドバンスト・マイクロ・デバイセズ|AMD]] GPU * [[Intel]] GPU * [[Apple]] GPU * [[Google]] TPU * [[Amazon Web Services|AWS]] Trainium, Inferentia<ref>{{Cite web |title=Installation — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=21 December 2024 |url= https://jax.readthedocs.io/en/latest/installation.html}}</ref><ref>{{cite web |url=https://aws.amazon.com/jp/about-aws/whats-new/2024/09/aws-neuron-nki-nxd-training-jax/ |title=AWS Neuron がトレーニング向け Neuron Kernel Interface (NKI)、NxD Training、JAX のサポートを提供 - AWS |newspaper=Https: |date= |author= |accessdate= December 21, 2024}}</ref> | サイズ = <!-- バイナリのサイズ --> | 対応言語 = | サポート状況 = | 種別 = 数値計算ライブラリ | ライセンス = [[Apache License 2.0]] | 公式サイト = {{url|https://jax.readthedocs.io/}} | 前身 = <!-- 前身となったソフトウェアの名称 --> | 後継 = <!-- 後継ソフトウェアの名称 --> | 業種 = <!-- Industry --> | 会員登録 = <!-- Registration --> }} '''JAX'''は、高速な[[数値計算]]と大規模な[[機械学習]]のために設計された[[Python]]の[[オープンソース]]のライブラリ<ref>{{cite web |url=https://github.com/jax-ml/jax/blob/main/README.md |title=jax/README.md at main · jax-ml/jax |newspaper=Https: |date= |author= |accessdate= December 21, 2024}}</ref>。[[NumPy]]風の構文で書かれたPythonの[[ソースコード]]を[[CPU]]・[[Graphics Processing Unit|GPU]]・[[AIアクセラレータ]]<ref>{{Cite web |title=Installation — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=21 December 2024 |url= https://jax.readthedocs.io/en/latest/installation.html}}</ref>へ[[コンパイラ|コンパイル]]する[[実行時コンパイラ]]や[[自動微分]]などを含む。 実行時コンパイラは、JAXからOpenXLAの[[XLA (コンパイラ)|XLA]]にコンパイルし、そこから先はハードウェア次第だが、多くのCPUとGPUは[[LLVM]]を経由してコンパイルされる<ref>{{cite web |url=https://openxla.org/xla/architecture |title=XLA architecture |newspaper=Https: |date= |author= |accessdate= December 21, 2024}}</ref>。 == 基本的な使用方法 == 下記のソースコードのように、関数に @jit を付けることにより、その部分が実行時コンパイルされる。同一のソースコードで、CPUだけでなく、GPUやAIアクセラレータでも動作させることが可能である。詳細は後述するが、@jitの中に書けるのは普通のPythonのプログラムではなく、Pythonの構文を使用した純粋[[関数型プログラミング|関数型言語]]である。 <syntaxhighlight lang="python"> import jax.numpy as jnp from jax import jit @jit def f(a, b): return a + b x = jnp.array([1, 2, 3], dtype=jnp.float32) print(f(x, x)) </syntaxhighlight> map を自動[[ベクトル化]]した vmap があり、{{code|code=a * 2|lang=python}} をあえて vmap を使用して書いた場合、下記のように書ける。[[SIMD]]を活用したプログラムにコンパイルされる。<ref>{{Cite web |title=Automatic vectorization — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=21 December 2024 |url= https://jax.readthedocs.io/en/latest/automatic-vectorization.html}}</ref> <syntaxhighlight lang="python"> from jax import jit, vmap @jit def f(a): return vmap(lambda x: x * 2)(a) </syntaxhighlight> == Numbaとの違い == 似たようなライブラリとして[[Numba]]があるが、以下の違いがある。純粋関数型にすることにより色々な[[コンパイラ最適化|最適化]]がかかっている。関数型言語としての分類は、純粋、[[先行評価|正格評価]]、型を明示する必要が無い[[静的型付け]]である。 {| class="wikitable" |+ ! 相違点 ! JAX ! Numba |- ! 設計思想 | 純粋[[関数型プログラミング|関数型]]。配列は[[イミュータブル|不変]]で、形状(shape)はコンパイル時に静的に確定してないといけない。<ref>{{Cite web |title=🔪 JAX - The Sharp Bits 🔪 — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=21 December 2024 |url= https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html}}</ref><ref>{{Cite web |title=How to think in JAX — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=28 December 2024 |url= https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html |quote=Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time. }}</ref> | [[手続き型プログラミング|手続き型]]。配列の破壊的操作が可能。 |- ! nowrap | if,match,while,for文 | 利用不可。代用関数が用意されている。 | 利用可能<ref>{{Cite web |title=Supported Python features — Numba documentation |author= |work=numba.readthedocs.io |date= |access-date=22 December 2024 |url= https://numba.readthedocs.io/en/stable/reference/pysupported.html}}</ref> |- ! nowrap | 対象ハードウェア | CPU・GPU・AIアクセラレータ全てで同一のソースコードで可能。 | CPUとNVIDIA [[CUDA]]に対応しているが、全く異なるソースコードが必要。<ref>{{Cite web |title=Writing CUDA Kernels — Numba documentation |author= |work=numba.readthedocs.io |date= |access-date=21 December 2024 |url= https://numba.readthedocs.io/en/stable/cuda/kernels.html}}</ref> |- ! [[自動微分]] | 対応<ref>{{Cite web |title=Automatic differentiation — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=21 December 2024 |url= https://jax.readthedocs.io/en/latest/automatic-differentiation.html}}</ref> | 非対応 |} 純粋関数型であるため、乱数を使用する際に、下記のように、乱数生成のキーを明示的に作り直さないといけない。<ref>{{Cite web |title=Pseudorandom numbers — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=21 December 2024 |url= https://jax.readthedocs.io/en/latest/random-numbers.html}}</ref> <syntaxhighlight lang="python"> key, subkey = jax.random.split(key) x = jax.random.normal(subkey) </syntaxhighlight> 配列を書き換える際は、手続き型では {{code|code=x[10] = 20|lang=python}} で良い場合も、 {{code|code=y = x.at[10].set(20)|lang=python}} という構文になり、x と y は異なるインスタンスになる。ただし、以後 x を使用しない場合は、x に破壊的書き換えして y とする最適化が実行される。<ref>{{Cite web |title=jax.numpy.ndarray.at — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=21 December 2024 |url= https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at}}</ref> === if文とmatch文 === JAXではPythonのif文とmatch文は基本的にはそのままでは使用できない。下記が用意されている。 * jax.lax.cond: Pythonのif文に対応するもので、例えば {{code|code=cond(x == 0, lambda: 10, lambda: 20)|lang=python}} の様に使用し、True/Falseに応じてlambda式が実行される。JAXは[[先行評価|正格評価]]の関数型言語のため、True/Falseが決まった後に分岐先の値を[[遅延評価]]するためにlambda式の中に入れる。<ref>{{Cite web |title=jax.lax.cond — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=22 December 2024 |url= https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html}}</ref> * jax.lax.switch: condを3択以上に出来るようにした物で、例えば {{code|code=switch(x, (lambda: 10, lambda: 20, lambda: 30))|lang=python}} の様に使用する。<ref>{{Cite web |title=jax.lax.switch — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=22 December 2024 |url= https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html}}</ref> * jax.lax.select: boolean配列に対してif文を使用する物で、例えば、xが配列の時 {{code|code=select(x == 0, jnp.array([1, 2]), jnp.array([3, 4]))|lang=python}} の様に使用し、{{code|code=x == 0|lang=python}} が True/False に応じて各要素が振り分けられる。<ref>{{Cite web |title=jax.lax.select — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=22 December 2024 |url= https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html}}</ref> * jax.lax.select_n: select を swtich の様に3択以上に出来るようにした物。<ref>{{Cite web |title=jax.lax.select_n — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=22 December 2024 |url= https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select_n.html}}</ref> === while文とfor文 === JAXではPythonのwhile文とfor文は基本的にはそのままでは使用できず、ループ回数が定数の場合でPythonのfor文をそのまま使用した場合は、ループアンロールされる。<ref>{{Cite web |title=Control flow and logical operators with JIT — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=22 December 2024 |url= https://jax.readthedocs.io/en/latest/control-flow.html}}</ref> ループ構造を作るものとして下記が用意されている。 * 関数型言語の [[高階関数#fold|fold]] 相当:jax.lax.fori_loop<ref>{{Cite web |title=jax.lax.fori_loop — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=22 December 2024 |url= https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html}}</ref> と jax.lax.scan<ref>{{Cite web |title=jax.lax.scan — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=21 December 2024 |url= https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html}}</ref> * 関数型言語の [[高階関数#unfold|unfold]] 相当:jax.lax.while_loop<ref>{{Cite web |title=jax.lax.while_loop — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=22 December 2024 |url= https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html}}</ref> * 関数型言語の [[高階関数#map|map]] 相当:jax.vmap と jax.lax.map<ref>{{Cite web |title=jax.lax.map — JAX documentation |author= |work=jax.readthedocs.io |date= |access-date=23 December 2024 |url= https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.map.html}}</ref> 純粋関数型のため、scan, fori_loop, while_loop は全て前の計算結果を次に渡すという形となっている。 == 自動微分 == jax.grad にて[[自動微分]]できる。例えば、[[最急降下法]]は下記で実装できる。init_x から始めて、fori_loop にて iter_count 回、計算を反復している。<math>(x - 1)^2</math> が最小となるx、つまり1を求めている。 <syntaxhighlight lang="python"> from jax import jit, grad from jax.lax import fori_loop f = lambda x: (x - 1) ** 2 @jit def gradient_descent(init_x, iter_count, learn_rate): return fori_loop(0, iter_count, lambda i, x: x - learn_rate * grad(f)(x), init_x) print(gradient_descent(0.0, 30, 0.3)) </syntaxhighlight> == 参照 == {{reflist}} == 関連項目 == * [[Numba]] == 外部リンク == * {{official|https://jax.readthedocs.io/}} {{Python}} [[Category:Pythonライブラリ]] [[Category:オープンソース人工知能]] [[Category:数値解析ソフトウェア]] [[Category:ディープラーニング]]
このページで使用されているテンプレート:
テンプレート:Cite web
(
ソースを閲覧
)
テンプレート:Code
(
ソースを閲覧
)
テンプレート:Infobox Software
(
ソースを閲覧
)
テンプレート:Official
(
ソースを閲覧
)
テンプレート:Python
(
ソースを閲覧
)
テンプレート:Reflist
(
ソースを閲覧
)
JAX (ライブラリ)
に戻る。
ナビゲーション メニュー
個人用ツール
ログイン
名前空間
ページ
議論
日本語
表示
閲覧
ソースを閲覧
履歴表示
その他
検索
案内
メインページ
最近の更新
おまかせ表示
MediaWiki についてのヘルプ
特別ページ
ツール
リンク元
関連ページの更新状況
ページ情報