Jax 簡介
在上一篇文章中,我們分享了 Jax 的概念,這是乙個來自 Google 的相對較新的機器學習庫。 它更像是乙個 autograd 庫,用於區分每個原生 python 和 numpy
python+numpy 程式的可組合轉換:差分、向量化、JIT 到 GPU TPU 等等”。 該庫利用 grad 函式轉換將函式轉換為返回原始函式梯度的函式。 JAX 還提供了函式轉換 JIT,用於對現有函式進行即時編譯,以及分別用於向量化和並行化的 VMAP 和 PMAP
JAX 是 Autograd 和 XLA 的結合體,JAX 本身不是乙個深度學習框架,它是乙個高效能的數值計算庫,也是乙個用於高效能機器學習研究的可組合函式變換庫。 深度學習只是故事的一部分,但您可以將自己的深度學習移植到 Jax。
自 2018 年底谷歌的 JAX 問世以來,它的受歡迎程度一直在穩步增長。 DeepMind 在 202 年宣布正在使用 JAX 來加速自己的研究,越來越多的來自 Google Brain 和其他專案的專案也在使用 JAX。 隨著 Jax 越來越受歡迎,Jax 似乎是下乙個大型深度學習框架雖然 Jax 不是乙個神經網路框架,但隨著 Jax 的發展,許多深度學習相關的研究也可以使用 Jax 來實現。
在上一篇文章中,我們還分享了 JAX 和 NUMPY 的速度對比,相比沒有 JAX 加速的 NUMPY,其速度遠遠落後於 JAX,本期我們將使用 JAX 來訓練第乙個機器學習模型。
使用 JAX 訓練您的第乙個機器學習模型
在使用 Jax 之前,我們需要安裝 Jax,幸運的是,Jax 可以用 pip 安裝,但是 Jax 目前在 Windows 平台上不可用,可以使用 Linux 虛擬機器來體驗。
pip install jaxpip install autogradpip install numpypip install jaxlib
首先,我們需要安裝jax等相關的第三方庫,並匯入相關的第三方庫。
import numpy as npimport jax.random as randomimport jaxfrom jax import numpy as jnpfrom jax import make_jaxprfrom jax import grad, jit, vmap, pmapimport matplotlib.pyplot as plt
然後我們構建乙個 y=ax+b 的線性函式,其中引數 a 是直線的斜率,b 是直線在 y 軸方向上的運動引數,並使用隨機隨機函式生成乙個隨機的 x 資料,這樣我們得到乙個完整的 y=ax+b 線性函式,我們可以使用 matplotlib 來顯示這個函式的曲線。
key = random.prngkey(56)x = random.normal(key, shape=(128, 1))a = 3.0b = 5.0ys = (a*xs) +bplt.scatter(xs, ys)plt.xlabel("xs")plt.ylabel("ys")plt.title("linear f(x)")plt.show()
執行上述 **,我們得到 y=ax+b 的線性函式。
利用上面的線性函式,我們將構建乙個線性模型,並使用機器學習來**這條直線。
def linear(theta, x): weight, bias = theta pred = x * weight + bias return pred
然後我們定義乙個線性函式,這個函式還有2個引數,乙個權重(weight),乙個偏置(bias),訓練的目的是找到乙個合適的權重和偏置引數,以便**得到上面的線性函式。 當然,我們還需要設定乙個損失函式,以便在以後的訓練中損失逐漸減少。 這裡使用均方差作為損失函式來計算**值與真實值的損失。
def p_loss(theta, x, y): pred = linear(theta, x) loss = jnp.mean((y - pred)**2) return loss@jitdef update_step(theta, x, y, lr): loss, gradient = jax.value_and_grad(p_loss)(theta, x, y) updated_theta = theta - lr * gradient return updated_theta, loss
然後使用 jaxvalue 和 grad 函式來更新損失,lr 引數是神經網路的學習效率,這裡我們可以隨機取乙個相對較小的值。 通過上述函式,我們可以訓練乙個機器學習模型。
weight = 0.0bias = 0.0theta = jnp.array([weight, bias])epochs = 20000for item in range(epochs): theta, loss_p = update_step(theta, xs, ys, 1e-4) if item % 1000 == 0 and item != 0: print(f"item |loss ")
我們初始化權重和偏置引數,並使用for迴圈訓練神經網路,使損失越來越小,這裡我們每1000步列印一次損失引數。
item 1000 | loss 23.4526item 2000 | loss 15.4000item 3000 | loss 10.1152item 4000 | loss 6.6459item 5000 | loss 4.3678item 6000 | loss 2.8714item 7000 | loss 1.8883item 8000 | loss 1.2422item 9000 | loss 0.8174item 10000 | loss 0.5380item 11000 | loss 0.3543item 12000 | loss 0.2333item 13000 | loss 0.1538item 14000 | loss 0.1013item 15000 | loss 0.0668item 16000 | loss 0.0441item 17000 | loss 0.0291item 18000 | loss 0.0192item 19000 | loss 0.0127
從上面的損失引數可以看出,模型的損失逐漸減小,說明我們設計的線性機器學習模型是有效的。 我們還可以在訓練 20,000 步後列印模型的輸出函式。
plt.scatter(xs, ys, label="true")plt.scatter(xs, linear(theta, xs), label="pred")plt.legend()plt.show()
可以看出,隨著訓練的進行,模型的損失逐漸減小,當訓練為20000步時,其Y=AX+B函式幾乎與輸入的初始函式值重合,當然也可以增加訓練步數,使損失再次縮小。
雖然 Jax 目前還不叫神經網路模型框架,但隨著 PyTorch、PaddlePaddle、Mindspore 相關框架的加入,神經網路框架的爭議愈演愈烈,谷歌是否會將 Jax 發展成下一代神經網路框架還不確定。