Skip to content

Commit

Permalink
[rtl] refactor mask unit.
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li committed Oct 31, 2024
1 parent 3b6d557 commit 2319671
Show file tree
Hide file tree
Showing 19 changed files with 2,057 additions and 1,112 deletions.
80 changes: 79 additions & 1 deletion t1/src/Bundles.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class InstructionState extends Bundle {
val idle: Bool = Bool()

/** used for mask unit, schedule mask unit to execute. */
val sMaskUnitExecution: Bool = Bool()
val wMaskUnitLast: Bool = Bool()

/** wait for vrf write finish. */
val wVRFWrite: Bool = Bool()
Expand Down Expand Up @@ -698,3 +698,81 @@ class T1Retire(xLen: Int) extends Bundle {
val csr: ValidIO[T1CSRRetire] = Valid(new T1CSRRetire)
val mem: ValidIO[EmptyBundle] = Valid(new EmptyBundle)
}

class MaskUnitExecuteState(parameter: T1Parameter) extends Bundle {
val groupReadState: UInt = UInt(parameter.laneNumber.W)
val needRead: UInt = UInt(parameter.laneNumber.W)
val elementValid: UInt = UInt(parameter.laneNumber.W)
val readOffset: UInt = UInt((parameter.laneNumber * parameter.laneParam.vrfOffsetBits).W)
val accessLane: Vec[UInt] = Vec(parameter.laneNumber, UInt(log2Ceil(parameter.laneNumber).W))
// 3: log2Ceil(8); 8: Use up to 8 registers
val vsGrowth: Vec[UInt] = Vec(parameter.laneNumber, UInt(3.W))
val groupCount: UInt = UInt(parameter.laneParam.groupNumberBits.W)
val executeIndex: UInt = UInt(2.W)
val readDataOffset: UInt = UInt((log2Ceil(parameter.datapathWidth / 8) * parameter.laneNumber).W)
val last: Bool = Bool()
}

class MaskUnitInstReq(parameter: T1Parameter) extends Bundle {
val instructionIndex: UInt = UInt(parameter.instructionIndexBits.W)
val decodeResult: DecodeBundle = Decoder.bundle(parameter.decoderParam)
val readFromScala: UInt = UInt(parameter.datapathWidth.W)
val sew: UInt = UInt(2.W)
val vlmul: UInt = UInt(3.W)
val maskType: Bool = Bool()
val vxrm: UInt = UInt(3.W)
val vs2: UInt = UInt(5.W)
val vs1: UInt = UInt(5.W)
val vd: UInt = UInt(5.W)
val vl: UInt = UInt(parameter.laneParam.vlMaxBits.W)
}

class MaskUnitExeReq(parameter: LaneParameter) extends Bundle {
// source1, read vs
val source1: UInt = UInt(parameter.datapathWidth.W)
// source2, read offset
val source2: UInt = UInt(parameter.datapathWidth.W)
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
val index: UInt = UInt(parameter.instructionIndexBits.W)
}

class MaskUnitExeResponse(parameter: LaneParameter) extends Bundle {
val ffoByOther: Bool = Bool()
val writeData = new MaskUnitWriteBundle(parameter)
val index: UInt = UInt(parameter.instructionIndexBits.W)
}

class MaskUnitReadReq(parameter: T1Parameter) extends Bundle {
val vs: UInt = UInt(5.W)
// source2, read offset
val offset: UInt = UInt(parameter.laneParam.vrfOffsetBits.W)
// Read which lane
val readLane: UInt = UInt(log2Ceil(parameter.laneNumber).W)
// from which request
val requestIndex: UInt = UInt(log2Ceil(parameter.laneNumber).W)
// data position in data path
val dataOffset: UInt = UInt(log2Ceil(parameter.datapathWidth / 8).W)
}

class MaskUnitReadQueue(parameter: T1Parameter) extends Bundle {
val vs: UInt = UInt(5.W)
// source2, read offset
val offset: UInt = UInt(parameter.laneParam.vrfOffsetBits.W)
// Which channel will this read request be written to?
val writeIndex: UInt = UInt(log2Ceil(parameter.laneNumber).W)
val dataOffset: UInt = UInt(log2Ceil(parameter.datapathWidth / 8).W)
}

class MaskUnitWaitReadQueue(parameter: T1Parameter) extends Bundle {
val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W)
val executeIndex: UInt = UInt(2.W)
val sourceValid: UInt = UInt(parameter.laneNumber.W)
val needRead: UInt = UInt(parameter.laneNumber.W)
val last: Bool = Bool()
}

class MaskUnitWriteBundle(parameter: LaneParameter) extends Bundle {
val data: UInt = UInt(parameter.datapathWidth.W)
val mask: UInt = UInt((parameter.datapathWidth / 8).W)
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
}
89 changes: 50 additions & 39 deletions t1/src/Lane.scala
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,14 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
@public
val csrInterface: CSRInterface = IO(Input(new CSRInterface(parameter.vlMaxBits)))

/** response to [[T1.lsu]] or mask unit in [[T1]] */
@public
val laneResponse: ValidIO[LaneResponse] = IO(Valid(new LaneResponse(parameter)))
val maskUnitRequest: DecoupledIO[MaskUnitExeReq] = IO(Decoupled(new MaskUnitExeReq(parameter)))

/** feedback from [[T1]] to [[Lane]] for [[laneResponse]] */
@public
val laneResponseFeedback: ValidIO[LaneResponseFeedback] = IO(Flipped(Valid(new LaneResponseFeedback(parameter))))
val maskRequestToLSU: Bool = IO(Output(Bool()))

@public
val maskUnitResponse: ValidIO[MaskUnitExeResponse] = IO(Flipped(Valid(new MaskUnitExeResponse(parameter))))

/** for LSU and V accessing lane, this is not a part of ring, but a direct connection. */
@public
Expand Down Expand Up @@ -570,14 +571,25 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
slotCanShift(index) := true.B
}

val laneState: LaneState = Wire(new LaneState(parameter))
val stage0: Instance[LaneStage0] = Instantiate(new LaneStage0(parameter, isLastSlot))
val stage1: Instance[LaneStage1] = Instantiate(new LaneStage1(parameter, isLastSlot))
val stage2: Instance[LaneStage2] = Instantiate(new LaneStage2(parameter, isLastSlot))
val executionUnit: Instance[LaneExecutionBridge] = Instantiate(
val laneState: LaneState = Wire(new LaneState(parameter))
val stage0: Instance[LaneStage0] = Instantiate(new LaneStage0(parameter, isLastSlot))
val stage1: Instance[LaneStage1] = Instantiate(new LaneStage1(parameter, isLastSlot))
val stage2: Instance[LaneStage2] = Instantiate(new LaneStage2(parameter, isLastSlot))
val executionUnit: Instance[LaneExecutionBridge] = Instantiate(
new LaneExecutionBridge(parameter, isLastSlot, index)
)
val stage3: Instance[LaneStage3] = Instantiate(new LaneStage3(parameter, isLastSlot))
val maskStage: Option[Instance[MaskExchangeUnit]] =
Option.when(isLastSlot)(Instantiate(new MaskExchangeUnit(parameter)))
val stage3: Instance[LaneStage3] = Instantiate(new LaneStage3(parameter, isLastSlot))
val stage3EnqWire: DecoupledIO[LaneStage3Enqueue] = Wire(Decoupled(new LaneStage3Enqueue(parameter, isLastSlot)))
val stage3EnqSelect: DecoupledIO[LaneStage3Enqueue] = maskStage.map { mask =>
mask.enqueue <> stage3EnqWire
maskUnitRequest <> mask.maskReq
maskRequestToLSU <> mask.maskRequestToLSU
mask.maskUnitResponse := maskUnitResponse
mask.dequeue
}.getOrElse(stage3EnqWire)
stage3.enqueue <> stage3EnqSelect

// slot state
laneState.vSew1H := vSew1H
Expand Down Expand Up @@ -759,50 +771,47 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
0.U(parameter.chainingSize.W)
)
AssertProperty(BoolSequence(!executionUnit.dequeue.valid || stage2.dequeue.valid))
stage3.enqueue.valid := executionUnit.dequeue.valid
executionUnit.dequeue.ready := stage3.enqueue.ready
stage3EnqWire.valid := executionUnit.dequeue.valid
executionUnit.dequeue.ready := stage3EnqWire.ready
stage2.dequeue.ready := executionUnit.dequeue.fire

if (!isLastSlot) {
stage3.enqueue.bits := DontCare
stage3EnqWire.bits := DontCare
}

// pipe state from stage0
stage3.enqueue.bits.decodeResult := stage2.dequeue.bits.decodeResult
stage3.enqueue.bits.instructionIndex := stage2.dequeue.bits.instructionIndex
stage3.enqueue.bits.loadStore := stage2.dequeue.bits.loadStore
stage3.enqueue.bits.vd := stage2.dequeue.bits.vd
stage3.enqueue.bits.ffoByOtherLanes := ffoRecord.ffoByOtherLanes
stage3.enqueue.bits.groupCounter := stage2.dequeue.bits.groupCounter
stage3.enqueue.bits.mask := stage2.dequeue.bits.mask
stage3EnqWire.bits.decodeResult := stage2.dequeue.bits.decodeResult
stage3EnqWire.bits.instructionIndex := stage2.dequeue.bits.instructionIndex
stage3EnqWire.bits.loadStore := stage2.dequeue.bits.loadStore
stage3EnqWire.bits.vd := stage2.dequeue.bits.vd
stage3EnqWire.bits.ffoByOtherLanes := ffoRecord.ffoByOtherLanes
stage3EnqWire.bits.groupCounter := stage2.dequeue.bits.groupCounter
stage3EnqWire.bits.mask := stage2.dequeue.bits.mask
if (isLastSlot) {
stage3.enqueue.bits.sSendResponse := stage2.dequeue.bits.sSendResponse.get
stage3.enqueue.bits.ffoSuccess := executionUnit.dequeue.bits.ffoSuccess.get
stage3.enqueue.bits.fpReduceValid.zip(executionUnit.dequeue.bits.fpReduceValid).foreach { case (sink, source) =>
stage3EnqWire.bits.sSendResponse := stage2.dequeue.bits.sSendResponse.get
stage3EnqWire.bits.ffoSuccess := executionUnit.dequeue.bits.ffoSuccess.get
stage3EnqWire.bits.fpReduceValid.zip(executionUnit.dequeue.bits.fpReduceValid).foreach { case (sink, source) =>
sink := source
}
}
stage3.enqueue.bits.data := executionUnit.dequeue.bits.data
stage3.enqueue.bits.pipeData := stage2.dequeue.bits.pipeData.getOrElse(DontCare)
stage3.enqueue.bits.ffoIndex := executionUnit.dequeue.bits.ffoIndex
executionUnit.dequeue.bits.crossWriteData.foreach(data => stage3.enqueue.bits.crossWriteData := data)
stage2.dequeue.bits.sSendResponse.foreach(_ => stage3.enqueue.bits.sSendResponse := _)
executionUnit.dequeue.bits.ffoSuccess.foreach(_ => stage3.enqueue.bits.ffoSuccess := _)
stage3EnqWire.bits.data := executionUnit.dequeue.bits.data
stage3EnqWire.bits.pipeData := stage2.dequeue.bits.pipeData.getOrElse(DontCare)
stage3EnqWire.bits.ffoIndex := executionUnit.dequeue.bits.ffoIndex
executionUnit.dequeue.bits.crossWriteData.foreach(data => stage3EnqWire.bits.crossWriteData := data)
stage2.dequeue.bits.sSendResponse.foreach(_ => stage3EnqWire.bits.sSendResponse := _)
executionUnit.dequeue.bits.ffoSuccess.foreach(_ => stage3EnqWire.bits.ffoSuccess := _)

if (isLastSlot) {
when(laneResponseFeedback.valid) {
when(laneResponseFeedback.bits.complete) {
when(maskUnitResponse.valid) {
when(maskUnitResponse.bits.ffoByOther) {
ffoRecord.ffoByOtherLanes := true.B
}
}
when(stage3.enqueue.fire) {
when(stage3EnqWire.fire) {
executionUnit.dequeue.bits.ffoSuccess.foreach(ffoRecord.selfCompleted := _)
// This group found means the next group ended early
ffoRecord.ffoByOtherLanes := ffoRecord.ffoByOtherLanes || ffoRecord.selfCompleted
}

laneResponse <> stage3.laneResponse.get
stage3.laneResponseFeedback.get <> laneResponseFeedback
}

// --- stage 3 end & stage 4 start ---
Expand Down Expand Up @@ -1175,10 +1184,10 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
rpt.bits := allVrfWriteAfterCheck(parameter.chainingSize + 1 + rptIndex).instructionIndex
}
// todo: add mask unit write token
tokenManager.responseReport.valid := laneResponse.valid
tokenManager.responseReport.bits := laneResponse.bits.instructionIndex
tokenManager.responseFeedbackReport.valid := laneResponseFeedback.valid
tokenManager.responseFeedbackReport.bits := laneResponseFeedback.bits.instructionIndex
tokenManager.responseReport.valid := maskUnitRequest.valid
tokenManager.responseReport.bits := maskUnitRequest.bits.index
tokenManager.responseFeedbackReport.valid := maskUnitResponse.valid
tokenManager.responseFeedbackReport.bits := maskUnitResponse.bits.index
val instInSlot: UInt = slotControl
.zip(slotOccupied)
.map { case (slotState, occupied) =>
Expand Down Expand Up @@ -1211,6 +1220,8 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
tokenManager.topWriteDeq.valid := afterCheckDequeueFire(parameter.chainingSize)
tokenManager.topWriteDeq.bits := allVrfWriteAfterCheck(parameter.chainingSize).instructionIndex

tokenManager.maskUnitLastReport := lsuLastReport

layer.block(layers.Verification) {
val probeWire = Wire(new LaneProbe(parameter))
define(laneProbe, ProbeValue(probeWire))
Expand Down
Loading

0 comments on commit 2319671

Please sign in to comment.