Skip to content

Commit

Permalink
[mxu] fix handshaking for accumulator latches
Browse files Browse the repository at this point in the history
  • Loading branch information
Navaneeth-KunhiPurayil committed Dec 14, 2024
1 parent 296ecf1 commit 85deafe
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
45 changes: 29 additions & 16 deletions hw/ip/spatz/src/spatz_mxu.sv
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ module spatz_mxu

logic [2:0] mx_read_enable;
logic mx_write_enable_d, mx_write_enable_q;
logic result_ready;

vrf_data_t operand1;
vrf_data_t operand2;
Expand All @@ -55,8 +56,8 @@ module spatz_mxu

//Accumulator signal
vrf_data_t [NrACCBanks-1:0] accu_result_q;
logic [NrACCBanks-1:0] accu_result_valid, accu_result_valid_d, accu_result_valid_q;
vrf_data_t wdata_q;
logic [NrACCBanks-1:0] accu_result_valid_d, accu_result_valid_q;
vrf_data_t wdata_d, wdata_q;
logic [NrACCBanks-1:0] waddr_onehot;

logic load_vd;
Expand Down Expand Up @@ -106,12 +107,14 @@ module spatz_mxu
accu_result_valid_q <= '0;
mx_to_write_vrf_q <= 1'b0;
write_cnt_q <= '0;
wdata_q <= '0;
end else begin
mxu_cnt_en_q <= mxu_cnt_en_d;
mxu_cnt_q <= mxu_cnt_d;
accu_result_valid_q <= accu_result_valid_d;
mx_to_write_vrf_q <= mx_to_write_vrf_d;
write_cnt_q <= write_cnt_d;
wdata_q <= wdata_d;
end
end

Expand All @@ -128,15 +131,16 @@ module spatz_mxu
// Save operands_i as previous operands every time we get new operands and we are starting a new col
`FFL(previous_operands_q, previous_operands_d, enable_mx_i && &operands_ready_i[1:0] && part_col == 0, '0)

always_ff @(posedge clk_i) begin: proc_wdata_q
// Save the FPU result in a FF before going into the latch
wdata_q <= result_i;
end: proc_wdata_q
always_comb begin
wdata_d = wdata_q;
if (result_valid_i & result_ready_o)
wdata_d = result_i;
end

// Select which destination bytes to write into
for (genvar accreg = 0; accreg < NrACCBanks; accreg++) begin: gen_waddr_onehot
// Create latch clock signal
assign waddr_onehot[accreg] = enable_mx_i && result_valid_i && accreg == part_acc;
assign waddr_onehot[accreg] = enable_mx_i && result_valid_i && result_ready_o && accreg == part_acc;
end: gen_waddr_onehot

// Store result into accumulator
Expand Down Expand Up @@ -168,24 +172,27 @@ module spatz_mxu
acc_counter_d = acc_counter_q;
col_counter_d = col_counter_q;

num_cols = tile_dimN;
num_rows = tile_dimM;
if(enable_mx_i) begin
num_cols = tile_dimN;
num_rows = tile_dimM;
// Todo: parametrize me
// Load as many vd as (M / OperandsPerVRFFetch)
load_vd = num_rows == 4 ? vl_i <= 0 : vl_i <= 4;
// mtx_A row counter from 0 to M
part_col = num_cols == 4 ? num_rows == 4 ? col_counter_q[1:0] : col_counter_q[2:0] : col_counter_q;
// Accumulator counter from 0 to M
part_acc = num_cols == 4 ? num_rows == 4 ? acc_counter_q[1:0] : acc_counter_q[2:0] : acc_counter_q;
end

// Accumulator counter from 0 to M
// Can get result for accumulator even when all instructions are done due to latency, hence outside the enable_mx_i block
part_acc = num_cols == 4 ? num_rows == 4 ? acc_counter_q[1:0] : acc_counter_q[2:0] : acc_counter_q;

// Update acc and col counters
if (result_valid_i & result_ready_o) begin
// For every accumulation result handshake update acc_counter
if (result_valid_i & result_ready_o & (enable_mx_i | write_enable_o)) begin
acc_counter_d = acc_counter_q + 1'b1;
end

if (enable_mx_i & ipu_en) begin
// For every useful FPU cycle, go to the next element of A column
if (ipu_en) begin
col_counter_d = col_counter_q + 1'b1;
end

Expand Down Expand Up @@ -253,6 +260,9 @@ module spatz_mxu
mx_to_write_vrf_d = mx_to_write_vrf_q;
write_cnt_d = write_cnt_q;

// Check if accumulator data is valid if so cannot receive new result
result_ready = ~accu_result_valid_q[part_acc];

// Write back into the VRF if we have processed all the words
// and we got a valid result
if (~mx_to_write_vrf_q) begin
Expand Down Expand Up @@ -286,8 +296,11 @@ module spatz_mxu
// Enable a read if we need an operand
// If the accumulators are to be used i.e. load_vd=1'b0 then also check that the accumulators have a valid result
mx_read_enable[1:0] = (load_vd | (~load_vd & accu_result_valid_q[part_col])) ? {2{part_col == 0 || part_col == 4}} : 2'b0;
if ((~load_vd & accu_result_valid_q[part_col] & ipu_en))
if ((~load_vd & accu_result_valid_q[part_col] & ipu_en)) begin
// If accumulator data is being used, become available to take in new data
accu_result_valid_d[part_col] = 1'b0;
result_ready = 1'b1;
end
mx_read_enable[2] = load_vd;
end
end
Expand Down Expand Up @@ -320,7 +333,7 @@ module spatz_mxu
// result_i : fpu_result_o
// result_valid_i : fpu_valid_o
// result_ready_o : fpu_ready_i
assign result_ready_o = enable_mx_i ? (|mx_read_enable ? |operands_ready_i : '1) || vrf_wvalid_i : '0;
assign result_ready_o = result_ready;

////////////////
// MXU -> VRF //
Expand Down
2 changes: 1 addition & 1 deletion hw/ip/spatz/src/spatz_vfu.sv
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ module spatz_vfu
vlen_t mx_offset ;
logic mx_result_ready ;

assign op_is_mx = (spatz_req.op_arith.is_mx && spatz_req_valid) || mx_write_enable;
assign op_is_mx = (spatz_req.op_arith.is_mx && spatz_req_valid);
assign clear_mxu_state = spatz_req_valid & spatz_req_ready;

// Vector length counter
Expand Down

0 comments on commit 85deafe

Please sign in to comment.