diff --git a/hw/ip/spatz/src/spatz_mxu.sv b/hw/ip/spatz/src/spatz_mxu.sv index b8b2357..10c0d9e 100644 --- a/hw/ip/spatz/src/spatz_mxu.sv +++ b/hw/ip/spatz/src/spatz_mxu.sv @@ -47,6 +47,7 @@ module spatz_mxu logic [2:0] mx_read_enable; logic mx_write_enable_d, mx_write_enable_q; + logic result_ready; vrf_data_t operand1; vrf_data_t operand2; @@ -55,8 +56,8 @@ module spatz_mxu //Accumulator signal vrf_data_t [NrACCBanks-1:0] accu_result_q; - logic [NrACCBanks-1:0] accu_result_valid, accu_result_valid_d, accu_result_valid_q; - vrf_data_t wdata_q; + logic [NrACCBanks-1:0] accu_result_valid_d, accu_result_valid_q; + vrf_data_t wdata_d, wdata_q; logic [NrACCBanks-1:0] waddr_onehot; logic load_vd; @@ -106,12 +107,14 @@ module spatz_mxu accu_result_valid_q <= '0; mx_to_write_vrf_q <= 1'b0; write_cnt_q <= '0; + wdata_q <= '0; end else begin mxu_cnt_en_q <= mxu_cnt_en_d; mxu_cnt_q <= mxu_cnt_d; accu_result_valid_q <= accu_result_valid_d; mx_to_write_vrf_q <= mx_to_write_vrf_d; write_cnt_q <= write_cnt_d; + wdata_q <= wdata_d; end end @@ -128,15 +131,16 @@ module spatz_mxu // Save operands_i as previous operands every time we get new operands and we are starting a new col `FFL(previous_operands_q, previous_operands_d, enable_mx_i && &operands_ready_i[1:0] && part_col == 0, '0) - always_ff @(posedge clk_i) begin: proc_wdata_q - // Save the FPU result in a FF before going into the latch - wdata_q <= result_i; - end: proc_wdata_q + always_comb begin + wdata_d = wdata_q; + if (result_valid_i & result_ready_o) + wdata_d = result_i; + end // Select which destination bytes to write into for (genvar accreg = 0; accreg < NrACCBanks; accreg++) begin: gen_waddr_onehot // Create latch clock signal - assign waddr_onehot[accreg] = enable_mx_i && result_valid_i && accreg == part_acc; + assign waddr_onehot[accreg] = enable_mx_i && result_valid_i && result_ready_o && accreg == part_acc; end: gen_waddr_onehot // Store result into accumulator @@ -168,24 +172,27 @@ module spatz_mxu acc_counter_d = acc_counter_q; col_counter_d = col_counter_q; + num_cols = tile_dimN; + num_rows = tile_dimM; if(enable_mx_i) begin - num_cols = tile_dimN; - num_rows = tile_dimM; // Todo: parametrize me // Load as many vd as (M / OperandsPerVRFFetch) load_vd = num_rows == 4 ? vl_i <= 0 : vl_i <= 4; // mtx_A row counter from 0 to M part_col = num_cols == 4 ? num_rows == 4 ? col_counter_q[1:0] : col_counter_q[2:0] : col_counter_q; - // Accumulator counter from 0 to M - part_acc = num_cols == 4 ? num_rows == 4 ? acc_counter_q[1:0] : acc_counter_q[2:0] : acc_counter_q; end + // Accumulator counter from 0 to M + // Can get result for accumulator even when all instructions are done due to latency, hence outside the enable_mx_i block + part_acc = num_cols == 4 ? num_rows == 4 ? acc_counter_q[1:0] : acc_counter_q[2:0] : acc_counter_q; + // Update acc and col counters - if (result_valid_i & result_ready_o) begin + // For every accumulation result handshake update acc_counter + if (result_valid_i & result_ready_o & (enable_mx_i | write_enable_o)) begin acc_counter_d = acc_counter_q + 1'b1; end - - if (enable_mx_i & ipu_en) begin + // For every useful FPU cycle, go to the next element of A column + if (ipu_en) begin col_counter_d = col_counter_q + 1'b1; end @@ -253,6 +260,9 @@ module spatz_mxu mx_to_write_vrf_d = mx_to_write_vrf_q; write_cnt_d = write_cnt_q; + // Check if accumulator data is valid if so cannot receive new result + result_ready = ~accu_result_valid_q[part_acc]; + // Write back into the VRF if we have processed all the words // and we got a valid result if (~mx_to_write_vrf_q) begin @@ -286,8 +296,11 @@ module spatz_mxu // Enable a read if we need an operand // If the accumulators are to be used i.e. load_vd=1'b0 then also check that the accumulators have a valid result mx_read_enable[1:0] = (load_vd | (~load_vd & accu_result_valid_q[part_col])) ? {2{part_col == 0 || part_col == 4}} : 2'b0; - if ((~load_vd & accu_result_valid_q[part_col] & ipu_en)) + if ((~load_vd & accu_result_valid_q[part_col] & ipu_en)) begin + // If accumulator data is being used, become available to take in new data accu_result_valid_d[part_col] = 1'b0; + result_ready = 1'b1; + end mx_read_enable[2] = load_vd; end end @@ -320,7 +333,7 @@ module spatz_mxu // result_i : fpu_result_o // result_valid_i : fpu_valid_o // result_ready_o : fpu_ready_i - assign result_ready_o = enable_mx_i ? (|mx_read_enable ? |operands_ready_i : '1) || vrf_wvalid_i : '0; + assign result_ready_o = result_ready; //////////////// // MXU -> VRF // diff --git a/hw/ip/spatz/src/spatz_vfu.sv b/hw/ip/spatz/src/spatz_vfu.sv index 86d1eff..244b684 100644 --- a/hw/ip/spatz/src/spatz_vfu.sv +++ b/hw/ip/spatz/src/spatz_vfu.sv @@ -101,7 +101,7 @@ module spatz_vfu vlen_t mx_offset ; logic mx_result_ready ; - assign op_is_mx = (spatz_req.op_arith.is_mx && spatz_req_valid) || mx_write_enable; + assign op_is_mx = (spatz_req.op_arith.is_mx && spatz_req_valid); assign clear_mxu_state = spatz_req_valid & spatz_req_ready; // Vector length counter