Skip to content

Commit

Permalink
Reduce repeated codes between Archs
Browse files Browse the repository at this point in the history
  • Loading branch information
n0gu-furiosa committed Apr 16, 2024
1 parent a142b49 commit e0c359b
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 65 deletions.
48 changes: 15 additions & 33 deletions device-api/src/arch/renegade.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::path::PathBuf;

use strum::IntoEnumIterator;
use strum_macros::EnumIter;

use crate::device::{DeviceCtrl, DeviceInner, DeviceMgmt, DevicePerf};
use crate::error::DeviceResult;
use crate::sysfs::npu_mgmt;
use crate::sysfs::npu_mgmt::{self, MgmtCache, MgmtFile, MgmtFileIO};
use crate::Arch;
use crate::ClockFrequency;
use crate::DeviceError;
Expand All @@ -17,21 +17,15 @@ pub struct RenegadeInner {
device_index: u8,
sysfs: PathBuf,
mgmt_root: PathBuf,
mgmt_cache: HashMap<StaticMgmtFile, String>,
mgmt_cache: MgmtCache<StaticMgmtFile>,
}

impl RenegadeInner {
pub fn new(arch: Arch, device_index: u8, sysfs: PathBuf) -> DeviceResult<Self> {
let mgmt_root = sysfs.join(format!(
"class/renegade_mgmt/renegade!npu{device_index}mgmt"
));
let m: DeviceResult<HashMap<_, _>> = StaticMgmtFile::iter()
.map(|key| {
let value = npu_mgmt::read_mgmt_to_string(&mgmt_root, key.filename())?;
Ok((key, value))
})
.collect();
let mgmt_cache = m?;
let mgmt_cache = MgmtCache::init(StaticMgmtFile::iter(), &mgmt_root)?;

Ok(RenegadeInner {
arch,
Expand All @@ -41,23 +35,11 @@ impl RenegadeInner {
mgmt_cache,
})
}
}

fn read_mgmt_to_string<P: AsRef<Path>>(&self, file: P) -> DeviceResult<String> {
npu_mgmt::read_mgmt_to_string(&self.mgmt_root, file).map_err(|e| e.into())
}

#[allow(dead_code)]
fn write_ctrl_file<P: AsRef<Path>>(&self, file: P, contents: &str) -> DeviceResult<()> {
let path = self.mgmt_root.join(file);
std::fs::write(path, contents)?;
Ok(())
}

fn get_mgmt_cache(&self, file: StaticMgmtFile) -> String {
self.mgmt_cache
.get(&file)
.unwrap_or(&Default::default())
.clone()
impl MgmtFileIO for RenegadeInner {
fn mgmt_root(&self) -> PathBuf {
self.mgmt_root.clone()
}
}

Expand All @@ -71,7 +53,7 @@ enum StaticMgmtFile {
Version,
}

impl StaticMgmtFile {
impl MgmtFile for StaticMgmtFile {
fn filename(&self) -> &'static str {
match self {
StaticMgmtFile::BusName => npu_mgmt::file::BUS_NAME,
Expand Down Expand Up @@ -111,27 +93,27 @@ impl DeviceMgmt for RenegadeInner {
}

fn busname(&self) -> String {
self.get_mgmt_cache(StaticMgmtFile::BusName)
self.mgmt_cache.get(&StaticMgmtFile::BusName)
}

fn pci_dev(&self) -> String {
self.get_mgmt_cache(StaticMgmtFile::Dev)
self.mgmt_cache.get(&StaticMgmtFile::Dev)
}

fn device_sn(&self) -> String {
self.get_mgmt_cache(StaticMgmtFile::DeviceSN)
self.mgmt_cache.get(&StaticMgmtFile::DeviceSN)
}

fn device_uuid(&self) -> String {
self.get_mgmt_cache(StaticMgmtFile::DeviceUUID)
self.mgmt_cache.get(&StaticMgmtFile::DeviceUUID)
}

fn firmware_version(&self) -> String {
self.get_mgmt_cache(StaticMgmtFile::FWVersion)
self.mgmt_cache.get(&StaticMgmtFile::FWVersion)
}

fn driver_version(&self) -> String {
self.get_mgmt_cache(StaticMgmtFile::Version)
self.mgmt_cache.get(&StaticMgmtFile::Version)
}

fn heartbeat(&self) -> DeviceResult<u32> {
Expand Down
47 changes: 15 additions & 32 deletions device-api/src/arch/warboy.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::path::PathBuf;

use strum::IntoEnumIterator;
use strum_macros::EnumIter;

use crate::device::{DeviceCtrl, DeviceInner, DeviceMgmt, DevicePerf};
use crate::error::DeviceResult;
use crate::perf_regs::PerformanceCounter;
use crate::sysfs::npu_mgmt;
use crate::sysfs::npu_mgmt::{self, MgmtCache, MgmtFile, MgmtFileIO};
use crate::{Arch, ClockFrequency, DeviceError, DeviceFile};

#[derive(Clone)]
Expand All @@ -16,19 +16,13 @@ pub struct WarboyInner {
device_index: u8,
sysfs: PathBuf,
mgmt_root: PathBuf,
mgmt_cache: HashMap<StaticMgmtFile, String>,
mgmt_cache: MgmtCache<StaticMgmtFile>,
}

impl WarboyInner {
pub fn new(arch: Arch, device_index: u8, sysfs: PathBuf) -> DeviceResult<Self> {
let mgmt_root = sysfs.join(format!("class/npu_mgmt/npu{device_index}_mgmt"));
let m: DeviceResult<HashMap<_, _>> = StaticMgmtFile::iter()
.map(|key| {
let value = npu_mgmt::read_mgmt_to_string(&mgmt_root, key.filename())?;
Ok((key, value))
})
.collect();
let mgmt_cache = m?;
let mgmt_cache = MgmtCache::init(StaticMgmtFile::iter(), &mgmt_root)?;

Ok(WarboyInner {
arch,
Expand All @@ -38,22 +32,11 @@ impl WarboyInner {
mgmt_cache,
})
}
}

fn read_mgmt_to_string<P: AsRef<Path>>(&self, file: P) -> DeviceResult<String> {
npu_mgmt::read_mgmt_to_string(&self.mgmt_root, file).map_err(|e| e.into())
}

fn write_ctrl_file<P: AsRef<Path>>(&self, file: P, contents: &str) -> DeviceResult<()> {
let path = self.mgmt_root.join(file);
std::fs::write(path, contents)?;
Ok(())
}

fn get_mgmt_cache(&self, file: StaticMgmtFile) -> String {
self.mgmt_cache
.get(&file)
.unwrap_or(&Default::default())
.clone()
impl MgmtFileIO for WarboyInner {
fn mgmt_root(&self) -> PathBuf {
self.mgmt_root.clone()
}
}

Expand All @@ -67,7 +50,7 @@ enum StaticMgmtFile {
Version,
}

impl StaticMgmtFile {
impl MgmtFile for StaticMgmtFile {
fn filename(&self) -> &'static str {
match self {
StaticMgmtFile::BusName => npu_mgmt::file::BUS_NAME,
Expand Down Expand Up @@ -112,27 +95,27 @@ impl DeviceMgmt for WarboyInner {
}

fn busname(&self) -> String {
self.get_mgmt_cache(StaticMgmtFile::BusName)
self.mgmt_cache.get(&StaticMgmtFile::BusName)
}

fn pci_dev(&self) -> String {
self.get_mgmt_cache(StaticMgmtFile::Dev)
self.mgmt_cache.get(&StaticMgmtFile::Dev)
}

fn device_sn(&self) -> String {
self.get_mgmt_cache(StaticMgmtFile::DeviceSN)
self.mgmt_cache.get(&StaticMgmtFile::DeviceSN)
}

fn device_uuid(&self) -> String {
self.get_mgmt_cache(StaticMgmtFile::DeviceUUID)
self.mgmt_cache.get(&StaticMgmtFile::DeviceUUID)
}

fn firmware_version(&self) -> String {
self.get_mgmt_cache(StaticMgmtFile::FWVersion)
self.mgmt_cache.get(&StaticMgmtFile::FWVersion)
}

fn driver_version(&self) -> String {
self.get_mgmt_cache(StaticMgmtFile::Version)
self.mgmt_cache.get(&StaticMgmtFile::Version)
}

fn heartbeat(&self) -> DeviceResult<u32> {
Expand Down
50 changes: 50 additions & 0 deletions device-api/src/sysfs.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
pub mod npu_mgmt {
use std::collections::HashMap;
use std::fs;
use std::hash::Hash;
use std::io;
use std::path::Path;
use std::path::PathBuf;



use crate::DeviceResult;

#[allow(dead_code)]
pub mod file {
Expand Down Expand Up @@ -115,6 +121,50 @@ pub mod npu_mgmt {
_ => None,
}
}

pub(crate) trait MgmtFileIO {
fn mgmt_root(&self) -> PathBuf;

fn read_mgmt_to_string<P: AsRef<Path>>(&self, file: P) -> DeviceResult<String> {
read_mgmt_to_string(self.mgmt_root(), file).map_err(|e| e.into())
}

fn write_ctrl_file<P: AsRef<Path>>(&self, file: P, contents: &str) -> DeviceResult<()> {
let path = &self.mgmt_root().join(file);
std::fs::write(path, contents)?;
Ok(())
}
}

pub(crate) trait MgmtFile {
fn filename(&self) -> &'static str;
}

#[derive(Clone)]
pub(crate) struct MgmtCache<K: Eq + Hash + MgmtFile> {
cache: HashMap<K, String>,
}

impl<K: Eq + Hash + MgmtFile> MgmtCache<K> {
pub fn init<P: AsRef<Path>>(
keys: impl Iterator<Item = K>,
mgmt_root: P,
) -> io::Result<Self> {
let cache: io::Result<HashMap<_, _>> = keys
.map(|key| {
let value = read_mgmt_to_string(&mgmt_root, key.filename())?;
Ok((key, value))
})
.collect();

let cache = cache?;
Ok(MgmtCache { cache })
}

pub fn get(&self, key: &K) -> String {
self.cache.get(key).unwrap_or(&Default::default()).clone()
}
}
}

// XXX(n0gu): warboy and renegade share the same implementation, but this may change in the future devices.
Expand Down

0 comments on commit e0c359b

Please sign in to comment.