Skip to content

Commit

Permalink
Adding blas dep for accelerate.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Mar 4, 2023
1 parent 409c640 commit 7ad72e4
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
16 changes: 16 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ thiserror = "1.0"
tokenizers = { git = "https://github.com/huggingface/tokenizers", branch="main", default-features=false, features=["onig"] }
tokio = { version = "1.23.0", features = ["rt-multi-thread", "macros"] }
cblas-sys = { version = "0.1.4", default-features = false, optional = true }
blas-src = { version = "0.7", default-features = false, optional = true }
libc = { version = "0.2", default-features = false, optional = true }
tracing-subscriber = "0.3.16"
axum = "0.6.3"
Expand All @@ -34,6 +35,7 @@ lazy_static = "1.4.0"
[features]
default = ["gpt2"]
cblas = ["dep:cblas-sys", "dep:libc"]
accelerate = ["blas-src/accelerate", "dep:libc"]
intel-mkl = ["cblas"]
dfdx_intel = ["dfdx/intel-mkl"]
dfdx_nightly = ["dfdx/nightly"]
Expand Down
23 changes: 23 additions & 0 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ use cblas_sys::{
CblasRowMajor as RowMajor, CblasTrans as Tr,
};

#[cfg(feature = "blas")]
use blas_src::{
sgemm, CblasColMajor as ColMajor, CblasNoTrans as NoTr, CblasRowMajor as RowMajor,
CblasTrans as Tr,
};

#[inline]
pub fn addmm<X: Tensor, A: Tensor, B: Tensor, TM: TensorMut>(x: &X, a: &A, b: &B, out: &mut TM) {
let m = x.shape()[0];
Expand Down Expand Up @@ -108,6 +114,23 @@ pub fn g_matmul<const TRANSPOSE: bool, A: Tensor, B: Tensor, TM: TensorMut>(
layout, a_tr, b_tr, m, n, k, 1.0, ap, lda, bp, ldb, 1.0, cp, ldc,
)
}

#[cfg(feature = "blas")]
unsafe {
let (m, n, k) = (m as libc::c_int, n as libc::c_int, k as libc::c_int);
let (layout, a_tr, b_tr, lda, ldb, ldc) = if cr < cc {
let (lda, a_tr) = if ar < ac { (m, NoTr) } else { (k, Tr) };
let (ldb, b_tr) = if br < bc { (k, NoTr) } else { (n, Tr) };
(ColMajor, a_tr, b_tr, lda, ldb, m)
} else {
let (lda, a_tr) = if ar < ac { (m, Tr) } else { (k, NoTr) };
let (ldb, b_tr) = if br < bc { (k, Tr) } else { (n, NoTr) };
(RowMajor, a_tr, b_tr, lda, ldb, n)
};
sgemm(
layout, a_tr, b_tr, m, n, k, 1.0, ap, lda, bp, ldb, 1.0, cp, ldc,
)
}
});
}

Expand Down

0 comments on commit 7ad72e4

Please sign in to comment.