From 4069f286072bb59dca5797df84839f1d991bde65 Mon Sep 17 00:00:00 2001 From: qinjun-li Date: Fri, 20 Dec 2024 11:55:53 +0800 Subject: [PATCH] [rtl] retime compress unit. --- omreaderlib/src/t1/T1.scala | 13 ++-- omreaderlib/src/t1rocketv/T1RocketTile.scala | 19 ++++-- t1/src/T1.scala | 8 +-- t1/src/mask/MaskCompress.scala | 70 +++++++++++++++---- t1/src/mask/MaskUnit.scala | 72 ++++++++++++-------- 5 files changed, 128 insertions(+), 54 deletions(-) diff --git a/omreaderlib/src/t1/T1.scala b/omreaderlib/src/t1/T1.scala index 2c7ee4666..a935163dd 100644 --- a/omreaderlib/src/t1/T1.scala +++ b/omreaderlib/src/t1/T1.scala @@ -19,14 +19,19 @@ class T1(val mlirbc: Array[Byte]) extends T1OMReaderAPI { def sram: Seq[SRAM] = t1("lanes").list.elements().map(_.obj("vrf").obj).flatMap(getSRAM) - def floatAdder = { - val reduceUnit = t1("permutatuon").obj("reduceUnit").obj + def permutation: Seq[Retime] = { + val permutation = t1("permutation") + val reduceUnit = permutation.obj("reduceUnit").obj + val compressUnit = permutation.obj("compress").obj // TODO: need fieldOpt(name: String) - Option.when(reduceUnit.fieldNames().contains("floatAdder"))(reduceUnit("floatAdder").obj).flatMap(getRetime) + val floatAdder = + Option.when(reduceUnit.fieldNames().contains("floatAdder"))(reduceUnit("floatAdder").obj) + + (Seq(compressUnit) ++ floatAdder).flatMap(getRetime) } def vfus: Seq[Retime] = t1("lanes").list.elements().map(_.obj("vfus")).flatMap(_.list.elements().map(_.obj)).flatMap(getRetime) - def retime = (vfus ++ floatAdder).distinct + def retime = (vfus ++ permutation).distinct } diff --git a/omreaderlib/src/t1rocketv/T1RocketTile.scala b/omreaderlib/src/t1rocketv/T1RocketTile.scala index 3187adec6..0c33bd3f6 100644 --- a/omreaderlib/src/t1rocketv/T1RocketTile.scala +++ b/omreaderlib/src/t1rocketv/T1RocketTile.scala @@ -20,14 +20,21 @@ class T1RocketTile(val mlirbc: Array[Byte]) extends T1OMReaderAPI { t1("lanes").list.elements().map(_.obj("vrf").obj).flatMap(getSRAM) def cache: Seq[SRAM] = Seq(tile("frontend").obj("icache").obj, tile("hellaCache").obj).flatMap(getSRAM) - def vfu: Seq[Retime] = - t1("lanes").list.elements().map(_.obj("vfus")).flatMap(_.list.elements().map(_.obj)).flatMap(getRetime) - def floatAdder = { - val reduceUnit = t1("permutatuon").obj("reduceUnit").obj + + def permutation: Seq[Retime] = { + val permutation = t1("permutation") + val reduceUnit = permutation.obj("reduceUnit").obj + val compressUnit = permutation.obj("compress").obj // TODO: need fieldOpt(name: String) - Option.when(reduceUnit.fieldNames().contains("floatAdder"))(reduceUnit("floatAdder").obj).flatMap(getRetime) + val floatAdder = + Option.when(reduceUnit.fieldNames().contains("floatAdder"))(reduceUnit("floatAdder").obj) + + (Seq(compressUnit) ++ floatAdder).flatMap(getRetime) } - def retime = (vfu ++ floatAdder).distinct + def vfus: Seq[Retime] = + t1("lanes").list.elements().map(_.obj("vfus")).flatMap(_.list.elements().map(_.obj)).flatMap(getRetime) + + def retime = (vfus ++ permutation).distinct def sram = vrf ++ cache } diff --git a/t1/src/T1.scala b/t1/src/T1.scala index 6813c9b05..8466cd7ba 100644 --- a/t1/src/T1.scala +++ b/t1/src/T1.scala @@ -66,10 +66,10 @@ class T1OM(parameter: T1Parameter) extends GeneralOM[T1Parameter, T1](parameter) val decoderIn = IO(Input(Property[AnyClassType]())) decoder := decoderIn - val permutatuon = IO(Output(Property[AnyClassType]())) + val permutation = IO(Output(Property[AnyClassType]())) @public - val permutatuonIn = IO(Input(Property[AnyClassType]())) - permutatuon := permutatuonIn + val permutationIn = IO(Input(Property[AnyClassType]())) + permutation := permutationIn } object T1Parameter { @@ -403,7 +403,7 @@ class T1(val parameter: T1Parameter) val maskUnit: Instance[MaskUnit] = Instantiate(new MaskUnit(parameter)) maskUnit.io.clock := implicitClock maskUnit.io.reset := implicitReset - omInstance.permutatuonIn := Property(maskUnit.io.om.asAnyClassType) + omInstance.permutationIn := Property(maskUnit.io.om.asAnyClassType) val tokenManager: Instance[T1TokenManager] = Instantiate(new T1TokenManager(parameter)) diff --git a/t1/src/mask/MaskCompress.scala b/t1/src/mask/MaskCompress.scala index aef7421f5..2ece87ca9 100644 --- a/t1/src/mask/MaskCompress.scala +++ b/t1/src/mask/MaskCompress.scala @@ -4,9 +4,26 @@ package org.chipsalliance.t1.rtl import chisel3._ +import chisel3.experimental.hierarchy.{instantiable, Instance, Instantiate} +import chisel3.experimental.{SerializableModule, SerializableModuleParameter} +import chisel3.properties.{AnyClassType, Path, Property} import chisel3.util._ +import org.chipsalliance.stdlib.GeneralOM -class CompressInput(parameter: T1Parameter) extends Bundle { +case class CompressParam( + datapathWidth: Int, + xLen: Int, + vLen: Int, + laneNumber: Int, + groupNumberBits: Int, + latency: Int) + extends SerializableModuleParameter + +object CompressParam { + implicit def rwP = upickle.default.macroRW[CompressParam] +} + +class CompressInput(parameter: CompressParam) extends Bundle { val maskType: Bool = Bool() val eew: UInt = UInt(2.W) val uop: UInt = UInt(3.W) @@ -14,26 +31,55 @@ class CompressInput(parameter: T1Parameter) extends Bundle { val source1: UInt = UInt(parameter.datapathWidth.W) val mask: UInt = UInt(parameter.datapathWidth.W) val source2: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W) - val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W) + val groupCounter: UInt = UInt(parameter.groupNumberBits.W) val ffoInput: UInt = UInt(parameter.laneNumber.W) val validInput: UInt = UInt(parameter.laneNumber.W) val lastCompress: Bool = Bool() } -class CompressOutput(parameter: T1Parameter) extends Bundle { +class CompressOutput(parameter: CompressParam) extends Bundle { val data: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W) val mask: UInt = UInt((parameter.laneNumber * parameter.datapathWidth / 8).W) - val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W) + val groupCounter: UInt = UInt(parameter.groupNumberBits.W) val ffoOutput: UInt = UInt(parameter.laneNumber.W) val compressValid: Bool = Bool() } -class MaskCompress(parameter: T1Parameter) extends Module { - val in: ValidIO[CompressInput] = IO(Flipped(Valid(new CompressInput(parameter)))) - val out: CompressOutput = IO(Output(new CompressOutput(parameter))) - val newInstruction: Bool = IO(Input(Bool())) - val ffoInstruction: Bool = IO(Input(Bool())) - val writeData: UInt = IO(Output(UInt(parameter.xLen.W))) +class MaskCompressInterFace(parameter: CompressParam) extends Bundle { + val clock = Input(Clock()) + val reset = Input(Reset()) + + val in: ValidIO[CompressInput] = Flipped(Valid(new CompressInput(parameter))) + val out: CompressOutput = Output(new CompressOutput(parameter)) + val newInstruction: Bool = Input(Bool()) + val ffoInstruction: Bool = Input(Bool()) + val writeData: UInt = Output(UInt(parameter.xLen.W)) + val om = Output(Property[AnyClassType]()) +} + +@instantiable +class MaskCompressOM(parameter: CompressParam) extends GeneralOM[CompressParam, MaskCompress](parameter) { + override def hasRetime: Boolean = true +} + +class MaskCompress(val parameter: CompressParam) + extends FixedIORawModule(new MaskCompressInterFace(parameter)) + with SerializableModule[CompressParam] + with ImplicitClock + with ImplicitReset { + + protected def implicitClock = io.clock + protected def implicitReset = io.reset + + val omInstance: Instance[MaskCompressOM] = Instantiate(new MaskCompressOM(parameter)) + io.om := omInstance.getPropertyReference + omInstance.retimeIn.foreach(_ := Property(Path(io.clock))) + + val in = io.in + val out = io.out + val newInstruction = io.newInstruction + val ffoInstruction = io.ffoInstruction + val writeData = io.writeData val maskSize: Int = parameter.laneNumber * parameter.datapathWidth / 8 @@ -122,7 +168,7 @@ class MaskCompress(parameter: T1Parameter) extends Module { val compressDataReg = RegInit(0.U((parameter.laneNumber * parameter.datapathWidth).W)) val compressTailValid: Bool = RegInit(false.B) - val compressWriteGroupCount: UInt = RegInit(0.U(parameter.laneParam.groupNumberBits.W)) + val compressWriteGroupCount: UInt = RegInit(0.U(parameter.groupNumberBits.W)) val compressDataVec = Seq(1, 2, 4).map { dataByte => val dataBit = dataByte * 8 val elementSizePerSet = parameter.laneNumber * parameter.datapathWidth / 8 / dataByte @@ -238,5 +284,5 @@ class MaskCompress(parameter: T1Parameter) extends Module { ffoIndex := source1SigExtend } outWire.ffoOutput := completedLeftOr | Fill(parameter.laneNumber, ffoValid) - out := RegNext(outWire, 0.U.asTypeOf(out)) + out := Pipe(true.B, outWire, parameter.latency).bits } diff --git a/t1/src/mask/MaskUnit.scala b/t1/src/mask/MaskUnit.scala index 76cd556d2..af21abaec 100644 --- a/t1/src/mask/MaskUnit.scala +++ b/t1/src/mask/MaskUnit.scala @@ -84,11 +84,15 @@ class MaskUnitInterface(parameter: T1Parameter) extends Bundle { @instantiable class MaskUnitOM(parameter: T1Parameter) extends GeneralOM[T1Parameter, MaskUnit](parameter) { - @public val reduceUnit = IO(Output(Property[AnyClassType]())) @public val reduceUnitIn = IO(Input(Property[AnyClassType]())) reduceUnit := reduceUnitIn + + val compress = IO(Output(Property[AnyClassType]())) + @public + val compressIn = IO(Input(Property[AnyClassType]())) + compress := compressIn } // TODO: no T1Parameter here. @@ -898,14 +902,24 @@ class MaskUnit(val parameter: T1Parameter) // Determine whether the data is ready val executeEnqValid: Bool = otherTypeRequestDeq && !readType + val compressParam: CompressParam = CompressParam( + parameter.datapathWidth, + parameter.xLen, + parameter.vLen, + parameter.laneNumber, + parameter.laneParam.groupNumberBits, + 1 + ) // start execute - val compressUnit: MaskCompress = Module(new MaskCompress(parameter)) - val reduceUnit = Instantiate( + val compressUnit = Instantiate(new MaskCompress(compressParam)) + val reduceUnit = Instantiate( new MaskReduce( MaskReduceParameter(parameter.datapathWidth, parameter.laneNumber, parameter.fpuEnable) ) ) omInstance.reduceUnitIn := reduceUnit.io.om.asAnyClassType + omInstance.compressIn := compressUnit.io.om.asAnyClassType + val extendUnit: MaskExtend = Module(new MaskExtend(parameter)) // todo @@ -935,28 +949,30 @@ class MaskUnit(val parameter: T1Parameter) val compressSource1: UInt = Mux1H(sew1H, vs1Split.map(_._1)) val source1Select: UInt = Mux(mv, readVS1Reg.data, compressSource1) val source1Change: Bool = Mux1H(sew1H, vs1Split.map(_._2)) - when(source1Change && compressUnit.in.fire) { + when(source1Change && compressUnit.io.in.fire) { readVS1Reg.dataValid := false.B readVS1Reg.requestSend := false.B readVS1Reg.readIndex := readVS1Reg.readIndex + 1.U } - viotaCounterAdd := compressUnit.in.fire - - compressUnit.in.valid := executeEnqValid && unitType(1) - compressUnit.in.bits.maskType := instReg.maskType - compressUnit.in.bits.eew := instReg.sew - compressUnit.in.bits.uop := instReg.decodeResult(Decoder.topUop) - compressUnit.in.bits.readFromScalar := instReg.readFromScala - compressUnit.in.bits.source1 := source1Select - compressUnit.in.bits.mask := executeElementMask - compressUnit.in.bits.source2 := source2 - compressUnit.in.bits.groupCounter := requestCounter - compressUnit.in.bits.lastCompress := lastGroup - compressUnit.in.bits.ffoInput := VecInit(exeReqReg.map(_.bits.ffo)).asUInt - compressUnit.in.bits.validInput := VecInit(exeReqReg.map(_.valid)).asUInt - compressUnit.newInstruction := instReq.valid - compressUnit.ffoInstruction := instReq.bits.decodeResult(Decoder.topUop)(2, 0) === BitPat("b11?") + viotaCounterAdd := compressUnit.io.in.fire + + compressUnit.io.clock := implicitClock + compressUnit.io.reset := implicitReset + compressUnit.io.in.valid := executeEnqValid && unitType(1) + compressUnit.io.in.bits.maskType := instReg.maskType + compressUnit.io.in.bits.eew := instReg.sew + compressUnit.io.in.bits.uop := instReg.decodeResult(Decoder.topUop) + compressUnit.io.in.bits.readFromScalar := instReg.readFromScala + compressUnit.io.in.bits.source1 := source1Select + compressUnit.io.in.bits.mask := executeElementMask + compressUnit.io.in.bits.source2 := source2 + compressUnit.io.in.bits.groupCounter := requestCounter + compressUnit.io.in.bits.lastCompress := lastGroup + compressUnit.io.in.bits.ffoInput := VecInit(exeReqReg.map(_.bits.ffo)).asUInt + compressUnit.io.in.bits.validInput := VecInit(exeReqReg.map(_.valid)).asUInt + compressUnit.io.newInstruction := instReq.valid + compressUnit.io.ffoInstruction := instReq.bits.decodeResult(Decoder.topUop)(2, 0) === BitPat("b11?") reduceUnit.io.clock := implicitClock reduceUnit.io.reset := implicitReset @@ -980,7 +996,7 @@ class MaskUnit(val parameter: T1Parameter) sink := VecInit(exeReqReg.map(_.bits.fpReduceValid.get)).asUInt } - when(reduceUnit.io.in.fire || compressUnit.in.fire) { + when(reduceUnit.io.in.fire || compressUnit.io.in.fire) { readVS1Reg.sendToExecution := true.B } @@ -1001,7 +1017,7 @@ class MaskUnit(val parameter: T1Parameter) val executeResult: UInt = Mux1H( unitType(3, 1), Seq( - compressUnit.out.data, + compressUnit.io.out.data, reduceUnit.io.out.bits.data, extendUnit.out ) @@ -1021,7 +1037,7 @@ class MaskUnit(val parameter: T1Parameter) val executeValid: Bool = Mux1H( unitType(3, 1), Seq( - compressUnit.out.compressValid, + compressUnit.io.out.compressValid, false.B, executeEnqValid ) @@ -1039,13 +1055,13 @@ class MaskUnit(val parameter: T1Parameter) val executeDeqGroupCounter: UInt = Mux1H( unitType(3, 1), Seq( - compressUnit.out.groupCounter, + compressUnit.io.out.groupCounter, requestCounter, extendGroupCount ) ) - val executeWriteByteMask: UInt = Mux(compress || ffo || mvVd, compressUnit.out.mask, executeByteMask) + val executeWriteByteMask: UInt = Mux(compress || ffo || mvVd, compressUnit.io.out.mask, executeByteMask) maskedWrite.needWAR := maskDestinationType maskedWrite.vd := instReg.vd maskedWrite.in.zipWithIndex.foreach { case (req, index) => @@ -1057,7 +1073,7 @@ class MaskUnit(val parameter: T1Parameter) req.bits.pipeData := exeReqReg(index).bits.source1 req.bits.bitMask := bitMask req.bits.groupCounter := executeDeqGroupCounter - req.bits.ffoByOther := compressUnit.out.ffoOutput(index) && ffo + req.bits.ffoByOther := compressUnit.io.out.ffoOutput(index) && ffo if (index == 0) { // reduce result when(unitType(2)) { @@ -1117,7 +1133,7 @@ class MaskUnit(val parameter: T1Parameter) val executeStageInvalid: Bool = Mux1H( unitType(3, 1), Seq( - !compressUnit.out.compressValid, + !compressUnit.io.out.compressValid, reduceUnit.io.in.ready, true.B ) @@ -1136,7 +1152,7 @@ class MaskUnit(val parameter: T1Parameter) lastReportValid, indexToOH(instReg.instructionIndex, parameter.chainingSize) ) - writeRDData := Mux(pop, reduceUnit.io.out.bits.data, compressUnit.writeData) + writeRDData := Mux(pop, reduceUnit.io.out.bits.data, compressUnit.io.writeData) // gather read state when(gatherRequestFire) {