Skip to content

Commit

Permalink
[rtl] retime compress unit.
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li authored and sequencer committed Dec 20, 2024
1 parent 10ad9f4 commit 4069f28
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 54 deletions.
13 changes: 9 additions & 4 deletions omreaderlib/src/t1/T1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@ class T1(val mlirbc: Array[Byte]) extends T1OMReaderAPI {
def sram: Seq[SRAM] =
t1("lanes").list.elements().map(_.obj("vrf").obj).flatMap(getSRAM)

def floatAdder = {
val reduceUnit = t1("permutatuon").obj("reduceUnit").obj
def permutation: Seq[Retime] = {
val permutation = t1("permutation")
val reduceUnit = permutation.obj("reduceUnit").obj
val compressUnit = permutation.obj("compress").obj
// TODO: need fieldOpt(name: String)
Option.when(reduceUnit.fieldNames().contains("floatAdder"))(reduceUnit("floatAdder").obj).flatMap(getRetime)
val floatAdder =
Option.when(reduceUnit.fieldNames().contains("floatAdder"))(reduceUnit("floatAdder").obj)

(Seq(compressUnit) ++ floatAdder).flatMap(getRetime)
}

def vfus: Seq[Retime] =
t1("lanes").list.elements().map(_.obj("vfus")).flatMap(_.list.elements().map(_.obj)).flatMap(getRetime)

def retime = (vfus ++ floatAdder).distinct
def retime = (vfus ++ permutation).distinct
}
19 changes: 13 additions & 6 deletions omreaderlib/src/t1rocketv/T1RocketTile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,21 @@ class T1RocketTile(val mlirbc: Array[Byte]) extends T1OMReaderAPI {
t1("lanes").list.elements().map(_.obj("vrf").obj).flatMap(getSRAM)
def cache: Seq[SRAM] =
Seq(tile("frontend").obj("icache").obj, tile("hellaCache").obj).flatMap(getSRAM)
def vfu: Seq[Retime] =
t1("lanes").list.elements().map(_.obj("vfus")).flatMap(_.list.elements().map(_.obj)).flatMap(getRetime)
def floatAdder = {
val reduceUnit = t1("permutatuon").obj("reduceUnit").obj

def permutation: Seq[Retime] = {
val permutation = t1("permutation")
val reduceUnit = permutation.obj("reduceUnit").obj
val compressUnit = permutation.obj("compress").obj
// TODO: need fieldOpt(name: String)
Option.when(reduceUnit.fieldNames().contains("floatAdder"))(reduceUnit("floatAdder").obj).flatMap(getRetime)
val floatAdder =
Option.when(reduceUnit.fieldNames().contains("floatAdder"))(reduceUnit("floatAdder").obj)

(Seq(compressUnit) ++ floatAdder).flatMap(getRetime)
}

def retime = (vfu ++ floatAdder).distinct
def vfus: Seq[Retime] =
t1("lanes").list.elements().map(_.obj("vfus")).flatMap(_.list.elements().map(_.obj)).flatMap(getRetime)

def retime = (vfus ++ permutation).distinct
def sram = vrf ++ cache
}
8 changes: 4 additions & 4 deletions t1/src/T1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ class T1OM(parameter: T1Parameter) extends GeneralOM[T1Parameter, T1](parameter)
val decoderIn = IO(Input(Property[AnyClassType]()))
decoder := decoderIn

val permutatuon = IO(Output(Property[AnyClassType]()))
val permutation = IO(Output(Property[AnyClassType]()))
@public
val permutatuonIn = IO(Input(Property[AnyClassType]()))
permutatuon := permutatuonIn
val permutationIn = IO(Input(Property[AnyClassType]()))
permutation := permutationIn
}

object T1Parameter {
Expand Down Expand Up @@ -403,7 +403,7 @@ class T1(val parameter: T1Parameter)
val maskUnit: Instance[MaskUnit] = Instantiate(new MaskUnit(parameter))
maskUnit.io.clock := implicitClock
maskUnit.io.reset := implicitReset
omInstance.permutatuonIn := Property(maskUnit.io.om.asAnyClassType)
omInstance.permutationIn := Property(maskUnit.io.om.asAnyClassType)

val tokenManager: Instance[T1TokenManager] = Instantiate(new T1TokenManager(parameter))

Expand Down
70 changes: 58 additions & 12 deletions t1/src/mask/MaskCompress.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,82 @@
package org.chipsalliance.t1.rtl

import chisel3._
import chisel3.experimental.hierarchy.{instantiable, Instance, Instantiate}
import chisel3.experimental.{SerializableModule, SerializableModuleParameter}
import chisel3.properties.{AnyClassType, Path, Property}
import chisel3.util._
import org.chipsalliance.stdlib.GeneralOM

class CompressInput(parameter: T1Parameter) extends Bundle {
case class CompressParam(
datapathWidth: Int,
xLen: Int,
vLen: Int,
laneNumber: Int,
groupNumberBits: Int,
latency: Int)
extends SerializableModuleParameter

object CompressParam {
implicit def rwP = upickle.default.macroRW[CompressParam]
}

class CompressInput(parameter: CompressParam) extends Bundle {
val maskType: Bool = Bool()
val eew: UInt = UInt(2.W)
val uop: UInt = UInt(3.W)
val readFromScalar: UInt = UInt(parameter.datapathWidth.W)
val source1: UInt = UInt(parameter.datapathWidth.W)
val mask: UInt = UInt(parameter.datapathWidth.W)
val source2: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W)
val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W)
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
val ffoInput: UInt = UInt(parameter.laneNumber.W)
val validInput: UInt = UInt(parameter.laneNumber.W)
val lastCompress: Bool = Bool()
}

class CompressOutput(parameter: T1Parameter) extends Bundle {
class CompressOutput(parameter: CompressParam) extends Bundle {
val data: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W)
val mask: UInt = UInt((parameter.laneNumber * parameter.datapathWidth / 8).W)
val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W)
val groupCounter: UInt = UInt(parameter.groupNumberBits.W)
val ffoOutput: UInt = UInt(parameter.laneNumber.W)
val compressValid: Bool = Bool()
}

class MaskCompress(parameter: T1Parameter) extends Module {
val in: ValidIO[CompressInput] = IO(Flipped(Valid(new CompressInput(parameter))))
val out: CompressOutput = IO(Output(new CompressOutput(parameter)))
val newInstruction: Bool = IO(Input(Bool()))
val ffoInstruction: Bool = IO(Input(Bool()))
val writeData: UInt = IO(Output(UInt(parameter.xLen.W)))
class MaskCompressInterFace(parameter: CompressParam) extends Bundle {
val clock = Input(Clock())
val reset = Input(Reset())

val in: ValidIO[CompressInput] = Flipped(Valid(new CompressInput(parameter)))
val out: CompressOutput = Output(new CompressOutput(parameter))
val newInstruction: Bool = Input(Bool())
val ffoInstruction: Bool = Input(Bool())
val writeData: UInt = Output(UInt(parameter.xLen.W))
val om = Output(Property[AnyClassType]())
}

@instantiable
class MaskCompressOM(parameter: CompressParam) extends GeneralOM[CompressParam, MaskCompress](parameter) {
override def hasRetime: Boolean = true
}

class MaskCompress(val parameter: CompressParam)
extends FixedIORawModule(new MaskCompressInterFace(parameter))
with SerializableModule[CompressParam]
with ImplicitClock
with ImplicitReset {

protected def implicitClock = io.clock
protected def implicitReset = io.reset

val omInstance: Instance[MaskCompressOM] = Instantiate(new MaskCompressOM(parameter))
io.om := omInstance.getPropertyReference
omInstance.retimeIn.foreach(_ := Property(Path(io.clock)))

val in = io.in
val out = io.out
val newInstruction = io.newInstruction
val ffoInstruction = io.ffoInstruction
val writeData = io.writeData

val maskSize: Int = parameter.laneNumber * parameter.datapathWidth / 8

Expand Down Expand Up @@ -122,7 +168,7 @@ class MaskCompress(parameter: T1Parameter) extends Module {

val compressDataReg = RegInit(0.U((parameter.laneNumber * parameter.datapathWidth).W))
val compressTailValid: Bool = RegInit(false.B)
val compressWriteGroupCount: UInt = RegInit(0.U(parameter.laneParam.groupNumberBits.W))
val compressWriteGroupCount: UInt = RegInit(0.U(parameter.groupNumberBits.W))
val compressDataVec = Seq(1, 2, 4).map { dataByte =>
val dataBit = dataByte * 8
val elementSizePerSet = parameter.laneNumber * parameter.datapathWidth / 8 / dataByte
Expand Down Expand Up @@ -238,5 +284,5 @@ class MaskCompress(parameter: T1Parameter) extends Module {
ffoIndex := source1SigExtend
}
outWire.ffoOutput := completedLeftOr | Fill(parameter.laneNumber, ffoValid)
out := RegNext(outWire, 0.U.asTypeOf(out))
out := Pipe(true.B, outWire, parameter.latency).bits
}
72 changes: 44 additions & 28 deletions t1/src/mask/MaskUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,15 @@ class MaskUnitInterface(parameter: T1Parameter) extends Bundle {

@instantiable
class MaskUnitOM(parameter: T1Parameter) extends GeneralOM[T1Parameter, MaskUnit](parameter) {
@public
val reduceUnit = IO(Output(Property[AnyClassType]()))
@public
val reduceUnitIn = IO(Input(Property[AnyClassType]()))
reduceUnit := reduceUnitIn

val compress = IO(Output(Property[AnyClassType]()))
@public
val compressIn = IO(Input(Property[AnyClassType]()))
compress := compressIn
}

// TODO: no T1Parameter here.
Expand Down Expand Up @@ -898,14 +902,24 @@ class MaskUnit(val parameter: T1Parameter)
// Determine whether the data is ready
val executeEnqValid: Bool = otherTypeRequestDeq && !readType

val compressParam: CompressParam = CompressParam(
parameter.datapathWidth,
parameter.xLen,
parameter.vLen,
parameter.laneNumber,
parameter.laneParam.groupNumberBits,
1
)
// start execute
val compressUnit: MaskCompress = Module(new MaskCompress(parameter))
val reduceUnit = Instantiate(
val compressUnit = Instantiate(new MaskCompress(compressParam))
val reduceUnit = Instantiate(
new MaskReduce(
MaskReduceParameter(parameter.datapathWidth, parameter.laneNumber, parameter.fpuEnable)
)
)
omInstance.reduceUnitIn := reduceUnit.io.om.asAnyClassType
omInstance.compressIn := compressUnit.io.om.asAnyClassType

val extendUnit: MaskExtend = Module(new MaskExtend(parameter))

// todo
Expand Down Expand Up @@ -935,28 +949,30 @@ class MaskUnit(val parameter: T1Parameter)
val compressSource1: UInt = Mux1H(sew1H, vs1Split.map(_._1))
val source1Select: UInt = Mux(mv, readVS1Reg.data, compressSource1)
val source1Change: Bool = Mux1H(sew1H, vs1Split.map(_._2))
when(source1Change && compressUnit.in.fire) {
when(source1Change && compressUnit.io.in.fire) {
readVS1Reg.dataValid := false.B
readVS1Reg.requestSend := false.B
readVS1Reg.readIndex := readVS1Reg.readIndex + 1.U

}
viotaCounterAdd := compressUnit.in.fire

compressUnit.in.valid := executeEnqValid && unitType(1)
compressUnit.in.bits.maskType := instReg.maskType
compressUnit.in.bits.eew := instReg.sew
compressUnit.in.bits.uop := instReg.decodeResult(Decoder.topUop)
compressUnit.in.bits.readFromScalar := instReg.readFromScala
compressUnit.in.bits.source1 := source1Select
compressUnit.in.bits.mask := executeElementMask
compressUnit.in.bits.source2 := source2
compressUnit.in.bits.groupCounter := requestCounter
compressUnit.in.bits.lastCompress := lastGroup
compressUnit.in.bits.ffoInput := VecInit(exeReqReg.map(_.bits.ffo)).asUInt
compressUnit.in.bits.validInput := VecInit(exeReqReg.map(_.valid)).asUInt
compressUnit.newInstruction := instReq.valid
compressUnit.ffoInstruction := instReq.bits.decodeResult(Decoder.topUop)(2, 0) === BitPat("b11?")
viotaCounterAdd := compressUnit.io.in.fire

compressUnit.io.clock := implicitClock
compressUnit.io.reset := implicitReset
compressUnit.io.in.valid := executeEnqValid && unitType(1)
compressUnit.io.in.bits.maskType := instReg.maskType
compressUnit.io.in.bits.eew := instReg.sew
compressUnit.io.in.bits.uop := instReg.decodeResult(Decoder.topUop)
compressUnit.io.in.bits.readFromScalar := instReg.readFromScala
compressUnit.io.in.bits.source1 := source1Select
compressUnit.io.in.bits.mask := executeElementMask
compressUnit.io.in.bits.source2 := source2
compressUnit.io.in.bits.groupCounter := requestCounter
compressUnit.io.in.bits.lastCompress := lastGroup
compressUnit.io.in.bits.ffoInput := VecInit(exeReqReg.map(_.bits.ffo)).asUInt
compressUnit.io.in.bits.validInput := VecInit(exeReqReg.map(_.valid)).asUInt
compressUnit.io.newInstruction := instReq.valid
compressUnit.io.ffoInstruction := instReq.bits.decodeResult(Decoder.topUop)(2, 0) === BitPat("b11?")

reduceUnit.io.clock := implicitClock
reduceUnit.io.reset := implicitReset
Expand All @@ -980,7 +996,7 @@ class MaskUnit(val parameter: T1Parameter)
sink := VecInit(exeReqReg.map(_.bits.fpReduceValid.get)).asUInt
}

when(reduceUnit.io.in.fire || compressUnit.in.fire) {
when(reduceUnit.io.in.fire || compressUnit.io.in.fire) {
readVS1Reg.sendToExecution := true.B
}

Expand All @@ -1001,7 +1017,7 @@ class MaskUnit(val parameter: T1Parameter)
val executeResult: UInt = Mux1H(
unitType(3, 1),
Seq(
compressUnit.out.data,
compressUnit.io.out.data,
reduceUnit.io.out.bits.data,
extendUnit.out
)
Expand All @@ -1021,7 +1037,7 @@ class MaskUnit(val parameter: T1Parameter)
val executeValid: Bool = Mux1H(
unitType(3, 1),
Seq(
compressUnit.out.compressValid,
compressUnit.io.out.compressValid,
false.B,
executeEnqValid
)
Expand All @@ -1039,13 +1055,13 @@ class MaskUnit(val parameter: T1Parameter)
val executeDeqGroupCounter: UInt = Mux1H(
unitType(3, 1),
Seq(
compressUnit.out.groupCounter,
compressUnit.io.out.groupCounter,
requestCounter,
extendGroupCount
)
)

val executeWriteByteMask: UInt = Mux(compress || ffo || mvVd, compressUnit.out.mask, executeByteMask)
val executeWriteByteMask: UInt = Mux(compress || ffo || mvVd, compressUnit.io.out.mask, executeByteMask)
maskedWrite.needWAR := maskDestinationType
maskedWrite.vd := instReg.vd
maskedWrite.in.zipWithIndex.foreach { case (req, index) =>
Expand All @@ -1057,7 +1073,7 @@ class MaskUnit(val parameter: T1Parameter)
req.bits.pipeData := exeReqReg(index).bits.source1
req.bits.bitMask := bitMask
req.bits.groupCounter := executeDeqGroupCounter
req.bits.ffoByOther := compressUnit.out.ffoOutput(index) && ffo
req.bits.ffoByOther := compressUnit.io.out.ffoOutput(index) && ffo
if (index == 0) {
// reduce result
when(unitType(2)) {
Expand Down Expand Up @@ -1117,7 +1133,7 @@ class MaskUnit(val parameter: T1Parameter)
val executeStageInvalid: Bool = Mux1H(
unitType(3, 1),
Seq(
!compressUnit.out.compressValid,
!compressUnit.io.out.compressValid,
reduceUnit.io.in.ready,
true.B
)
Expand All @@ -1136,7 +1152,7 @@ class MaskUnit(val parameter: T1Parameter)
lastReportValid,
indexToOH(instReg.instructionIndex, parameter.chainingSize)
)
writeRDData := Mux(pop, reduceUnit.io.out.bits.data, compressUnit.writeData)
writeRDData := Mux(pop, reduceUnit.io.out.bits.data, compressUnit.io.writeData)

// gather read state
when(gatherRequestFire) {
Expand Down

0 comments on commit 4069f28

Please sign in to comment.