Skip to content

Commit

Permalink
[rtl] refactor mask unit.
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li committed Oct 5, 2024
1 parent d7395a2 commit d4ef7e0
Show file tree
Hide file tree
Showing 14 changed files with 1,107 additions and 962 deletions.
70 changes: 70 additions & 0 deletions t1/src/Bundles.scala
Original file line number Diff line number Diff line change
Expand Up @@ -698,3 +698,73 @@ class T1Retire(xLen: Int) extends Bundle {
val csr: ValidIO[T1CSRRetire] = Valid(new T1CSRRetire)
val mem: ValidIO[EmptyBundle] = Valid(new EmptyBundle)
}

class MaskUnitExecuteState(parameter: T1Parameter) extends Bundle {
val groupReadState: UInt = UInt(parameter.laneNumber.W)
val needRead: UInt = UInt(parameter.laneNumber.W)
val readMask: UInt = UInt((parameter.laneNumber * parameter.datapathWidth / 8).W)
val readDataOffset = UInt((parameter.laneNumber * parameter.laneParam.vrfOffsetBits).W)
}

class MaskUnitInstReq(parameter: T1Parameter) extends Bundle {
val instructionIndex: UInt = UInt(parameter.instructionIndexBits.W)
val decodeResult: DecodeBundle = Decoder.bundle(parameter.decoderParam)
val readFromScala: UInt = UInt(parameter.datapathWidth.W)
val eew: UInt = UInt(2.W)
val vlmul: UInt = UInt(2.W)
val vm: Bool = Bool()
val vxrm: UInt = UInt(3.W)
val vs2: UInt = UInt(5.W)
}

class MaskUnitExeReq(parameter: LaneParameter) extends Bundle {
// source1, read vs
val source1: UInt = UInt(parameter.datapathWidth.W)
// source2, read offset
val source2: UInt = UInt(parameter.datapathWidth.W)
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
val index: UInt = UInt(parameter.instructionIndexBits.W)
}

class MaskUnitExeResponse(parameter: LaneParameter) extends Bundle {
val ffoByOther: Bool = Bool()
val writeData = new MaskUnitWriteBundle(parameter)
val index: UInt = UInt(parameter.instructionIndexBits.W)
}

class MaskUnitReadReq(parameter: T1Parameter) extends Bundle {
val vs: UInt = UInt(5.W)
// source2, read offset
val offset: UInt = UInt(parameter.laneParam.vrfOffsetBits.W)
// Read which lane
val readLane: UInt = UInt(log2Ceil(parameter.laneNumber).W)
}

class MaskUnitReadQueue(parameter: T1Parameter) extends Bundle {
val vs: UInt = UInt(5.W)
// source2, read offset
val offset: UInt = UInt(parameter.laneParam.vrfOffsetBits.W)
// Which channel will this read request be written to?
val writeIndex: UInt = UInt(log2Ceil(parameter.laneNumber).W)
}

class MaskUnitWaitReadQueue(parameter: T1Parameter) extends Bundle {
// source1
val source1: Vec[UInt] = Vec(parameter.laneNumber, UInt(parameter.datapathWidth.W))
// source2
val source2: Vec[UInt] = Vec(parameter.laneNumber, UInt(parameter.datapathWidth.W))

val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W)
val executeIndex: UInt = UInt(2.W)
val lastGroup: Bool = Bool()
val sourceValid: UInt = UInt(parameter.laneNumber.W)
val writeMask: Vec[UInt] = Vec(parameter.laneNumber, UInt((parameter.datapathWidth / 4).W))

val needRead: UInt = UInt(parameter.laneNumber.W)
}

class MaskUnitWriteBundle(parameter: LaneParameter) extends Bundle {
val data: UInt = UInt(parameter.datapathWidth.W)
val mask: UInt = UInt((parameter.datapathWidth / 8).W)
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
}
74 changes: 41 additions & 33 deletions t1/src/Lane.scala
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,14 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
@public
val csrInterface: CSRInterface = IO(Input(new CSRInterface(parameter.vlMaxBits)))

/** response to [[T1.lsu]] or mask unit in [[T1]] */
@public
val laneResponse: ValidIO[LaneResponse] = IO(Valid(new LaneResponse(parameter)))
val maskUnitRequest: DecoupledIO[MaskUnitExeReq] = IO(Decoupled(new MaskUnitExeReq(parameter)))

/** feedback from [[T1]] to [[Lane]] for [[laneResponse]] */
@public
val laneResponseFeedback: ValidIO[LaneResponseFeedback] = IO(Flipped(Valid(new LaneResponseFeedback(parameter))))
val maskRequestToLSU: Bool = IO(Output(Bool()))

@public
val maskUnitResponse: ValidIO[MaskUnitExeResponse] = IO(Flipped(Valid(new MaskUnitExeResponse(parameter))))

/** for LSU and V accessing lane, this is not a part of ring, but a direct connection. */
@public
Expand Down Expand Up @@ -571,7 +572,17 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
val executionUnit: Instance[LaneExecutionBridge] = Instantiate(
new LaneExecutionBridge(parameter, isLastSlot, index)
)
val maskStage: Option[Instance[MaskExchangeUnit]] = Option.when(isLastSlot)(Instantiate(new MaskExchangeUnit(parameter)))
val stage3: Instance[LaneStage3] = Instantiate(new LaneStage3(parameter, isLastSlot))
val stage3EnqWire: DecoupledIO[LaneStage3Enqueue] = Wire(Decoupled(new LaneStage3Enqueue(parameter, isLastSlot)))
val stage3EnqSelect: DecoupledIO[LaneStage3Enqueue] = maskStage.map { mask =>
mask.enqueue <> stage3EnqWire
maskUnitRequest <> mask.maskReq
maskRequestToLSU <> mask.maskRequestToLSU
mask.maskUnitResponse := maskUnitResponse
mask.dequeue
}.getOrElse(stage3EnqWire)
stage3.enqueue <> stage3EnqSelect

// slot state
laneState.vSew1H := vSew1H
Expand Down Expand Up @@ -753,50 +764,47 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
0.U(parameter.chainingSize.W)
)
AssertProperty(BoolSequence(!executionUnit.dequeue.valid || stage2.dequeue.valid))
stage3.enqueue.valid := executionUnit.dequeue.valid
executionUnit.dequeue.ready := stage3.enqueue.ready
stage3EnqWire.valid := executionUnit.dequeue.valid
executionUnit.dequeue.ready := stage3EnqWire.ready
stage2.dequeue.ready := executionUnit.dequeue.fire

if (!isLastSlot) {
stage3.enqueue.bits := DontCare
stage3EnqWire.bits := DontCare
}

// pipe state from stage0
stage3.enqueue.bits.decodeResult := stage2.dequeue.bits.decodeResult
stage3.enqueue.bits.instructionIndex := stage2.dequeue.bits.instructionIndex
stage3.enqueue.bits.loadStore := stage2.dequeue.bits.loadStore
stage3.enqueue.bits.vd := stage2.dequeue.bits.vd
stage3.enqueue.bits.ffoByOtherLanes := ffoRecord.ffoByOtherLanes
stage3.enqueue.bits.groupCounter := stage2.dequeue.bits.groupCounter
stage3.enqueue.bits.mask := stage2.dequeue.bits.mask
stage3EnqWire.bits.decodeResult := stage2.dequeue.bits.decodeResult
stage3EnqWire.bits.instructionIndex := stage2.dequeue.bits.instructionIndex
stage3EnqWire.bits.loadStore := stage2.dequeue.bits.loadStore
stage3EnqWire.bits.vd := stage2.dequeue.bits.vd
stage3EnqWire.bits.ffoByOtherLanes := ffoRecord.ffoByOtherLanes
stage3EnqWire.bits.groupCounter := stage2.dequeue.bits.groupCounter
stage3EnqWire.bits.mask := stage2.dequeue.bits.mask
if (isLastSlot) {
stage3.enqueue.bits.sSendResponse := stage2.dequeue.bits.sSendResponse.get
stage3.enqueue.bits.ffoSuccess := executionUnit.dequeue.bits.ffoSuccess.get
stage3.enqueue.bits.fpReduceValid.zip(executionUnit.dequeue.bits.fpReduceValid).foreach { case (sink, source) =>
stage3EnqWire.bits.sSendResponse := stage2.dequeue.bits.sSendResponse.get
stage3EnqWire.bits.ffoSuccess := executionUnit.dequeue.bits.ffoSuccess.get
stage3EnqWire.bits.fpReduceValid.zip(executionUnit.dequeue.bits.fpReduceValid).foreach { case (sink, source) =>
sink := source
}
}
stage3.enqueue.bits.data := executionUnit.dequeue.bits.data
stage3.enqueue.bits.pipeData := stage2.dequeue.bits.pipeData.getOrElse(DontCare)
stage3.enqueue.bits.ffoIndex := executionUnit.dequeue.bits.ffoIndex
executionUnit.dequeue.bits.crossWriteData.foreach(data => stage3.enqueue.bits.crossWriteData := data)
stage2.dequeue.bits.sSendResponse.foreach(_ => stage3.enqueue.bits.sSendResponse := _)
executionUnit.dequeue.bits.ffoSuccess.foreach(_ => stage3.enqueue.bits.ffoSuccess := _)
stage3EnqWire.bits.data := executionUnit.dequeue.bits.data
stage3EnqWire.bits.pipeData := stage2.dequeue.bits.pipeData.getOrElse(DontCare)
stage3EnqWire.bits.ffoIndex := executionUnit.dequeue.bits.ffoIndex
executionUnit.dequeue.bits.crossWriteData.foreach(data => stage3EnqWire.bits.crossWriteData := data)
stage2.dequeue.bits.sSendResponse.foreach(_ => stage3EnqWire.bits.sSendResponse := _)
executionUnit.dequeue.bits.ffoSuccess.foreach(_ => stage3EnqWire.bits.ffoSuccess := _)

if (isLastSlot) {
when(laneResponseFeedback.valid) {
when(laneResponseFeedback.bits.complete) {
when(maskUnitResponse.valid) {
when(maskUnitResponse.bits.ffoByOther) {
ffoRecord.ffoByOtherLanes := true.B
}
}
when(stage3.enqueue.fire) {
when(stage3EnqWire.fire) {
executionUnit.dequeue.bits.ffoSuccess.foreach(ffoRecord.selfCompleted := _)
// This group found means the next group ended early
ffoRecord.ffoByOtherLanes := ffoRecord.ffoByOtherLanes || ffoRecord.selfCompleted
}

laneResponse <> stage3.laneResponse.get
stage3.laneResponseFeedback.get <> laneResponseFeedback
}

// --- stage 3 end & stage 4 start ---
Expand Down Expand Up @@ -1168,10 +1176,10 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
rpt.bits := allVrfWriteAfterCheck(parameter.chainingSize + 1 + rptIndex).instructionIndex
}
// todo: add mask unit write token
tokenManager.responseReport.valid := laneResponse.valid
tokenManager.responseReport.bits := laneResponse.bits.instructionIndex
tokenManager.responseFeedbackReport.valid := laneResponseFeedback.valid
tokenManager.responseFeedbackReport.bits := laneResponseFeedback.bits.instructionIndex
tokenManager.responseReport.valid := maskUnitRequest.valid
tokenManager.responseReport.bits := maskUnitRequest.bits.index
tokenManager.responseFeedbackReport.valid := maskUnitResponse.valid
tokenManager.responseFeedbackReport.bits := maskUnitResponse.bits.index
val instInSlot: UInt = slotControl
.zip(slotOccupied)
.map { case (slotState, occupied) =>
Expand Down
Loading

0 comments on commit d4ef7e0

Please sign in to comment.