-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add simple regression training example.
Using 'tt' backend to execute training graph (forward/backward graphs) and running rest of the code on 'cpu' backend. Adding simple regression weights update as a test case
- Loading branch information
1 parent
f0e5b08
commit 80c727a
Showing
2 changed files
with
125 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
|
||
from jax import grad, jit, vmap | ||
import jax.numpy as jnp | ||
from sklearn.datasets import make_regression | ||
from sklearn.model_selection import train_test_split | ||
import jax | ||
import os | ||
import sys | ||
import jax._src.xla_bridge as xb | ||
|
||
# Register cpu and tt plugin. tt plugin is registered with higher priority; so | ||
# program will execute on tt device if not specified otherwise. | ||
def initialize(): | ||
backend = "tt" | ||
path = os.path.join(os.path.dirname(__file__), "../build/src/tt/pjrt_plugin_tt.so") | ||
if not os.path.exists(path): | ||
raise FileNotFoundError(f"Could not find tt_pjrt C API plugin at {path}") | ||
|
||
print("Loading tt_pjrt C API plugin", file=sys.stderr) | ||
xb.discover_pjrt_plugins() | ||
|
||
plugin = xb.register_plugin('tt', priority=500, library_path=path, options=None) | ||
print("Loaded", file=sys.stderr) | ||
jax.config.update("jax_platforms", "tt,cpu") | ||
|
||
|
||
# Create random inputs (weights) on cpu and move them to tt device if requested. | ||
def random_input_tensor(shape, key=42, on_device=False): | ||
def random_input(shape, key): | ||
return jax.random.uniform(jax.random.PRNGKey(key), shape=shape) | ||
|
||
jitted_tensor_creator = jax.jit(random_input, static_argnums=[0,1], backend='cpu') | ||
tensor = jitted_tensor_creator(shape, key) | ||
if on_device: | ||
tensor = jax.device_put(tensor, jax.devices()[0]) | ||
return tensor | ||
|
||
|
||
# Predict outcome label. | ||
def predict(params, X): | ||
w, b = params | ||
return X.dot(w) + b | ||
|
||
|
||
# Create a vectorized version of predict function. | ||
batched_predict = vmap(predict, in_axes=(None, 0)) | ||
|
||
|
||
# Calculate loss for give dataset. | ||
def loss(params, X, y): | ||
pred = batched_predict(params, X) | ||
return ((pred-y)**2).mean() | ||
|
||
|
||
def test_simple_regression(): | ||
initialize() | ||
|
||
X,y= make_regression(n_samples = 150, n_features= 2, noise = 5) | ||
y=y.reshape((y.shape[0],1)) | ||
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.15) #Splitting data into train and test | ||
|
||
Weights = random_input_tensor((X_train.shape[1], 1)) | ||
Bias = 0. | ||
l_rate = 0.001 | ||
n_iter = 6000 | ||
size = 127.0 | ||
params = [Weights, Bias] | ||
|
||
gradient = jit(grad(loss), backend="tt") | ||
print(gradient.lower(params, X_train, y_train).as_text()) | ||
|
||
for i in range(n_iter): | ||
dW, db = gradient(params,X_train,y_train) | ||
if i % 10 == 0: | ||
print(f"iteration: {i} {loss(params,X_train,y_train)}") | ||
weights, bias = params | ||
weights -= dW*l_rate | ||
bias-= db*l_rate | ||
params = [weights, bias] | ||
|
||
test_loss = loss(params, X_test, y_test) #Model's Loss on test set | ||
|
||
|
||
if __name__ == "__main__": | ||
test_simple_regression() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from infrastructure import verify_module | ||
|
||
|
||
@pytest.mark.xfail | ||
def test_gradient(): | ||
def simple_gradient(a): | ||
def gradient(a): | ||
return (a ** 2).sum() | ||
|
||
return jax.grad(gradient)(a) | ||
|
||
verify_module(simple_gradient, [(2, 2)]) | ||
|
||
|
||
@pytest.mark.xfail | ||
def test_simple_regression(): | ||
def simple_regression(weights, bias, X, y): | ||
def loss(weights, bias, X, y): | ||
predict = X.dot(weights) + bias | ||
return ((predict - y) ** 2).sum() | ||
|
||
# Compute gradient and update weights | ||
weights -= jax.grad(loss)(weights, bias, X, y) | ||
|
||
return weights | ||
|
||
verify_module(simple_regression, [(1, 2), (1,1), (2, 1), (1, 1)]) | ||
|