From e0c359b1604d9b35eaf8342d18192a2e45dbf692 Mon Sep 17 00:00:00 2001 From: Sungyoon Jeong Date: Mon, 15 Apr 2024 21:49:32 +0900 Subject: [PATCH] Reduce repeated codes between Archs --- device-api/src/arch/renegade.rs | 48 ++++++++++--------------------- device-api/src/arch/warboy.rs | 47 ++++++++++--------------------- device-api/src/sysfs.rs | 50 +++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 65 deletions(-) diff --git a/device-api/src/arch/renegade.rs b/device-api/src/arch/renegade.rs index 411731c..f3f970f 100644 --- a/device-api/src/arch/renegade.rs +++ b/device-api/src/arch/renegade.rs @@ -1,12 +1,12 @@ use std::collections::HashMap; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use strum::IntoEnumIterator; use strum_macros::EnumIter; use crate::device::{DeviceCtrl, DeviceInner, DeviceMgmt, DevicePerf}; use crate::error::DeviceResult; -use crate::sysfs::npu_mgmt; +use crate::sysfs::npu_mgmt::{self, MgmtCache, MgmtFile, MgmtFileIO}; use crate::Arch; use crate::ClockFrequency; use crate::DeviceError; @@ -17,7 +17,7 @@ pub struct RenegadeInner { device_index: u8, sysfs: PathBuf, mgmt_root: PathBuf, - mgmt_cache: HashMap, + mgmt_cache: MgmtCache, } impl RenegadeInner { @@ -25,13 +25,7 @@ impl RenegadeInner { let mgmt_root = sysfs.join(format!( "class/renegade_mgmt/renegade!npu{device_index}mgmt" )); - let m: DeviceResult> = StaticMgmtFile::iter() - .map(|key| { - let value = npu_mgmt::read_mgmt_to_string(&mgmt_root, key.filename())?; - Ok((key, value)) - }) - .collect(); - let mgmt_cache = m?; + let mgmt_cache = MgmtCache::init(StaticMgmtFile::iter(), &mgmt_root)?; Ok(RenegadeInner { arch, @@ -41,23 +35,11 @@ impl RenegadeInner { mgmt_cache, }) } +} - fn read_mgmt_to_string>(&self, file: P) -> DeviceResult { - npu_mgmt::read_mgmt_to_string(&self.mgmt_root, file).map_err(|e| e.into()) - } - - #[allow(dead_code)] - fn write_ctrl_file>(&self, file: P, contents: &str) -> DeviceResult<()> { - let path = self.mgmt_root.join(file); - std::fs::write(path, contents)?; - Ok(()) - } - - fn get_mgmt_cache(&self, file: StaticMgmtFile) -> String { - self.mgmt_cache - .get(&file) - .unwrap_or(&Default::default()) - .clone() +impl MgmtFileIO for RenegadeInner { + fn mgmt_root(&self) -> PathBuf { + self.mgmt_root.clone() } } @@ -71,7 +53,7 @@ enum StaticMgmtFile { Version, } -impl StaticMgmtFile { +impl MgmtFile for StaticMgmtFile { fn filename(&self) -> &'static str { match self { StaticMgmtFile::BusName => npu_mgmt::file::BUS_NAME, @@ -111,27 +93,27 @@ impl DeviceMgmt for RenegadeInner { } fn busname(&self) -> String { - self.get_mgmt_cache(StaticMgmtFile::BusName) + self.mgmt_cache.get(&StaticMgmtFile::BusName) } fn pci_dev(&self) -> String { - self.get_mgmt_cache(StaticMgmtFile::Dev) + self.mgmt_cache.get(&StaticMgmtFile::Dev) } fn device_sn(&self) -> String { - self.get_mgmt_cache(StaticMgmtFile::DeviceSN) + self.mgmt_cache.get(&StaticMgmtFile::DeviceSN) } fn device_uuid(&self) -> String { - self.get_mgmt_cache(StaticMgmtFile::DeviceUUID) + self.mgmt_cache.get(&StaticMgmtFile::DeviceUUID) } fn firmware_version(&self) -> String { - self.get_mgmt_cache(StaticMgmtFile::FWVersion) + self.mgmt_cache.get(&StaticMgmtFile::FWVersion) } fn driver_version(&self) -> String { - self.get_mgmt_cache(StaticMgmtFile::Version) + self.mgmt_cache.get(&StaticMgmtFile::Version) } fn heartbeat(&self) -> DeviceResult { diff --git a/device-api/src/arch/warboy.rs b/device-api/src/arch/warboy.rs index ea9baed..6f7cd1d 100644 --- a/device-api/src/arch/warboy.rs +++ b/device-api/src/arch/warboy.rs @@ -1,5 +1,5 @@ use std::collections::HashMap; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -7,7 +7,7 @@ use strum_macros::EnumIter; use crate::device::{DeviceCtrl, DeviceInner, DeviceMgmt, DevicePerf}; use crate::error::DeviceResult; use crate::perf_regs::PerformanceCounter; -use crate::sysfs::npu_mgmt; +use crate::sysfs::npu_mgmt::{self, MgmtCache, MgmtFile, MgmtFileIO}; use crate::{Arch, ClockFrequency, DeviceError, DeviceFile}; #[derive(Clone)] @@ -16,19 +16,13 @@ pub struct WarboyInner { device_index: u8, sysfs: PathBuf, mgmt_root: PathBuf, - mgmt_cache: HashMap, + mgmt_cache: MgmtCache, } impl WarboyInner { pub fn new(arch: Arch, device_index: u8, sysfs: PathBuf) -> DeviceResult { let mgmt_root = sysfs.join(format!("class/npu_mgmt/npu{device_index}_mgmt")); - let m: DeviceResult> = StaticMgmtFile::iter() - .map(|key| { - let value = npu_mgmt::read_mgmt_to_string(&mgmt_root, key.filename())?; - Ok((key, value)) - }) - .collect(); - let mgmt_cache = m?; + let mgmt_cache = MgmtCache::init(StaticMgmtFile::iter(), &mgmt_root)?; Ok(WarboyInner { arch, @@ -38,22 +32,11 @@ impl WarboyInner { mgmt_cache, }) } +} - fn read_mgmt_to_string>(&self, file: P) -> DeviceResult { - npu_mgmt::read_mgmt_to_string(&self.mgmt_root, file).map_err(|e| e.into()) - } - - fn write_ctrl_file>(&self, file: P, contents: &str) -> DeviceResult<()> { - let path = self.mgmt_root.join(file); - std::fs::write(path, contents)?; - Ok(()) - } - - fn get_mgmt_cache(&self, file: StaticMgmtFile) -> String { - self.mgmt_cache - .get(&file) - .unwrap_or(&Default::default()) - .clone() +impl MgmtFileIO for WarboyInner { + fn mgmt_root(&self) -> PathBuf { + self.mgmt_root.clone() } } @@ -67,7 +50,7 @@ enum StaticMgmtFile { Version, } -impl StaticMgmtFile { +impl MgmtFile for StaticMgmtFile { fn filename(&self) -> &'static str { match self { StaticMgmtFile::BusName => npu_mgmt::file::BUS_NAME, @@ -112,27 +95,27 @@ impl DeviceMgmt for WarboyInner { } fn busname(&self) -> String { - self.get_mgmt_cache(StaticMgmtFile::BusName) + self.mgmt_cache.get(&StaticMgmtFile::BusName) } fn pci_dev(&self) -> String { - self.get_mgmt_cache(StaticMgmtFile::Dev) + self.mgmt_cache.get(&StaticMgmtFile::Dev) } fn device_sn(&self) -> String { - self.get_mgmt_cache(StaticMgmtFile::DeviceSN) + self.mgmt_cache.get(&StaticMgmtFile::DeviceSN) } fn device_uuid(&self) -> String { - self.get_mgmt_cache(StaticMgmtFile::DeviceUUID) + self.mgmt_cache.get(&StaticMgmtFile::DeviceUUID) } fn firmware_version(&self) -> String { - self.get_mgmt_cache(StaticMgmtFile::FWVersion) + self.mgmt_cache.get(&StaticMgmtFile::FWVersion) } fn driver_version(&self) -> String { - self.get_mgmt_cache(StaticMgmtFile::Version) + self.mgmt_cache.get(&StaticMgmtFile::Version) } fn heartbeat(&self) -> DeviceResult { diff --git a/device-api/src/sysfs.rs b/device-api/src/sysfs.rs index 589740e..8123bb1 100644 --- a/device-api/src/sysfs.rs +++ b/device-api/src/sysfs.rs @@ -1,8 +1,14 @@ pub mod npu_mgmt { use std::collections::HashMap; use std::fs; + use std::hash::Hash; use std::io; use std::path::Path; + use std::path::PathBuf; + + + + use crate::DeviceResult; #[allow(dead_code)] pub mod file { @@ -115,6 +121,50 @@ pub mod npu_mgmt { _ => None, } } + + pub(crate) trait MgmtFileIO { + fn mgmt_root(&self) -> PathBuf; + + fn read_mgmt_to_string>(&self, file: P) -> DeviceResult { + read_mgmt_to_string(self.mgmt_root(), file).map_err(|e| e.into()) + } + + fn write_ctrl_file>(&self, file: P, contents: &str) -> DeviceResult<()> { + let path = &self.mgmt_root().join(file); + std::fs::write(path, contents)?; + Ok(()) + } + } + + pub(crate) trait MgmtFile { + fn filename(&self) -> &'static str; + } + + #[derive(Clone)] + pub(crate) struct MgmtCache { + cache: HashMap, + } + + impl MgmtCache { + pub fn init>( + keys: impl Iterator, + mgmt_root: P, + ) -> io::Result { + let cache: io::Result> = keys + .map(|key| { + let value = read_mgmt_to_string(&mgmt_root, key.filename())?; + Ok((key, value)) + }) + .collect(); + + let cache = cache?; + Ok(MgmtCache { cache }) + } + + pub fn get(&self, key: &K) -> String { + self.cache.get(key).unwrap_or(&Default::default()).clone() + } + } } // XXX(n0gu): warboy and renegade share the same implementation, but this may change in the future devices.