TensorFlow 會不會被 JAX 取代,使用 JAX 來訓練第乙個機器學習模型?

Mondo 科技 更新 2024-01-30

Jax 簡介

在上一篇文章中,我們分享了 Jax 的概念,這是乙個來自 Google 的相對較新的機器學習庫。 它更像是乙個 autograd 庫,用於區分每個原生 python 和 numpy

python+numpy 程式的可組合轉換:差分、向量化、JIT 到 GPU TPU 等等”。 該庫利用 grad 函式轉換將函式轉換為返回原始函式梯度的函式。 JAX 還提供了函式轉換 JIT,用於對現有函式進行即時編譯,以及分別用於向量化和並行化的 VMAP 和 PMAP

JAX 是 Autograd 和 XLA 的結合體,JAX 本身不是乙個深度學習框架,它是乙個高效能的數值計算庫,也是乙個用於高效能機器學習研究的可組合函式變換庫。 深度學習只是故事的一部分,但您可以將自己的深度學習移植到 Jax。

自 2018 年底谷歌的 JAX 問世以來,它的受歡迎程度一直在穩步增長。 DeepMind 在 202 年宣布正在使用 JAX 來加速自己的研究,越來越多的來自 Google Brain 和其他專案的專案也在使用 JAX。 隨著 Jax 越來越受歡迎,Jax 似乎是下乙個大型深度學習框架雖然 Jax 不是乙個神經網路框架,但隨著 Jax 的發展,許多深度學習相關的研究也可以使用 Jax 來實現。

在上一篇文章中,我們還分享了 JAX 和 NUMPY 的速度對比,相比沒有 JAX 加速的 NUMPY,其速度遠遠落後於 JAX,本期我們將使用 JAX 來訓練第乙個機器學習模型。

使用 JAX 訓練您的第乙個機器學習模型

在使用 Jax 之前,我們需要安裝 Jax,幸運的是,Jax 可以用 pip 安裝,但是 Jax 目前在 Windows 平台上不可用,可以使用 Linux 虛擬機器來體驗。

pip install jaxpip install autogradpip install numpypip install jaxlib

首先,我們需要安裝jax等相關的第三方庫,並匯入相關的第三方庫。

import numpy as npimport jax.random as randomimport jaxfrom jax import numpy as jnpfrom jax import make_jaxprfrom jax import grad, jit, vmap, pmapimport matplotlib.pyplot as plt

然後我們構建乙個 y=ax+b 的線性函式,其中引數 a 是直線的斜率,b 是直線在 y 軸方向上的運動引數,並使用隨機隨機函式生成乙個隨機的 x 資料,這樣我們得到乙個完整的 y=ax+b 線性函式,我們可以使用 matplotlib 來顯示這個函式的曲線。

key = random.prngkey(56)x = random.normal(key, shape=(128, 1))a = 3.0b = 5.0ys = (a*xs) +bplt.scatter(xs, ys)plt.xlabel("xs")plt.ylabel("ys")plt.title("linear f(x)")plt.show()

執行上述 **,我們得到 y=ax+b 的線性函式。

利用上面的線性函式,我們將構建乙個線性模型,並使用機器學習來**這條直線。

def linear(theta, x): weight, bias = theta pred = x * weight + bias return pred

然後我們定義乙個線性函式,這個函式還有2個引數,乙個權重(weight),乙個偏置(bias),訓練的目的是找到乙個合適的權重和偏置引數,以便**得到上面的線性函式。 當然,我們還需要設定乙個損失函式,以便在以後的訓練中損失逐漸減少。 這裡使用均方差作為損失函式來計算**值與真實值的損失。

def p_loss(theta, x, y): pred = linear(theta, x) loss = jnp.mean((y - pred)**2) return loss@jitdef update_step(theta, x, y, lr): loss, gradient = jax.value_and_grad(p_loss)(theta, x, y) updated_theta = theta - lr * gradient return updated_theta, loss

然後使用 jaxvalue 和 grad 函式來更新損失,lr 引數是神經網路的學習效率,這裡我們可以隨機取乙個相對較小的值。 通過上述函式,我們可以訓練乙個機器學習模型。

weight = 0.0bias = 0.0theta = jnp.array([weight, bias])epochs = 20000for item in range(epochs): theta, loss_p = update_step(theta, xs, ys, 1e-4) if item % 1000 == 0 and item != 0: print(f"item |loss ")

我們初始化權重和偏置引數,並使用for迴圈訓練神經網路,使損失越來越小,這裡我們每1000步列印一次損失引數。

item 1000 | loss 23.4526item 2000 | loss 15.4000item 3000 | loss 10.1152item 4000 | loss 6.6459item 5000 | loss 4.3678item 6000 | loss 2.8714item 7000 | loss 1.8883item 8000 | loss 1.2422item 9000 | loss 0.8174item 10000 | loss 0.5380item 11000 | loss 0.3543item 12000 | loss 0.2333item 13000 | loss 0.1538item 14000 | loss 0.1013item 15000 | loss 0.0668item 16000 | loss 0.0441item 17000 | loss 0.0291item 18000 | loss 0.0192item 19000 | loss 0.0127

從上面的損失引數可以看出,模型的損失逐漸減小,說明我們設計的線性機器學習模型是有效的。 我們還可以在訓練 20,000 步後列印模型的輸出函式。

plt.scatter(xs, ys, label="true")plt.scatter(xs, linear(theta, xs), label="pred")plt.legend()plt.show()

可以看出,隨著訓練的進行,模型的損失逐漸減小,當訓練為20000步時,其Y=AX+B函式幾乎與輸入的初始函式值重合,當然也可以增加訓練步數,使損失再次縮小。

雖然 Jax 目前還不叫神經網路模型框架,但隨著 PyTorch、PaddlePaddle、Mindspore 相關框架的加入,神經網路框架的爭議愈演愈烈,谷歌是否會將 Jax 發展成下一代神經網路框架還不確定。

相關問題答案

    叉車執照會被取消嗎?

    叉車執照會被取消嗎?叉車執照是從事叉車作業的人員必須持有的工作許可證,如果持有人違反有關規定或不再從事叉車作業,其叉車執照可能會被吊銷。本文將詳細介紹吊銷叉車執照的相關規定和程式。 叉車證有效期。叉車執照的有效期通常為四年,之後需要進行審查和更新。持牌人違反有關規定或者在有效期內不再從事叉車作業的,...

    政府雇員會被解雇嗎

    根據相關法律規定,在某些法定情況下,員工可能會被解雇。有兩種型別的員工 公務員和有勞動關係的工人。.公務員 根據 公務員法 第二條第一款的規定,本法所稱公務員,是指依法履行公務,納入國家行政編制,由國庫支付工資 福利待遇的工作人員。根據 公務員法 的規定,公務員有下列情形之一的,應當予以辭退 一 經...

    政府雇員會被解雇嗎

    在我國的 工作制度中,沒有 員工 的概念。如果從一般意義上理解,它可以指 公務員 公務員很少聽說過解雇這個詞,但經常聽說解雇。在什麼情況下,公務員會被解雇?公務員是國家的重要公職人員,承擔著管理國家事務和社會公共事務的職責。因此,公務員必須遵紀守法,忠實履行職責,維護國家利益和社會公共利益。但是,如...

    連續多久會虧損

    ST 是國內的乙個特殊類別,通常是指多年虧損的上市公司。投資者應了解ST的含義 原因和風險,以避免不必要的投資風險。.什麼是st ST 是指多年虧損的上市公司,經過特殊處理後,需要加裝 ST 標識。以下是 ST 的一些功能 .ST標誌的含義 ST是Special Treatment的縮寫,在 前面加...

    以色列會被摧毀嗎?

    目前不會毀國,以色列擁有強大的軍事實力和美國的無限支援,當然也得到了美國猶太人的支援,美國的猶太勢力真的是橫掃全球,宗教和以色列的實力根本不是它的對手,雖然在二戰期間幾乎全部被殲滅,但在不到一百年的時間裡,世界上的猶太勢力也不容忽視。但這場不公正的侵略戰爭,導致猶太人在世界上的名聲很差,雖然其他國家...