Skip to content

Commit

Permalink
[fix] Fix CI
Browse files Browse the repository at this point in the history
- Run rustfmt
- Fix regex version to be compatible with rustc 1.63.0
  • Loading branch information
Xeratec committed Oct 10, 2023
1 parent d3d351b commit 45e7a44
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 14 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pest = "2.1.3"
pest_derive = "2.1.0"
ndarray = "0.13"
pretty_env_logger = "0.4"
regex = "~1.9.6"
rev_slice = "0.1.5"
serde = { version = "1.0.123", features = ["derive"] }
serde_json = "1.0.63"
Expand Down
77 changes: 63 additions & 14 deletions src/peripherals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,14 @@ impl Peripheral for MemPoolITA {
0x0C => self.eps_mul_1.store(value, Ordering::SeqCst),
0x10 => self.right_shift_0.store(value, Ordering::SeqCst),
0x14 => self.right_shift_1.store(value, Ordering::SeqCst),
0x18 => unsafe {self.add_0.store(std::mem::transmute::<u32, i32>(value), Ordering::SeqCst)},
0x1C => unsafe {self.add_1.store(std::mem::transmute::<u32, i32>(value), Ordering::SeqCst)},
0x18 => unsafe {
self.add_0
.store(std::mem::transmute::<u32, i32>(value), Ordering::SeqCst)
},
0x1C => unsafe {
self.add_1
.store(std::mem::transmute::<u32, i32>(value), Ordering::SeqCst)
},
_ => unimplemented!(),
}
}
Expand Down Expand Up @@ -514,9 +520,12 @@ impl MemPoolITA {
let rqs_shift = u64::to_le_bytes((_right_shift_1 as u64) << 32 | (_right_shift_0 as u64));
let rqs_add = u64::to_le_bytes((_add_1 as u64) << 32 | (_add_0 as u64)).map(|c| c as i8);

debug!("[ITA] Start Address 0x{:x}, Out Address 0x{:x}", start, out_address);
debug!(
"[ITA] Start Address 0x{:x}, Out Address 0x{:x}",
start, out_address
);
debug!("[ITA] RQS Mult {:?}", rqs_mult);
debug!("[ITA] RQS Shift {:?}",rqs_shift);
debug!("[ITA] RQS Shift {:?}", rqs_shift);
debug!("[ITA] RQS Add {:?}", rqs_add);

let mut q = Array2::<i8>::zeros((64, 64));
Expand Down Expand Up @@ -572,7 +581,13 @@ impl MemPoolITA {
MemPoolITA::projection_space_transformation(&mut q_p, &mut q, &mut w_q, &mut b_q, 1);
// requantization of q_p
let mut q_p_requant = Array3::<i8>::zeros((1, 64, 64));
MemPoolITA::requantization_3d(&mut q_p, &mut q_p_requant, rqs_mult[0], rqs_shift[0], rqs_add[0]);
MemPoolITA::requantization_3d(
&mut q_p,
&mut q_p_requant,
rqs_mult[0],
rqs_shift[0],
rqs_add[0],
);
// debug!("q_p_requant: {}", q_p_requant);

// key_projection_space_transformation
Expand All @@ -583,14 +598,26 @@ impl MemPoolITA {
MemPoolITA::projection_space_transformation(&mut k_p, &mut k, &mut w_k, &mut b_k, 1);
// requantization of k_p
let mut k_p_requant = Array3::<i8>::zeros((1, 64, 64));
MemPoolITA::requantization_3d(&mut k_p, &mut k_p_requant, rqs_mult[1], rqs_shift[1], rqs_add[1]);
MemPoolITA::requantization_3d(
&mut k_p,
&mut k_p_requant,
rqs_mult[1],
rqs_shift[1],
rqs_add[1],
);
// debug!("k_p_requant: {}", k_p_requant);

// query_key_correlation
let mut qk = Array3::<i32>::zeros((1, 64, 64));
MemPoolITA::query_key_correlation(&mut q_p_requant, &mut k_p_requant, &mut qk);
// requantization of qk
MemPoolITA::requantization_3d(&mut qk, &mut a_requant, rqs_mult[2], rqs_shift[2], rqs_add[2]);
MemPoolITA::requantization_3d(
&mut qk,
&mut a_requant,
rqs_mult[2],
rqs_shift[2],
rqs_add[2],
);
// debug!("a_requant: {}", a_requant);

// streaming_partial_softmax
Expand All @@ -601,7 +628,13 @@ impl MemPoolITA {
MemPoolITA::projection_space_transformation(&mut v_p, &mut v, &mut w_v, &mut b_v, 1);
// requantization of v_p
let mut v_p_requant = Array3::<i8>::zeros((1, 64, 64));
MemPoolITA::requantization_3d(&mut v_p, &mut v_p_requant, rqs_mult[3], rqs_shift[3], rqs_add[3]);
MemPoolITA::requantization_3d(
&mut v_p,
&mut v_p_requant,
rqs_mult[3],
rqs_shift[3],
rqs_add[3],
);
// debug!("v_p_requant: {}", v_p_requant);

// single_head_computation
Expand All @@ -613,14 +646,26 @@ impl MemPoolITA {
);
// requantization of o_softmax
let mut o_softmax_requant = Array3::<i8>::zeros((1, 64, 64));
MemPoolITA::requantization_3d(&mut o_softmax, &mut o_softmax_requant, rqs_mult[4], rqs_shift[4], rqs_add[4]);
MemPoolITA::requantization_3d(
&mut o_softmax,
&mut o_softmax_requant,
rqs_mult[4],
rqs_shift[4],
rqs_add[4],
);
// debug!("o_softmax_requant: {}", o_softmax_requant);

// multi_head_computation
MemPoolITA::multi_head_computation(&mut o_softmax_requant, &mut out, &mut w_o, &mut b_o, 1);
// parallel requantization of out
let mut out_requant = Array2::<i8>::zeros((64, 64));
MemPoolITA::parallel_requantize3d(&mut out, &mut out_requant, rqs_mult[5], rqs_shift[5], rqs_add[5]);
MemPoolITA::parallel_requantize3d(
&mut out,
&mut out_requant,
rqs_mult[5],
rqs_shift[5],
rqs_add[5],
);
// debug!("out_requant: {}", out_requant);

// for j in 0..out_requant.shape()[1] {
Expand All @@ -636,7 +681,9 @@ impl MemPoolITA {
let mut shifted = ((element * (eps_mult as i32)) >> (right_shift as i32)) + (add as i32);

// Perform rounding half away from zero
if right_shift > 0 && ((element * (eps_mult as i32)) >> ((right_shift-1) as i32)) & 0x1 == 1 {
if right_shift > 0
&& ((element * (eps_mult as i32)) >> ((right_shift - 1) as i32)) & 0x1 == 1
{
shifted = shifted.saturating_add(1);
}
if shifted > 127 {
Expand Down Expand Up @@ -686,13 +733,15 @@ impl MemPoolITA {
let row = m.slice(s![i, j, ..]);
for k in 0..row.len() {
let mut shifted = ((row[k] * (eps_mult as i32)) >> (right_shift as i32))
+ m_requant[[i * m.shape()[1] + j, k]] as i32;
+ m_requant[[i * m.shape()[1] + j, k]] as i32;

// Perform rounding half away from zero
if right_shift > 0 && ((row[k] * (eps_mult as i32)) >> ((right_shift-1) as i32)) & 0x1 == 1 {
if right_shift > 0
&& ((row[k] * (eps_mult as i32)) >> ((right_shift - 1) as i32)) & 0x1 == 1
{
shifted = shifted.saturating_add(1);
}
m_requant[[i * m.shape()[1] + j, k]] =
m_requant[[i * m.shape()[1] + j, k]] =
MemPoolITA::requantize_row(shifted, 1, 0, 0);
}
}
Expand Down

0 comments on commit 45e7a44

Please sign in to comment.