diff --git a/dfdx-core/src/nn_traits/mod.rs b/dfdx-core/src/nn_traits/mod.rs index 52203373..869e1047 100644 --- a/dfdx-core/src/nn_traits/mod.rs +++ b/dfdx-core/src/nn_traits/mod.rs @@ -173,6 +173,21 @@ pub trait LoadSafeTensors { ) -> Result<(), safetensors::SafeTensorError> { self.load_safetensors_with(path, false, &mut core::convert::identity) } + fn load_safetensors_from_bytes_with String>( + &mut self, + bytes: &[u8], + skip_missing: bool, + key_map: &mut F, + ) -> Result<(), safetensors::SafeTensorError> { + let tensors = safetensors::SafeTensors::deserialize(&bytes)?; + self.read_safetensors_with("", &tensors, skip_missing, key_map) + } + fn load_safetensors_from_bytes( + &mut self, + bytes: &[u8], + ) -> Result<(), safetensors::SafeTensorError> { + self.load_safetensors_from_bytes_with(bytes, false, &mut core::convert::identity) + } fn read_safetensors_with String>( &mut self,