Skip to content

Commit

Permalink
[rtl] retime compress unit.
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li committed Dec 20, 2024
1 parent 10ad9f4 commit 79814c6
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 39 deletions.
70 changes: 58 additions & 12 deletions t1/src/mask/MaskCompress.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,82 @@
package org.chipsalliance.t1.rtl

import chisel3._
import chisel3.experimental.hierarchy.{instantiable, Instance, Instantiate}
import chisel3.experimental.{SerializableModule, SerializableModuleParameter}
import chisel3.properties.{AnyClassType, Path, Property}
import chisel3.util._
import org.chipsalliance.stdlib.GeneralOM

class CompressInput(parameter: T1Parameter) extends Bundle {
case class CompressParam(
datapathWidth: Int,
xLen: Int,
vLen: Int,
laneNumber: Int,
groupNumberBits: Int,
latency: Int)
extends SerializableModuleParameter

object CompressParam {
implicit def rwP = upickle.default.macroRW[CompressParam]
}

class CompressInput(parameter: CompressParam) extends Bundle {
val maskType: Bool = Bool()
val eew: UInt = UInt(2.W)
val uop: UInt = UInt(3.W)
val readFromScalar: UInt = UInt(parameter.datapathWidth.W)
val source1: UInt = UInt(parameter.datapathWidth.W)
val mask: UInt = UInt(parameter.datapathWidth.W)
val source2: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W)
val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W)
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
val ffoInput: UInt = UInt(parameter.laneNumber.W)
val validInput: UInt = UInt(parameter.laneNumber.W)
val lastCompress: Bool = Bool()
}

class CompressOutput(parameter: T1Parameter) extends Bundle {
class CompressOutput(parameter: CompressParam) extends Bundle {
val data: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W)
val mask: UInt = UInt((parameter.laneNumber * parameter.datapathWidth / 8).W)
val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W)
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
val ffoOutput: UInt = UInt(parameter.laneNumber.W)
val compressValid: Bool = Bool()
}

class MaskCompress(parameter: T1Parameter) extends Module {
val in: ValidIO[CompressInput] = IO(Flipped(Valid(new CompressInput(parameter))))
val out: CompressOutput = IO(Output(new CompressOutput(parameter)))
val newInstruction: Bool = IO(Input(Bool()))
val ffoInstruction: Bool = IO(Input(Bool()))
val writeData: UInt = IO(Output(UInt(parameter.xLen.W)))
class MaskCompressInterFace(parameter: CompressParam) extends Bundle {
val clock = Input(Clock())
val reset = Input(Reset())

val in: ValidIO[CompressInput] = Flipped(Valid(new CompressInput(parameter)))
val out: CompressOutput = Output(new CompressOutput(parameter))
val newInstruction: Bool = Input(Bool())
val ffoInstruction: Bool = Input(Bool())
val writeData: UInt = Output(UInt(parameter.xLen.W))
val om = Output(Property[AnyClassType]())
}

@instantiable
class MaskCompressOM(parameter: CompressParam) extends GeneralOM[CompressParam, MaskCompress](parameter) {
override def hasRetime: Boolean = true
}

class MaskCompress(val parameter: CompressParam)
extends FixedIORawModule(new MaskCompressInterFace(parameter))
with SerializableModule[CompressParam]
with ImplicitClock
with ImplicitReset {

protected def implicitClock = io.clock
protected def implicitReset = io.reset

val omInstance: Instance[MaskCompressOM] = Instantiate(new MaskCompressOM(parameter))
io.om := omInstance.getPropertyReference
omInstance.retimeIn.foreach(_ := Property(Path(io.clock)))

val in = io.in
val out = io.out
val newInstruction = io.newInstruction
val ffoInstruction = io.ffoInstruction
val writeData = io.writeData

val maskSize: Int = parameter.laneNumber * parameter.datapathWidth / 8

Expand Down Expand Up @@ -122,7 +168,7 @@ class MaskCompress(parameter: T1Parameter) extends Module {

val compressDataReg = RegInit(0.U((parameter.laneNumber * parameter.datapathWidth).W))
val compressTailValid: Bool = RegInit(false.B)
val compressWriteGroupCount: UInt = RegInit(0.U(parameter.laneParam.groupNumberBits.W))
val compressWriteGroupCount: UInt = RegInit(0.U(parameter.groupNumberBits.W))
val compressDataVec = Seq(1, 2, 4).map { dataByte =>
val dataBit = dataByte * 8
val elementSizePerSet = parameter.laneNumber * parameter.datapathWidth / 8 / dataByte
Expand Down Expand Up @@ -238,5 +284,5 @@ class MaskCompress(parameter: T1Parameter) extends Module {
ffoIndex := source1SigExtend
}
outWire.ffoOutput := completedLeftOr | Fill(parameter.laneNumber, ffoValid)
out := RegNext(outWire, 0.U.asTypeOf(out))
out := Pipe(true.B, outWire, parameter.latency).bits
}
72 changes: 45 additions & 27 deletions t1/src/mask/MaskUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ class MaskUnitOM(parameter: T1Parameter) extends GeneralOM[T1Parameter, MaskUnit
@public
val reduceUnitIn = IO(Input(Property[AnyClassType]()))
reduceUnit := reduceUnitIn

@public
val compress = IO(Output(Property[AnyClassType]()))
@public
val compressIn = IO(Input(Property[AnyClassType]()))
compress := compressIn
}

// TODO: no T1Parameter here.
Expand Down Expand Up @@ -898,14 +904,24 @@ class MaskUnit(val parameter: T1Parameter)
// Determine whether the data is ready
val executeEnqValid: Bool = otherTypeRequestDeq && !readType

val compressParam: CompressParam = CompressParam(
parameter.datapathWidth,
parameter.xLen,
parameter.vLen,
parameter.laneNumber,
parameter.laneParam.groupNumberBits,
1
)
// start execute
val compressUnit: MaskCompress = Module(new MaskCompress(parameter))
val reduceUnit = Instantiate(
val compressUnit = Instantiate(new MaskCompress(compressParam))
val reduceUnit = Instantiate(
new MaskReduce(
MaskReduceParameter(parameter.datapathWidth, parameter.laneNumber, parameter.fpuEnable)
)
)
omInstance.reduceUnitIn := reduceUnit.io.om.asAnyClassType
omInstance.compressIn := compressUnit.io.om.asAnyClassType

val extendUnit: MaskExtend = Module(new MaskExtend(parameter))

// todo
Expand Down Expand Up @@ -935,28 +951,30 @@ class MaskUnit(val parameter: T1Parameter)
val compressSource1: UInt = Mux1H(sew1H, vs1Split.map(_._1))
val source1Select: UInt = Mux(mv, readVS1Reg.data, compressSource1)
val source1Change: Bool = Mux1H(sew1H, vs1Split.map(_._2))
when(source1Change && compressUnit.in.fire) {
when(source1Change && compressUnit.io.in.fire) {
readVS1Reg.dataValid := false.B
readVS1Reg.requestSend := false.B
readVS1Reg.readIndex := readVS1Reg.readIndex + 1.U

}
viotaCounterAdd := compressUnit.in.fire

compressUnit.in.valid := executeEnqValid && unitType(1)
compressUnit.in.bits.maskType := instReg.maskType
compressUnit.in.bits.eew := instReg.sew
compressUnit.in.bits.uop := instReg.decodeResult(Decoder.topUop)
compressUnit.in.bits.readFromScalar := instReg.readFromScala
compressUnit.in.bits.source1 := source1Select
compressUnit.in.bits.mask := executeElementMask
compressUnit.in.bits.source2 := source2
compressUnit.in.bits.groupCounter := requestCounter
compressUnit.in.bits.lastCompress := lastGroup
compressUnit.in.bits.ffoInput := VecInit(exeReqReg.map(_.bits.ffo)).asUInt
compressUnit.in.bits.validInput := VecInit(exeReqReg.map(_.valid)).asUInt
compressUnit.newInstruction := instReq.valid
compressUnit.ffoInstruction := instReq.bits.decodeResult(Decoder.topUop)(2, 0) === BitPat("b11?")
viotaCounterAdd := compressUnit.io.in.fire

compressUnit.io.clock := implicitClock
compressUnit.io.reset := implicitReset
compressUnit.io.in.valid := executeEnqValid && unitType(1)
compressUnit.io.in.bits.maskType := instReg.maskType
compressUnit.io.in.bits.eew := instReg.sew
compressUnit.io.in.bits.uop := instReg.decodeResult(Decoder.topUop)
compressUnit.io.in.bits.readFromScalar := instReg.readFromScala
compressUnit.io.in.bits.source1 := source1Select
compressUnit.io.in.bits.mask := executeElementMask
compressUnit.io.in.bits.source2 := source2
compressUnit.io.in.bits.groupCounter := requestCounter
compressUnit.io.in.bits.lastCompress := lastGroup
compressUnit.io.in.bits.ffoInput := VecInit(exeReqReg.map(_.bits.ffo)).asUInt
compressUnit.io.in.bits.validInput := VecInit(exeReqReg.map(_.valid)).asUInt
compressUnit.io.newInstruction := instReq.valid
compressUnit.io.ffoInstruction := instReq.bits.decodeResult(Decoder.topUop)(2, 0) === BitPat("b11?")

reduceUnit.io.clock := implicitClock
reduceUnit.io.reset := implicitReset
Expand All @@ -980,7 +998,7 @@ class MaskUnit(val parameter: T1Parameter)
sink := VecInit(exeReqReg.map(_.bits.fpReduceValid.get)).asUInt
}

when(reduceUnit.io.in.fire || compressUnit.in.fire) {
when(reduceUnit.io.in.fire || compressUnit.io.in.fire) {
readVS1Reg.sendToExecution := true.B
}

Expand All @@ -1001,7 +1019,7 @@ class MaskUnit(val parameter: T1Parameter)
val executeResult: UInt = Mux1H(
unitType(3, 1),
Seq(
compressUnit.out.data,
compressUnit.io.out.data,
reduceUnit.io.out.bits.data,
extendUnit.out
)
Expand All @@ -1021,7 +1039,7 @@ class MaskUnit(val parameter: T1Parameter)
val executeValid: Bool = Mux1H(
unitType(3, 1),
Seq(
compressUnit.out.compressValid,
compressUnit.io.out.compressValid,
false.B,
executeEnqValid
)
Expand All @@ -1039,13 +1057,13 @@ class MaskUnit(val parameter: T1Parameter)
val executeDeqGroupCounter: UInt = Mux1H(
unitType(3, 1),
Seq(
compressUnit.out.groupCounter,
compressUnit.io.out.groupCounter,
requestCounter,
extendGroupCount
)
)

val executeWriteByteMask: UInt = Mux(compress || ffo || mvVd, compressUnit.out.mask, executeByteMask)
val executeWriteByteMask: UInt = Mux(compress || ffo || mvVd, compressUnit.io.out.mask, executeByteMask)
maskedWrite.needWAR := maskDestinationType
maskedWrite.vd := instReg.vd
maskedWrite.in.zipWithIndex.foreach { case (req, index) =>
Expand All @@ -1057,7 +1075,7 @@ class MaskUnit(val parameter: T1Parameter)
req.bits.pipeData := exeReqReg(index).bits.source1
req.bits.bitMask := bitMask
req.bits.groupCounter := executeDeqGroupCounter
req.bits.ffoByOther := compressUnit.out.ffoOutput(index) && ffo
req.bits.ffoByOther := compressUnit.io.out.ffoOutput(index) && ffo
if (index == 0) {
// reduce result
when(unitType(2)) {
Expand Down Expand Up @@ -1117,7 +1135,7 @@ class MaskUnit(val parameter: T1Parameter)
val executeStageInvalid: Bool = Mux1H(
unitType(3, 1),
Seq(
!compressUnit.out.compressValid,
!compressUnit.io.out.compressValid,
reduceUnit.io.in.ready,
true.B
)
Expand All @@ -1136,7 +1154,7 @@ class MaskUnit(val parameter: T1Parameter)
lastReportValid,
indexToOH(instReg.instructionIndex, parameter.chainingSize)
)
writeRDData := Mux(pop, reduceUnit.io.out.bits.data, compressUnit.writeData)
writeRDData := Mux(pop, reduceUnit.io.out.bits.data, compressUnit.io.writeData)

// gather read state
when(gatherRequestFire) {
Expand Down

0 comments on commit 79814c6

Please sign in to comment.