Skip to content

Commit

Permalink
Fkyd/xdma snax integration (#199)
Browse files Browse the repository at this point in the history
* xdma wrapper template

* Upgrade xdmaTop generator

* Another bug fix in xdmaTop generator

* integrated generation of xdma

* add the xdma configuration

* Renaming major module of xdma

* xdma integration

* Compiled, not tested

* xdma runtime finished, untest ed

* Small bug fix + debug function

* Fix typo in xdma-memset.c and update xdma CSR address calculation

* Fix xdma_start() to return task ID and update printf statement

* bug-fix: Update snitch_cluster_wrapper.sv.tpl and snax-xdma-lib.c

* chore: Refactor xdma unit test job name and add system test job

* Formatting python script

* Add license header

* Fix formatting of C code

* Fix formatting of scala code

* Fix ci format

* chore: Add generated xdma files to SNAX_GEN list in Makefile

* chore: Update ci.yml to include snax-xdma-run.yaml in the run command

* chore: Remove unnecessary code in AddressGenUnit.scala

* chore: Update .gitignore to include missing newline at end of file

* Move ArgParser to utils

* chore: Update snitch_cluster_wrapper.sv.tpl to remove the incrementation of pref_snax_count at the generation of xdma

* Improve the comment

* Improve comments

* chore: Refactor xdma error handling

* chore: Update ci.yml and remove erroneous Chisel Generation CI

* Update scala unit test
  • Loading branch information
IveanEx authored Jul 23, 2024
1 parent 5ad7aa2 commit b933fce
Show file tree
Hide file tree
Showing 34 changed files with 1,548 additions and 120 deletions.
37 changes: 19 additions & 18 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -318,31 +318,32 @@ jobs:
sw/runtime.yaml \
sw/snax-streamer-gemm-conv-simd-run.yaml -j
snax-xdma-unittest:
name: Run several unit tests for xdma
snax-xdma-vlt-generic:
name: Simulate SW on xdma w/ Verilator (Generic LLVM)
runs-on: ubuntu-22.04
container:
image: ghcr.io/kuleuven-micas/snax:main
steps:
- uses: actions/checkout@v2
with:
submodules: 'recursive'
- name: Test xdma streamer
working-directory: hw/chisel
run: |-
sbt "testOnly snax.xdma.xdmaStreamer.*"
- name: Test xdma DataPath and Controller
working-directory: hw/chisel
run: |-
sbt "testOnly snax.xdma.xdmaFrontend.*"
- name: Test xdma Extension Framework + MaxPool
working-directory: hw/chisel
run: |-
sbt "testOnly snax.xdma.xdmaExtension.*"
- name: Test xdma Top Module
working-directory: hw/chisel
submodules: "recursive"
- name: Build Hardware
run: |
make CFG_OVERRIDE=cfg/snax-streamer-gemmX-xdma.hjson \
-C target/snitch_cluster bin/snitch_cluster.vlt -j$(nproc)
- name: Build Software
run: |
make -C target/snitch_cluster sw \
CFG_OVERRIDE=cfg/snax-streamer-gemmX-xdma.hjson \
SELECT_RUNTIME=rtl-generic \
SELECT_TOOLCHAIN=llvm-generic
- name: Run Tests
working-directory: target/snitch_cluster
run: |-
sbt "testOnly snax.xdma.xdmaTop.*"
./run.py --simulator verilator \
sw/runtime.yaml \
sw/snax-streamer-gemm-conv-simd-run.yaml \
sw/snax-xdma-run.yaml -j
############################################
# Build SW on Snitch Cluster w/ Banshee #
Expand Down
29 changes: 0 additions & 29 deletions .github/workflows/gen.yml

This file was deleted.

16 changes: 16 additions & 0 deletions .github/workflows/scala-unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,19 @@ jobs:
working-directory: hw/chisel
run: |
mill Snax.test
- name: Test xdma streamer
working-directory: hw/chisel
run: |-
sbt "testOnly snax.xdma.xdmaStreamer.*"
- name: Test xdma DataPath and Controller
working-directory: hw/chisel
run: |-
sbt "testOnly snax.xdma.xdmaFrontend.*"
- name: Test xdma Extension Framework + MaxPool
working-directory: hw/chisel
run: |-
sbt "testOnly snax.xdma.xdmaExtension.*"
- name: Test xdma Top Module
working-directory: hw/chisel
run: |-
sbt "testOnly snax.xdma.xdmaTop.*"
23 changes: 23 additions & 0 deletions Bender.yml
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,26 @@ sources:
# Level 2
- target/snitch_cluster/generated/snax_streamer_gemmX/snax_streamer_gemmX_wrapper.sv

- target: snax_streamer_gemmX_xdma
files:
# Level 0
- hw/chisel_acc/generated/gemmx/BlockGemmRescaleSIMD.sv
- hw/chisel_acc/src/snax_streamer_gemmX_shell_wrapper.sv
- target/snitch_cluster/generated/snax_streamer_gemmX/snax_streamer_gemmX_csrman_CsrManager.sv
- target/snitch_cluster/generated/snax_streamer_gemmX/snax_streamer_gemmX_streamer_StreamerTop.sv
# Level 1
- target/snitch_cluster/generated/snax_streamer_gemmX/snax_streamer_gemmX_csrman_wrapper.sv
- target/snitch_cluster/generated/snax_streamer_gemmX/snax_streamer_gemmX_streamer_wrapper.sv
# Level 2
- target/snitch_cluster/generated/snax_streamer_gemmX/snax_streamer_gemmX_wrapper.sv

# xdma source
# Level 0
- target/snitch_cluster/generated/snax_streamer_gemmX_xdma_cluster_xdma/snax_streamer_gemmX_xdma_cluster_xdma.sv
# Level 1
- target/snitch_cluster/generated/snax_streamer_gemmX_xdma_cluster_xdma/snax_streamer_gemmX_xdma_cluster_xdma_wrapper.sv


- target: test
files:
- hw/snitch_cluster/test/snitch_tcdm_interconnect_tb.sv
Expand Down Expand Up @@ -416,6 +436,9 @@ sources:
- target: snax_streamer_gemmX
files:
- target/snitch_cluster/generated/snax_streamer_gemmX_cluster_wrapper.sv
- target: snax_streamer_gemmX_xdma
files:
- target/snitch_cluster/generated/snax_streamer_gemmX_xdma_cluster_wrapper.sv
- target: snax_streamer_gemm_add_c
files:
- target/snitch_cluster/generated/snax_streamer_gemm_add_c_cluster_wrapper.sv
Expand Down
32 changes: 32 additions & 0 deletions hw/chisel/src/main/scala/snax/utils/ArgParser.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package snax.utils

object ArgParser {

/*
* Function to parse the arguments provided to the program
* Arguments are expected to be in the form of --arg_name arg_value or --arg_name
* Returns a map of argument names to their values
*/
def parse(args: Array[String]): collection.mutable.Map[String, String] = {
val parsed_args = collection.mutable.Map[String, String]()
var i = 0
while (i < args.length) {
if (args(i)(0) == '-' && args(i)(1) == '-') {
if (
i == args.length - 1 || (args(i + 1)(0) == '-' && args(i + 1)(
1
) == '-')
) {
// Last argument or next argument is also a flag
parsed_args(args(i).substring(2)) = "NoArg"
} else parsed_args(args(i).substring(2)) = args(i + 1)
}
i += 1
}
if (parsed_args.size == 0) {
println("No arguments provided. Please provide arguments")
sys.exit(1)
}
parsed_args
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import chisel3.util._
import snax.utils._
import snax.xdma.CommonCells._
import snax.xdma.DesignParams._
import os.copy.over

/** The parent (abstract) Class for the DMA Extension Generation Params This
* class template is used to isolate the definition of class (when user provide
Expand All @@ -24,7 +25,8 @@ abstract class HasDMAExtension {
implicit val extensionParam: DMAExtensionParam

def totalCsrNum = extensionParam.userCsrNum + 1
def instantiate: DMAExtension
def namePostfix = "_xdma_extension_" + extensionParam.moduleName
def instantiate(clusterName: String): DMAExtension
}

/** The parent (abstract) Class for the DMA Extension Implementation (Circuit)
Expand All @@ -45,8 +47,6 @@ abstract class DMAExtension(implicit extensionParam: DMAExtensionParam)
extends Module
with RequireAsyncReset {

override def desiredName: String = extensionParam.moduleName

val io = IO(new Bundle {
val csr_i = Input(
Vec(extensionParam.userCsrNum + 1, UInt(32.W))
Expand Down Expand Up @@ -74,7 +74,9 @@ abstract class DMAExtension(implicit extensionParam: DMAExtensionParam)

// Structure to bypass extension: Demux
private[this] val inputDemux = Module(
new DemuxDecoupled(UInt(extensionParam.dataWidth.W), numOutput = 2)
new DemuxDecoupled(UInt(extensionParam.dataWidth.W), numOutput = 2) {
override def desiredName = s"xdma_extension_inputDemux"
}
)
inputDemux.io.sel := bypass
inputDemux.io.in <> io.data_i
Expand All @@ -85,7 +87,9 @@ abstract class DMAExtension(implicit extensionParam: DMAExtensionParam)

// Structure to bypass extension: Mux
private[this] val outputMux = Module(
new MuxDecoupled(UInt(extensionParam.dataWidth.W), numInput = 2)
new MuxDecoupled(UInt(extensionParam.dataWidth.W), numInput = 2) {
override def desiredName = s"xdma_extension_outputMux"
}
)
outputMux.io.sel := bypass
outputMux.io.out <> io.data_o
Expand Down
22 changes: 13 additions & 9 deletions hw/chisel/src/main/scala/snax/xdma/xdmaExtension/MaxPool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ object HasMaxPool extends HasDMAExtension {
userCsrNum = 1,
dataWidth = 512
)
def instantiate: MaxPool = Module(
new MaxPool(elementWidth = 8)
def instantiate(clusterName: String): MaxPool = Module(
new MaxPool(elementWidth = 8) {
override def desiredName = clusterName + namePostfix
}
)
}

Expand All @@ -43,11 +45,13 @@ class MaxPool(elementWidth: Int)(implicit extensionParam: DMAExtensionParam)

// Counter to record the steps
// 256-element MaxPool maximum
val counters = Module(new snax.xdma.xdmaStreamer.BasicCounter(8))
counters.io.ceil := ext_csr_i(0)
counters.io.reset := ext_start_i
counters.io.tick := ext_data_i.fire
ext_busy_o := counters.io.value =/= 0.U
val counter = Module(new snax.xdma.xdmaStreamer.BasicCounter(8) {
override val desiredName = "xdma_extension_MaxPoolCounter"
})
counter.io.ceil := ext_csr_i(0)
counter.io.reset := ext_start_i
counter.io.tick := ext_data_i.fire
ext_busy_o := counter.io.value =/= 0.U

// The wire to connect the output result
val ext_data_o_bits = Wire(
Expand All @@ -57,7 +61,7 @@ class MaxPool(elementWidth: Int)(implicit extensionParam: DMAExtensionParam)

val PEs = for (i <- 0 until extensionParam.dataWidth / elementWidth) yield {
val PE = Module(new MAXPoolPE(dataWidth = elementWidth))
PE.io.init_i := counters.io.value === 0.U
PE.io.init_i := counter.io.value === 0.U
PE.io.data_i.valid := ext_data_i.fire
PE.io.data_i.bits := ext_data_i
.bits((i + 1) * elementWidth - 1, i * elementWidth)
Expand All @@ -79,7 +83,7 @@ class MaxPool(elementWidth: Int)(implicit extensionParam: DMAExtensionParam)
// Under this condition, the system does not need to send the sum to the next stage
ext_data_i.ready := true.B
ext_data_o.valid := false.B
when(ext_data_i.fire && counters.io.lastVal) {
when(ext_data_i.fire && counter.io.lastVal) {
// The result is about to be ready, switching state to output
current_state := s_output
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ object HasMemset extends HasDMAExtension {
userCsrNum = 1,
dataWidth = 512
)
def instantiate: Memset = Module(new Memset)
def instantiate(clusterName: String): Memset = Module(new Memset {
override def desiredName = clusterName + namePostfix
})
}

class Memset()(implicit extensionParam: DMAExtensionParam)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ object HasTransposer extends HasDMAExtension {
dataWidth = 512
)

def instantiate: Transposer = Module(new Transposer)
def instantiate(clusterName: String): Transposer = Module(new Transposer {
override def desiredName = clusterName + namePostfix
})
}

class Transposer()(implicit extensionParam: DMAExtensionParam)
Expand Down
Loading

0 comments on commit b933fce

Please sign in to comment.