Skip to content

Commit

Permalink
[hardware] REFACTORING
Browse files Browse the repository at this point in the history
  • Loading branch information
mp-17 committed Dec 2, 2024
1 parent ef0e531 commit 3cb8614
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 79 deletions.
151 changes: 89 additions & 62 deletions hardware/src/masku/masku.sv
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,23 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(

import cf_math_pkg::idx_width;

// Pointers
// Predication
//
// Remaining elements of the current instruction in the read operand phase
vlen_t masku_pred_cnt_d, masku_pred_cnt_q;
// We need a pointer to which bit on the full VRF word we are reading mask operands from.
logic [idx_width(DataWidth*NrLanes):0] mask_pnt_d, mask_pnt_q;
// We need a pointer to which bit on the full VRF word we are writing results to.
logic [idx_width(DataWidth*NrLanes):0] vrf_pnt_d, vrf_pnt_q;
logic [idx_width(DataWidth*NrLanes):0] masku_pred_pnt_d, masku_pred_pnt_q;

// Remaining elements of the current instruction in the read operand phase
vlen_t read_cnt_d, read_cnt_q;
// Remaining elements of the current instruction in the issue phase
vlen_t issue_cnt_d, issue_cnt_q;
// Remaining elements of the current instruction to be validated in the result queue
vlen_t processing_cnt_d, processing_cnt_q;
// Remaining elements of the current instruction in the commit phase
vlen_t commit_cnt_d, commit_cnt_q;

// We need a pointer to which bit on the full VRF word we are writing during operand compression
logic [idx_width(DataWidth*NrLanes):0] masku_alu_compress_cnt_d, masku_alu_compress_cnt_q;

////////////////
// Operands //
////////////////
Expand Down Expand Up @@ -124,7 +125,6 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
// Control logic
.masku_fu_i ( masku_operand_fu ),
.vinsn_issue_i ( vinsn_issue ),
.vrf_pnt_i ( vrf_pnt_q ),
// Operands coming from lanes
.masku_operand_valid_i ( masku_operand_valid_i ),
.masku_operand_ready_o ( masku_operand_ready_o ),
Expand All @@ -148,8 +148,7 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
.masku_operand_m_seq_o ( masku_operand_m_seq ),
.masku_operand_m_seq_valid_o ( ),
.masku_operand_m_seq_ready_i ( '0 ),
.bit_enable_mask_o ( bit_enable_mask ),
.alu_result_compressed_seq_o ( alu_result_compressed_seq )
.bit_enable_mask_o ( bit_enable_mask )
);

// Local Parameter for mask logical instructions
Expand Down Expand Up @@ -230,7 +229,7 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
logic found_one, found_one_d, found_one_q;

// How many elements we are processing per cycle
logic [idx_width(NrLanes*DataWidth):0] delta_elm_d, delta_elm_q;
logic [idx_width(NrLanes*DataWidth):0] issue_cnt_delta_d, issue_cnt_delta_q;

// MASKU Alu: is a VRF word result or a scalar result fully valid?
logic out_vrf_word_valid, out_scalar_valid;
Expand Down Expand Up @@ -683,6 +682,7 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
alu_result_vmsbf_vm = '1;
alu_result_vmsof_vm = '1;
alu_result_vm = '1;
alu_result_compressed_seq = '1;

vcpop_operand = '0;

Expand Down Expand Up @@ -714,11 +714,38 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
// This operation always writes to multiple of VRF words, and it does not need vd
// This operation can overwrite the destination register without constraints on tail elements
[VMANDNOT:VMXNOR]: alu_result_vm_m = masku_operand_alu_seq;
// Comparisons: mask out the masked out bits of this pre-computed slice
[VMFEQ:VMSGT]: alu_result_vm_m = alu_result_compressed_seq
| ~(masku_operand_m_seq | {NrLanes*DataWidth{vinsn_issue.vm}});
// Add/sub-with-carry/borrow: the masks are all 1 since these operations are NOT masked
[VMADC:VMSBC]: alu_result_vm_m = alu_result_compressed_seq;
// Compress the alu_result from ALU/FPU format to MASKU format
[VMFEQ:VMSGT],
[VMADC:VMSBC]: begin
unique case (vinsn_issue.eew_vs2)
EW8: begin
for (int i = 0; i < NrLanes * DataWidth / 8; i++)
alu_result_compressed_seq[masku_alu_compress_cnt_q[idx_width(NrLanes * DataWidth/8)-1:0] * NrLanes * DataWidth / 8 + i] =
masku_operand_alu_seq[i * DataWidth / 8];
end
EW16: begin
for (int i = 0; i < NrLanes * DataWidth / 16; i++)
alu_result_compressed_seq[masku_alu_compress_cnt_q[idx_width(NrLanes * DataWidth/16)-1:0] * NrLanes * DataWidth / 16 + i] =
masku_operand_alu_seq[i * DataWidth / 4];
end
EW32: begin
for (int i = 0; i < NrLanes * DataWidth / 32; i++)
alu_result_compressed_seq[masku_alu_compress_cnt_q[idx_width(NrLanes * DataWidth/32)-1:0] * NrLanes * DataWidth / 32 + i] =
masku_operand_alu_seq[i * DataWidth / 2];
end
default: begin // EW64
for (int i = 0; i < NrLanes * DataWidth / 64; i++)
alu_result_compressed_seq[masku_alu_compress_cnt_q[idx_width(NrLanes * DataWidth/64)-1:0] * NrLanes * DataWidth / 64 + i] =
masku_operand_alu_seq[i * DataWidth / 1];
end
endcase

// Comparisons: mask out the masked out bits of this pre-computed slice
// Add/sub-with-carry/borrow: the masks are all 1 since these operations are NOT masked
alu_result_vm_m = vinsn_issue.op inside {[VMFEQ:VMSGT]}
? alu_result_compressed_seq | ~(masku_operand_m_seq | {NrLanes*DataWidth{vinsn_issue.vm}})
: alu_result_compressed_seq;
end
// VMSBF, VMSOF, VMSIF: compute a slice of the output and mask out the masked out bits
[VMSBF:VMSIF] : begin
vmsbf_buffer[0] = ~(masku_operand_alu_seq_m[in_ready_cnt_q[idx_width(NrLanes*DataWidth/VmsxfParallelism)-1:0] * VmsxfParallelism] | found_one_q);
Expand Down Expand Up @@ -894,13 +921,13 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(

// Maintain state
vinsn_queue_d = vinsn_queue_q;
read_cnt_d = read_cnt_q;
masku_pred_cnt_d = masku_pred_cnt_q;
issue_cnt_d = issue_cnt_q;
processing_cnt_d = processing_cnt_q;
commit_cnt_d = commit_cnt_q;

mask_pnt_d = mask_pnt_q;
vrf_pnt_d = vrf_pnt_q;
masku_pred_pnt_d = masku_pred_pnt_q;
masku_alu_compress_cnt_d = masku_alu_compress_cnt_q;

popcount_d = popcount_q;
vfirst_count_d = vfirst_count_q;
Expand Down Expand Up @@ -952,7 +979,7 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
result_queue_background_data[lane] = result_queue_q[result_queue_write_pnt_q][lane].wdata;

// Maintain state
delta_elm_d = delta_elm_q;
issue_cnt_delta_d = issue_cnt_delta_q;
in_ready_threshold_d = in_ready_threshold_q;
in_m_ready_threshold_d = in_m_ready_threshold_q;
out_valid_threshold_d = out_valid_threshold_q;
Expand Down Expand Up @@ -982,8 +1009,8 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
automatic int vrf_offset = vrf_byte[idx_width(StrbWidth)-1:0];

// The VRF pointer can be broken into a byte offset, and a bit offset
automatic int vrf_pnt_byte_offset = mask_pnt_q >> $clog2(StrbWidth);
automatic int vrf_pnt_bit_offset = mask_pnt_q[idx_width(StrbWidth)-1:0];
automatic int vrf_pnt_byte_offset = masku_pred_pnt_q >> $clog2(StrbWidth);
automatic int vrf_pnt_bit_offset = masku_pred_pnt_q[idx_width(StrbWidth)-1:0];

// A single bit from the mask operands can be used several times, depending on the eew.
automatic int mask_seq_bit = vrf_seq_byte >> int'(vinsn_issue.vtype.vsew);
Expand All @@ -1004,13 +1031,13 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
masku_operand_m[mask_lane][mask_offset];
end

// Is there an instruction ready to be issued?
// We need to send mask strobes outside of the MASKU in the case of VMADC/VMSBC or non-MASKU masked instructions
if (vinsn_issue_valid && ((vinsn_issue.vfu != VFU_MaskUnit) || (vinsn_issue.op inside {[VMADC:VMSBC]}))) begin
// Is there place in the mask queue to write the mask operands?
// Did we receive the mask bits on the MaskM channel?
if (!vinsn_issue.vm && &masku_operand_m_valid) begin
// Account for the used operands
mask_pnt_d += NrLanes * (1 << (int'(EW64) - vinsn_issue.vtype.vsew));
masku_pred_pnt_d += NrLanes * (1 << (int'(EW64) - vinsn_issue.vtype.vsew));

// Increment result queue pointers and counters
mask_queue_cnt_d += 1;
Expand All @@ -1020,24 +1047,24 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
mask_queue_write_pnt_d = mask_queue_write_pnt_q + 1;

// Account for the operands that were issued
read_cnt_d = read_cnt_q - NrLanes * (1 << (int'(EW64) - vinsn_issue.vtype.vsew));
if (read_cnt_q < NrLanes * (1 << (int'(EW64) - vinsn_issue.vtype.vsew)))
read_cnt_d = '0;
masku_pred_cnt_d = masku_pred_cnt_q - NrLanes * (1 << (int'(EW64) - vinsn_issue.vtype.vsew));
if (masku_pred_cnt_q < NrLanes * (1 << (int'(EW64) - vinsn_issue.vtype.vsew)))
masku_pred_cnt_d = '0;

// Trigger the request signal
mask_queue_valid_d[mask_queue_write_pnt_q] = {NrLanes{1'b1}};

// Are there lanes with no valid elements?
// If so, mute their request signal
if (read_cnt_q < NrLanes)
mask_queue_valid_d[mask_queue_write_pnt_q] = (1 << read_cnt_q) - 1;
if (masku_pred_cnt_q < NrLanes)
mask_queue_valid_d[mask_queue_write_pnt_q] = (1 << masku_pred_cnt_q) - 1;

// Consumed all valid bytes from the lane operands
if (mask_pnt_d == NrLanes*DataWidth || read_cnt_d == '0) begin
if (masku_pred_pnt_d == NrLanes*DataWidth || masku_pred_cnt_d == '0) begin
// Request another beat
masku_operand_m_ready = '1;
// Reset the pointer
mask_pnt_d = '0;
masku_pred_pnt_d = '0;
end
end
end
Expand Down Expand Up @@ -1288,7 +1315,7 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
// Bump MASKU ALU state
found_one_d = found_one;
viota_acc_d = viota_acc;
vrf_pnt_d = vrf_pnt_q + delta_elm_q;
masku_alu_compress_cnt_d = masku_alu_compress_cnt_q + 1;
if (vinsn_issue.op inside {[VRGATHER:VCOMPRESS]}) vrgat_idx_fifo_pop = 1'b1;

// Increment the input, input-mask, and output slice counters
Expand All @@ -1297,8 +1324,8 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
out_valid_cnt_en = 1'b1;

// Account for the elements that have been processed
issue_cnt_d = issue_cnt_q - delta_elm_q;
if (issue_cnt_q < delta_elm_q)
issue_cnt_d = issue_cnt_q - issue_cnt_delta_q;
if (issue_cnt_q < issue_cnt_delta_q)
issue_cnt_d = '0;

// Request new input (by completing ready-valid handshake) once all slices have been processed
Expand Down Expand Up @@ -1391,7 +1418,7 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
// Finished issuing results
if (vinsn_issue_valid && (
( (vinsn_issue.vm || vinsn_issue.vfu == VFU_MaskUnit) && issue_cnt_d == '0) || vcompress_issue_end_q ||
(!(vinsn_issue.vm || vinsn_issue.vfu == VFU_MaskUnit) && read_cnt_d == '0))) begin
(!(vinsn_issue.vm || vinsn_issue.vfu == VFU_MaskUnit) && masku_pred_cnt_d == '0))) begin
// The instruction finished its issue phase
vinsn_queue_d.issue_cnt -= 1;
end
Expand Down Expand Up @@ -1458,7 +1485,7 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
out_valid_cnt_clr = 1'b1;

// Clear the vrf pointer for comparisons
vrf_pnt_d = '0;
masku_alu_compress_cnt_d = '0;

// Clear the vcompress issue-end indicator
vcompress_cnt_d = '0;
Expand Down Expand Up @@ -1516,28 +1543,28 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
if (vinsn_queue_d.issue_cnt == '0) begin
issue_cnt_d = pe_req_i.vl;
processing_cnt_d = pe_req_i.vl;
read_cnt_d = pe_req_i.vl;
masku_pred_cnt_d = pe_req_i.vl;

// Trim skipped words
if (pe_req_i.op == VSLIDEUP) begin
issue_cnt_d -= vlen_t'(trimmed_stride);
processing_cnt_d -= vlen_t'(trimmed_stride);
case (pe_req_i.vtype.vsew)
EW8: begin
read_cnt_d -= (vlen_t'(trimmed_stride) >> $clog2(NrLanes << 3)) << $clog2(NrLanes << 3);
mask_pnt_d = (vlen_t'(trimmed_stride) >> $clog2(NrLanes << 3)) << $clog2(NrLanes << 3);
masku_pred_cnt_d -= (vlen_t'(trimmed_stride) >> $clog2(NrLanes << 3)) << $clog2(NrLanes << 3);
masku_pred_pnt_d = (vlen_t'(trimmed_stride) >> $clog2(NrLanes << 3)) << $clog2(NrLanes << 3);
end
EW16: begin
read_cnt_d -= (vlen_t'(trimmed_stride) >> $clog2(NrLanes << 2)) << $clog2(NrLanes << 2);
mask_pnt_d = (vlen_t'(trimmed_stride) >> $clog2(NrLanes << 2)) << $clog2(NrLanes << 2);
masku_pred_cnt_d -= (vlen_t'(trimmed_stride) >> $clog2(NrLanes << 2)) << $clog2(NrLanes << 2);
masku_pred_pnt_d = (vlen_t'(trimmed_stride) >> $clog2(NrLanes << 2)) << $clog2(NrLanes << 2);
end
EW32: begin
read_cnt_d -= (vlen_t'(trimmed_stride) >> $clog2(NrLanes << 1)) << $clog2(NrLanes << 1);
mask_pnt_d = (vlen_t'(trimmed_stride) >> $clog2(NrLanes << 1)) << $clog2(NrLanes << 1);
masku_pred_cnt_d -= (vlen_t'(trimmed_stride) >> $clog2(NrLanes << 1)) << $clog2(NrLanes << 1);
masku_pred_pnt_d = (vlen_t'(trimmed_stride) >> $clog2(NrLanes << 1)) << $clog2(NrLanes << 1);
end
EW64: begin
read_cnt_d -= (vlen_t'(trimmed_stride) >> $clog2(NrLanes)) << $clog2(NrLanes);
mask_pnt_d = (vlen_t'(trimmed_stride) >> $clog2(NrLanes)) << $clog2(NrLanes);
masku_pred_cnt_d -= (vlen_t'(trimmed_stride) >> $clog2(NrLanes)) << $clog2(NrLanes);
masku_pred_pnt_d = (vlen_t'(trimmed_stride) >> $clog2(NrLanes)) << $clog2(NrLanes);
end
default:;
endcase
Expand All @@ -1547,62 +1574,62 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
unique case (pe_req_i.op) inside
[VMFEQ:VMSGT]: begin
// Mask to mask - encoded
delta_elm_d = NrLanes << (EW64 - pe_req_i.eew_vs2[1:0]);
issue_cnt_delta_d = NrLanes << (EW64 - pe_req_i.eew_vs2[1:0]);

in_ready_threshold_d = 0;
in_ready_threshold_d = '0;
in_m_ready_threshold_d = (DataWidth >> (EW64 - pe_req_i.eew_vs2[1:0]))-1;
out_valid_threshold_d = (DataWidth >> (EW64 - pe_req_i.eew_vs2[1:0]))-1;
end
[VMADC:VMSBC]: begin
// Mask to mask - encoded
delta_elm_d = NrLanes << (EW64 - pe_req_i.eew_vs2[1:0]);
issue_cnt_delta_d = NrLanes << (EW64 - pe_req_i.eew_vs2[1:0]);

in_ready_threshold_d = 0;
in_ready_threshold_d = '0;
in_m_ready_threshold_d = (DataWidth >> (EW64 - pe_req_i.eew_vs2[1:0]))-1;
out_valid_threshold_d = (DataWidth >> (EW64 - pe_req_i.eew_vs2[1:0]))-1;
end
[VMANDNOT:VMXNOR]: begin
// Mask to mask
delta_elm_d = VmLogicalParallelism;
issue_cnt_delta_d = VmLogicalParallelism;

in_ready_threshold_d = NrLanes*DataWidth/VmLogicalParallelism-1;
in_m_ready_threshold_d = NrLanes*DataWidth/VmLogicalParallelism-1;
out_valid_threshold_d = NrLanes*DataWidth/VmLogicalParallelism-1;
end
[VMSBF:VMSIF]: begin
// Mask to mask
delta_elm_d = VmsxfParallelism;
issue_cnt_delta_d = VmsxfParallelism;

in_ready_threshold_d = NrLanes*DataWidth/VmsxfParallelism-1;
in_m_ready_threshold_d = NrLanes*DataWidth/VmsxfParallelism-1;
out_valid_threshold_d = NrLanes*DataWidth/VmsxfParallelism-1;
end
[VIOTA:VID]: begin
// Mask to non-mask
delta_elm_d = ViotaParallelism;
issue_cnt_delta_d = ViotaParallelism;

in_ready_threshold_d = NrLanes*DataWidth/ViotaParallelism-1;
in_m_ready_threshold_d = NrLanes*DataWidth/ViotaParallelism-1;
out_valid_threshold_d = ((NrLanes*DataWidth/8/ViotaParallelism) >> pe_req_i.vtype.vsew[1:0])-1;
end
VCPOP: begin
// Mask to scalar
delta_elm_d = VcpopParallelism;
issue_cnt_delta_d = VcpopParallelism;

in_ready_threshold_d = NrLanes*DataWidth/VcpopParallelism-1;
in_m_ready_threshold_d = NrLanes*DataWidth/VcpopParallelism-1;
out_valid_threshold_d = '0;
end
VFIRST: begin
// Mask to scalar
delta_elm_d = VfirstParallelism;
issue_cnt_delta_d = VfirstParallelism;

in_ready_threshold_d = NrLanes*DataWidth/VfirstParallelism-1;
in_m_ready_threshold_d = NrLanes*DataWidth/VfirstParallelism-1;
out_valid_threshold_d = '0;
end
default: begin // VRGATHER, VRGATHEREI16, VCOMPRESS
delta_elm_d = 1;
issue_cnt_delta_d = 1;

in_ready_threshold_d = pe_req_i.op == VCOMPRESS ? NrLanes*DataWidth-1 : ((NrLanes*DataWidth/8) >> vrgat_eff_vsew)-1;
in_m_ready_threshold_d = NrLanes*DataWidth-1;
Expand Down Expand Up @@ -1632,17 +1659,17 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
always_ff @(posedge clk_i or negedge rst_ni) begin
if (!rst_ni) begin
vinsn_running_q <= '0;
read_cnt_q <= '0;
masku_pred_cnt_q <= '0;
issue_cnt_q <= '0;
processing_cnt_q <= '0;
commit_cnt_q <= '0;
vrf_pnt_q <= '0;
mask_pnt_q <= '0;
masku_alu_compress_cnt_q <= '0;
masku_pred_pnt_q <= '0;
pe_resp_o <= '0;
result_final_gnt_q <= '0;
popcount_q <= '0;
vfirst_count_q <= '0;
delta_elm_q <= '0;
issue_cnt_delta_q <= '0;
in_ready_threshold_q <= '0;
in_m_ready_threshold_q <= '0;
out_valid_threshold_q <= '0;
Expand All @@ -1656,17 +1683,17 @@ module masku import ara_pkg::*; import rvv_pkg::*; #(
vcompress_cnt_q <= '0;
end else begin
vinsn_running_q <= vinsn_running_d;
read_cnt_q <= read_cnt_d;
masku_pred_cnt_q <= masku_pred_cnt_d;
issue_cnt_q <= issue_cnt_d;
processing_cnt_q <= processing_cnt_d;
commit_cnt_q <= commit_cnt_d;
vrf_pnt_q <= vrf_pnt_d;
mask_pnt_q <= mask_pnt_d;
masku_alu_compress_cnt_q <= masku_alu_compress_cnt_d;
masku_pred_pnt_q <= masku_pred_pnt_d;
pe_resp_o <= pe_resp;
result_final_gnt_q <= result_final_gnt_d;
popcount_q <= popcount_d;
vfirst_count_q <= vfirst_count_d;
delta_elm_q <= delta_elm_d;
issue_cnt_delta_q <= issue_cnt_delta_d;
in_ready_threshold_q <= in_ready_threshold_d;
in_m_ready_threshold_q <= in_m_ready_threshold_d;
out_valid_threshold_q <= out_valid_threshold_d;
Expand Down
Loading

0 comments on commit 3cb8614

Please sign in to comment.