From 54c4c15c2e9c750eec861bd168e8cd7ff6249449 Mon Sep 17 00:00:00 2001 From: Marco Visaya Date: Thu, 21 Nov 2024 11:53:54 -0800 Subject: [PATCH 1/8] Add xtask command to generate and verify flash image Following the flash layout defined in https://github.com/chipsalliance/caliptra-mcu-sw/blob/main/docs/src/flash_layout.md - Add new command to specify the firmware files and soc manifest to create 1 flash image file that can be burned to the SPI flash. - Add new command to verify a flash image file Example: cargo run --bin flash-image create --caliptra-fw caliptra_fw.bin \ --soc-manifest soc_manifest.bin \ --mcu-runtime mcu_runtime.bin \ --soc-images soc_image1.bin soc_image2.bin soc_image3.bin \ --output flash_image.bin cargo run --bin flash-image verify flash_image.bin --- Cargo.lock | 10 + xtask/Cargo.toml | 1 + xtask/src/flash_image.rs | 501 +++++++++++++++++++++++++++++++++++++++ xtask/src/main.rs | 54 +++++ 4 files changed, 566 insertions(+) create mode 100644 xtask/src/flash_image.rs diff --git a/Cargo.lock b/Cargo.lock index 8d7acba..79c266f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -724,6 +724,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "critical-section" version = "1.2.0" @@ -2534,6 +2543,7 @@ name = "xtask" version = "0.1.0" dependencies = [ "clap 4.5.21", + "crc32fast", "proc-macro2", "quote", "registers-generator", diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index 4b5726f..ea10608 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -13,3 +13,4 @@ clap.workspace = true registers-generator.workspace = true registers-systemrdl.workspace = true quote.workspace = true +crc32fast = "1.4.2" diff --git a/xtask/src/flash_image.rs b/xtask/src/flash_image.rs new file mode 100644 index 0000000..d1fb0a5 --- /dev/null +++ b/xtask/src/flash_image.rs @@ -0,0 +1,501 @@ +// Licensed under the Apache-2.0 license + +use crate::DynError; +use crc32fast::Hasher; +use std::fs::File; +use std::io::{self, Read, Write}; +use std::io::{Error, ErrorKind}; + +const FLASH_IMAGE_MAGIC_NUMBER: u32 = 0x464C5348; +const HEADER_VERSION: u16 = 0x0001; +const CALIPTRA_FMC_RT_IDENTIFIER: u32 = 0x00000001; +const SOC_MANIFEST_IDENTIFIER: u32 = 0x00000002; +const MCU_RT_IDENTIFIER: u32 = 0x00000002; +const SOC_IMAGES_BASE_IDENTIFIER: u32 = 0x00001000; + +pub struct FlashImage { + header: FlashImageHeader, + checksum: FlashImageChecksum, + payload: FlashImagePayload, +} + +pub struct FlashImageHeader { + magic_number: u32, + header_version: u16, + image_count: u16, // number of images +} + +pub struct FlashImageChecksum { + header: u32, // checksum of the header + payload: u32, // checksum of the payload +} + +pub struct FlashImagePayload { + image_info: Vec, + images: Vec, +} + +// Per image header +pub struct FlashImageInfo { + identifier: u32, + image_offset: u32, // Location of the image in the flash as an offset from the header + size: u32, // Size of the image +} + +#[derive(Clone)] +pub struct FirmwareImage { + identifier: u32, + data: Vec, +} + +impl FirmwareImage { + pub fn new(identifier: u32, filename: &str) -> io::Result { + let mut file = File::open(filename)?; + let mut data = Vec::new(); + file.read_to_end(&mut data)?; + + Ok(Self { identifier, data }) + } +} + +impl FlashImage { + pub fn new(images: &mut [FirmwareImage]) -> Self { + let mut image_info = Vec::new(); + let mut offset = std::mem::size_of::() as u32 + + std::mem::size_of::() as u32 + + (std::mem::size_of::() * images.len()) as u32; + + for image in images.iter_mut() { + let image_size = image.data.len() as u32; + Self::align_to_4_bytes(&mut image.data); + let padded_size = image.data.len() as u32; + image_info.push(FlashImageInfo { + identifier: image.identifier, + image_offset: offset, + size: image_size, + }); + offset += padded_size; + } + + let header = FlashImageHeader { + magic_number: FLASH_IMAGE_MAGIC_NUMBER, + header_version: HEADER_VERSION, + image_count: image_info.len() as u16, + }; + + let payload = FlashImagePayload { + image_info, + images: images.to_owned(), + }; + + let checksum = FlashImageChecksum::new(&header, &payload); + + Self { + header, + checksum, + payload, + } + } + + fn align_to_4_bytes(data: &mut Vec) { + let padding = (4 - (data.len() % 4)) % 4; + data.extend(vec![0; padding]); + } + + pub fn write_to_file(&self, filename: &str) -> io::Result<()> { + let mut file = File::create(filename)?; + + // Write header + file.write_all(&self.header.magic_number.to_le_bytes())?; + file.write_all(&self.header.header_version.to_le_bytes())?; + file.write_all(&self.header.image_count.to_le_bytes())?; + + // Write checksums + file.write_all(&self.checksum.header.to_le_bytes())?; + file.write_all(&self.checksum.payload.to_le_bytes())?; + + // Write image info + for info in &self.payload.image_info { + file.write_all(&info.identifier.to_le_bytes())?; + file.write_all(&info.image_offset.to_le_bytes())?; + file.write_all(&info.size.to_le_bytes())?; + } + + // Write images + for image in &self.payload.images { + file.write_all(&image.data)?; + } + + Ok(()) + } + + pub fn verify_flash_image(image: &[u8]) -> Result<(), DynError> { + // Parse and verify header + let magic_number = u32::from_le_bytes(image[0..4].try_into().unwrap()); + let header_version = u16::from_le_bytes(image[4..6].try_into().unwrap()); + let image_count = u16::from_le_bytes(image[6..8].try_into().unwrap()); + + if magic_number != FLASH_IMAGE_MAGIC_NUMBER { + // Return error + return Err("Invalid header: incorrect magic number or header version.")?; + } + + if header_version != HEADER_VERSION { + return Err("Unsupported header version")?; + } + + if image_count < 3 { + return Err("Expected at least 3 images")?; + } + + // Parse and verify checksums + let header_checksum = u32::from_le_bytes(image[8..12].try_into().unwrap()); + let payload_checksum = u32::from_le_bytes(image[12..16].try_into().unwrap()); + let calculated_header_checksum = FlashImageChecksum::calculate_checksum(&image[0..8]); + let calculated_payload_checksum = FlashImageChecksum::calculate_checksum(&image[16..]); + + if header_checksum != calculated_header_checksum { + return Err("Header checksum mismatch.")?; + } + + if payload_checksum != calculated_payload_checksum { + return Err("Payload checksum mismatch.")?; + } + + // Parse and verify image info and data + let mut offset = 16; // Start after header and checksums + + for i in 0..image_count as usize { + let identifier = u32::from_le_bytes(image[offset..offset + 4].try_into().unwrap()); + match i { + 0 => { + if identifier != CALIPTRA_FMC_RT_IDENTIFIER { + return Err("Image 0 is not Caliptra Identifier")?; + } + } + 1 => { + if identifier != SOC_MANIFEST_IDENTIFIER { + return Err("Image 0 is not SOC Manifest Identifier")?; + } + } + 2 => { + if identifier != MCU_RT_IDENTIFIER { + return Err("Image 0 is not MCU RT Identifier")?; + } + } + 3..255 => { + if identifier != (SOC_IMAGES_BASE_IDENTIFIER + (i as u32) - 3) { + return Err("Invalid SOC image identifier")?; + } + } + _ => return Err("Invalid image identifier")?, + } + + offset += 12; + } + + println!("Image is valid!"); + Ok(()) + } +} + +impl FlashImageHeader { + fn serialize(&self) -> Vec { + let mut buffer = Vec::new(); + buffer.extend_from_slice(&self.magic_number.to_le_bytes()); + buffer.extend_from_slice(&self.header_version.to_le_bytes()); + buffer.extend_from_slice(&self.image_count.to_le_bytes()); + buffer + } +} + +impl FlashImagePayload { + pub fn serialize(&self) -> Vec { + let mut buffer = Vec::new(); + for info in &self.image_info { + buffer.extend_from_slice(&info.identifier.to_le_bytes()); + buffer.extend_from_slice(&info.image_offset.to_le_bytes()); + buffer.extend_from_slice(&info.size.to_le_bytes()); + } + for image in &self.images { + buffer.extend_from_slice(&image.data); + } + buffer + } +} + +impl FlashImageChecksum { + pub fn new(header: &FlashImageHeader, payload: &FlashImagePayload) -> Self { + Self { + header: Self::calculate_checksum(&header.serialize()), + payload: Self::calculate_checksum(&payload.serialize()), + } + } + pub fn calculate_checksum(data: &[u8]) -> u32 { + let mut hasher = Hasher::new(); + hasher.update(data); + hasher.finalize() + } +} + +pub(crate) fn flash_image_create( + caliptra_fw_path: &str, + soc_manifest_path: &str, + mcu_runtime_path: &str, + soc_image_paths: &Option>, + output_path: &str, +) -> Result<(), DynError> { + let mut images: Vec = Vec::new(); + images.push(FirmwareImage::new( + CALIPTRA_FMC_RT_IDENTIFIER, + caliptra_fw_path, + )?); + images.push(FirmwareImage::new( + SOC_MANIFEST_IDENTIFIER, + soc_manifest_path, + )?); + images.push(FirmwareImage::new(MCU_RT_IDENTIFIER, mcu_runtime_path)?); + if let Some(soc_image_paths) = soc_image_paths { + let mut soc_image_identifer = SOC_IMAGES_BASE_IDENTIFIER; + for soc_image_path in soc_image_paths { + images.push(FirmwareImage::new(soc_image_identifer, soc_image_path)?); + soc_image_identifer += 1; + } + } + + let flash_image = FlashImage::new(&mut images); + flash_image.write_to_file(output_path)?; + + Ok(()) +} + +pub(crate) fn flash_image_verify(image_file_path: &str) -> Result<(), DynError> { + let mut file = File::open(image_file_path).map_err(|e| { + Error::new( + ErrorKind::NotFound, + format!("Failed to open file '{}': {}", image_file_path, e), + ) + })?; + + let mut data = Vec::new(); + file.read_to_end(&mut data).map_err(|e| { + Error::new( + ErrorKind::InvalidData, + format!("Failed to read file '{}': {}", image_file_path, e), + ) + })?; + file.read_to_end(&mut data)?; + FlashImage::verify_flash_image(&data) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::PROJECT_ROOT; + use std::fs::{self, File}; + use std::io::Write; + + /// Helper function to create a temporary file with specific content + fn create_temp_file(content: &[u8], file_name: &str) -> io::Result { + let tmp_directory = PROJECT_ROOT.join("target").join("tmp"); + fs::create_dir_all(tmp_directory.clone())?; + let path = tmp_directory.join(file_name); + let mut file = File::create(&path).expect("Failed to create temp file"); + file.write_all(content) + .expect("Failed to write to temp file"); + Ok(String::from(path.to_str().unwrap())) + } + + #[test] + fn test_flash_image_build() { + // Generate test contents for temporary files + let caliptra_fw_content = b"Caliptra Firmware Data - ABCDEFGH"; + let soc_manifest_content = b"Soc Manifest Data - 123456789"; + let mcu_runtime_content = b"MCU Runtime Data - QWERTYUI"; + let soc_image1_content = b"Soc Image 1 Data - ZXCVBNMLKJ"; + let soc_image2_content = b"Soc Image 2 Data - POIUYTREWQ"; + + // Create temporary files with the generated content + let caliptra_fw_path = create_temp_file(caliptra_fw_content, "caliptra_fw.bin") + .expect("Failed to create caliptra_fw.bin"); + let soc_manifest_path = create_temp_file(soc_manifest_content, "soc_manifest.bin") + .expect("Failed to create soc_manifest.bin"); + let mcu_runtime_path = create_temp_file(mcu_runtime_content, "mcu_runtime.bin") + .expect("Failed to create mcu_runtime.bin"); + let soc_image1_path = create_temp_file(soc_image1_content, "soc_image1.bin") + .expect("Failed to create soc_image1.bin"); + let soc_image2_path = create_temp_file(soc_image2_content, "soc_image2.bin") + .expect("Failed to create soc_image2.bin"); + + // Collect SoC image paths + let soc_image_paths = Some(vec![soc_image1_path.clone(), soc_image2_path.clone()]); + + // Specify the output file path + let output_path = PROJECT_ROOT + .join("target") + .join("tmp") + .join("flash_image.bin"); + let output_path = output_path.to_str().unwrap(); + + // Build the flash image + flash_image_create( + &caliptra_fw_path, + &soc_manifest_path, + &mcu_runtime_path, + &soc_image_paths, + output_path, + ) + .expect("Failed to build flash image"); + + // Read and verify the generated flash image + let mut file = File::open(output_path).expect("Failed to open generated flash image"); + let mut data = Vec::new(); + + file.read_to_end(&mut data) + .expect("Failed to read flash image"); + + // Verify header + let magic_number = u32::from_le_bytes(data[0..4].try_into().unwrap()); + let header_version = u16::from_le_bytes(data[4..6].try_into().unwrap()); + let image_count = u16::from_le_bytes(data[6..8].try_into().unwrap()); + + assert_eq!(magic_number, FLASH_IMAGE_MAGIC_NUMBER); + assert_eq!(header_version, HEADER_VERSION); + assert_eq!(image_count, 5); // 3 main images + 2 SoC images + + // Verify checksums + let header_checksum = u32::from_le_bytes(data[8..12].try_into().unwrap()); + let payload_checksum = u32::from_le_bytes(data[12..16].try_into().unwrap()); + let calculated_header_checksum = FlashImageChecksum::calculate_checksum(&data[0..8]); + let calculated_payload_checksum = FlashImageChecksum::calculate_checksum(&data[16..]); + assert_eq!(header_checksum, calculated_header_checksum); + assert_eq!(payload_checksum, calculated_payload_checksum); + + let expected_images: Vec<(u32, &[u8])> = vec![ + (CALIPTRA_FMC_RT_IDENTIFIER, caliptra_fw_content), + (SOC_MANIFEST_IDENTIFIER, soc_manifest_content), + (MCU_RT_IDENTIFIER, mcu_runtime_content), + (SOC_IMAGES_BASE_IDENTIFIER, soc_image1_content), + (SOC_IMAGES_BASE_IDENTIFIER + 1, soc_image2_content), + ]; + let mut image_offsets = Vec::new(); + let mut offset = 16; // Start after header and checksums + + for i in 0..image_count as usize { + let identifier = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); + let image_offset = u32::from_le_bytes(data[offset + 4..offset + 8].try_into().unwrap()); + let size = u32::from_le_bytes(data[offset + 8..offset + 12].try_into().unwrap()); + + // Verify identifier and size + assert_eq!(identifier, expected_images[i].0); + assert_eq!(size as usize, expected_images[i].1.len()); + + image_offsets.push((image_offset as usize, size as usize)); + offset += 12; + } + + // Verify image data using offsets + for (i, (start_offset, size)) in image_offsets.iter().enumerate() { + let actual_data = &data[*start_offset..*start_offset + size]; + assert_eq!(actual_data, expected_images[i].1); + } + + // Cleanup temporary files + fs::remove_file(caliptra_fw_path).unwrap(); + fs::remove_file(soc_manifest_path).unwrap(); + fs::remove_file(mcu_runtime_path).unwrap(); + fs::remove_file(soc_image1_path).unwrap(); + fs::remove_file(soc_image2_path).unwrap(); + fs::remove_file(output_path).unwrap(); + } + + #[test] + fn test_flash_image_verify_happy_path() { + let image_path = PROJECT_ROOT + .join("target") + .join("tmp") + .join("flash_image.bin"); + let image_path = image_path.to_str().unwrap(); + + // Create a valid firmware image + let mut expected_images = [ + FirmwareImage { + identifier: CALIPTRA_FMC_RT_IDENTIFIER, + data: b"Caliptra Firmware Data - ABCDEFGH".to_vec(), + }, + FirmwareImage { + identifier: SOC_MANIFEST_IDENTIFIER, + data: b"Soc Manifest Data - 123456789".to_vec(), + }, + FirmwareImage { + identifier: MCU_RT_IDENTIFIER, + data: b"MCU Runtime Data - QWERTYUI".to_vec(), + }, + FirmwareImage { + identifier: SOC_IMAGES_BASE_IDENTIFIER, + data: b"Soc Image 1 Data - ZXCVBNMLKJ".to_vec(), + }, + FirmwareImage { + identifier: SOC_IMAGES_BASE_IDENTIFIER + 1, + data: b"Soc Image 2 Data - POIUYTREWQ".to_vec(), + }, + ]; + // Create a flash image from the mutable slice + let flash_image = FlashImage::new(&mut expected_images); + flash_image + .write_to_file(image_path) + .expect("Failed to write flash image"); + + // Verify the firmware image + let result = flash_image_verify(image_path); + assert!(result.is_ok(), "Expected verification to succeed"); + + // Cleanup + fs::remove_file(image_path).expect("Failed to clean up test file"); + } + + #[test] + fn test_flash_image_verify_corrupted_case() { + let image_path = PROJECT_ROOT + .join("target") + .join("tmp") + .join("flash_image.bin"); + let image_path = image_path.to_str().unwrap(); + + // Create a corrupted firmware image (tamper with the header or data) + FlashImage::new(&mut vec![ + FirmwareImage { + identifier: CALIPTRA_FMC_RT_IDENTIFIER, + data: b"Valid Caliptra Firmware Data".to_vec(), + }, + FirmwareImage { + identifier: SOC_MANIFEST_IDENTIFIER, + data: b"Valid SOC Manifest Data".to_vec(), + }, + ]) + .write_to_file(image_path) + .expect("Failed to write flash image"); + + // Corrupt the file by modifying the data + let mut file = File::options() + .write(true) + .open(image_path) + .expect("Failed to open firmware image for tampering"); + file.write_all(b"Corrupted Data") + .expect("Failed to corrupt data"); + + // Verify the corrupted firmware image + let result = flash_image_verify(image_path); + assert!( + result.is_err(), + "Expected verification to fail for corrupted firmware image" + ); + + if let Err(e) = result { + println!("Expected error: {}", e); + } + + // Cleanup + fs::remove_file(image_path).expect("Failed to clean up test file"); + } +} diff --git a/xtask/src/main.rs b/xtask/src/main.rs index c025f28..07eba08 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -11,6 +11,7 @@ mod apps_build; mod cargo_lock; mod clippy; mod docs; +mod flash_image; mod format; mod header; mod precheckin; @@ -46,6 +47,11 @@ enum Commands { #[arg(short, long, default_value_t = false)] trace: bool, }, + /// Commands related to flash images + FlashImage { + #[command(subcommand)] + subcommand: FlashImageCommands, + }, /// Run clippy on all targets Clippy, /// Build docs @@ -70,6 +76,38 @@ enum Commands { }, } +#[derive(Subcommand)] +enum FlashImageCommands { + /// Create a new flash image + Create { + /// Path to the Caliptra firmware file + #[arg(long, value_name = "CALIPTRA_FW", required = true)] + caliptra_fw: String, + + /// Path to the SoC manifest file + #[arg(long, value_name = "SOC_MANIFEST", required = true)] + soc_manifest: String, + + /// Path to the MCU runtime file + #[arg(long, value_name = "MCU_RUNTIME", required = true)] + mcu_runtime: String, + + /// Paths to optional SoC images + #[arg(long, value_name = "SOC_IMAGE", num_args=1.., required = false)] + soc_images: Option>, + + /// Paths to the output image file + #[arg(long, value_name = "OUTPUT", required = true)] + output: String, + }, + /// Verify an existing flash image + Verify { + /// Path to the flash image file + #[arg(value_name = "FILE")] + file: String, + }, +} + pub type DynError = Box; pub const TARGET: &str = "riscv32imc-unknown-none-elf"; @@ -87,6 +125,22 @@ fn main() { Commands::RuntimeBuild => runtime_build::runtime_build_with_apps(), Commands::Rom { trace } => rom::rom_run(*trace), Commands::RomBuild => rom::rom_build(), + Commands::FlashImage { subcommand } => match subcommand { + FlashImageCommands::Create { + caliptra_fw, + soc_manifest, + mcu_runtime, + soc_images, + output, + } => flash_image::flash_image_create( + caliptra_fw, + soc_manifest, + mcu_runtime, + soc_images, + output, + ), + FlashImageCommands::Verify { file } => flash_image::flash_image_verify(file), + }, Commands::Clippy => clippy::clippy(), Commands::Docs => docs::docs(), Commands::Precheckin => precheckin::precheckin(), From 56d27e357a5b1e3ba7c594a79332d3e6619a0268 Mon Sep 17 00:00:00 2001 From: Marco Visaya Date: Thu, 21 Nov 2024 14:28:51 -0800 Subject: [PATCH 2/8] Print error message in unit test when verification fails --- xtask/src/flash_image.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xtask/src/flash_image.rs b/xtask/src/flash_image.rs index d1fb0a5..2e272d5 100644 --- a/xtask/src/flash_image.rs +++ b/xtask/src/flash_image.rs @@ -448,7 +448,10 @@ mod tests { // Verify the firmware image let result = flash_image_verify(image_path); - assert!(result.is_ok(), "Expected verification to succeed"); + result.unwrap_or_else(|e| { + eprintln!("Error: {}", e); + std::process::exit(1); + }); // Cleanup fs::remove_file(image_path).expect("Failed to clean up test file"); From 9e3599fd99173b242f4ee79b13075e46896e60bd Mon Sep 17 00:00:00 2001 From: Marco Visaya Date: Thu, 21 Nov 2024 14:41:30 -0800 Subject: [PATCH 3/8] Make temporary file names different for each test case Tests can run simultaneously and can access the same temporary files at the same time. --- xtask/src/flash_image.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xtask/src/flash_image.rs b/xtask/src/flash_image.rs index 2e272d5..55ff795 100644 --- a/xtask/src/flash_image.rs +++ b/xtask/src/flash_image.rs @@ -414,7 +414,7 @@ mod tests { let image_path = PROJECT_ROOT .join("target") .join("tmp") - .join("flash_image.bin"); + .join("flash_image_happy_path.bin"); let image_path = image_path.to_str().unwrap(); // Create a valid firmware image @@ -462,7 +462,7 @@ mod tests { let image_path = PROJECT_ROOT .join("target") .join("tmp") - .join("flash_image.bin"); + .join("flash_image_corrupted.bin"); let image_path = image_path.to_str().unwrap(); // Create a corrupted firmware image (tamper with the header or data) From d9548c8f4533de02d14a7dd412b8ba84112eef75 Mon Sep 17 00:00:00 2001 From: mlvisaya <38512415+mlvisaya@users.noreply.github.com> Date: Fri, 22 Nov 2024 10:55:36 -0800 Subject: [PATCH 4/8] Update xtask/src/flash_image.rs Co-authored-by: Christopher Swenson --- xtask/src/flash_image.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xtask/src/flash_image.rs b/xtask/src/flash_image.rs index 55ff795..b6efd0a 100644 --- a/xtask/src/flash_image.rs +++ b/xtask/src/flash_image.rs @@ -98,7 +98,7 @@ impl FlashImage { } fn align_to_4_bytes(data: &mut Vec) { - let padding = (4 - (data.len() % 4)) % 4; + let padding = data.len().next_multiple_of(4) - data.len(); data.extend(vec![0; padding]); } From 364d9cfc2853b51aefb645c4aa1ecfb117367b2f Mon Sep 17 00:00:00 2001 From: Marco Visaya Date: Fri, 22 Nov 2024 16:05:47 -0800 Subject: [PATCH 5/8] Address feedback - use zerocopy - convert all Vecs in structs to slices - various code optimization --- Cargo.lock | 21 +++ xtask/Cargo.toml | 2 + xtask/src/flash_image.rs | 345 ++++++++++++++++++++------------------- xtask/src/main.rs | 9 +- 4 files changed, 205 insertions(+), 172 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 79c266f..25fdd81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1045,6 +1045,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "fastrand" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" + [[package]] name = "ff" version = "0.13.0" @@ -2176,6 +2182,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tempfile" +version = "3.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +dependencies = [ + "cfg-if", + "fastrand", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "terminal_size" version = "0.4.0" @@ -2549,7 +2568,9 @@ dependencies = [ "registers-generator", "registers-systemrdl", "semver", + "tempfile", "walkdir", + "zerocopy 0.6.6", ] [[package]] diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index ea10608..e18d58d 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -14,3 +14,5 @@ registers-generator.workspace = true registers-systemrdl.workspace = true quote.workspace = true crc32fast = "1.4.2" +zerocopy.workspace = true +tempfile = "3.14.0" diff --git a/xtask/src/flash_image.rs b/xtask/src/flash_image.rs index b6efd0a..012d760 100644 --- a/xtask/src/flash_image.rs +++ b/xtask/src/flash_image.rs @@ -3,39 +3,44 @@ use crate::DynError; use crc32fast::Hasher; use std::fs::File; -use std::io::{self, Read, Write}; -use std::io::{Error, ErrorKind}; +use std::io::{self, Error, ErrorKind, Read, Write}; +use zerocopy::{byteorder::U32, AsBytes, FromBytes}; -const FLASH_IMAGE_MAGIC_NUMBER: u32 = 0x464C5348; +const FLASH_IMAGE_MAGIC_NUMBER: u32 = u32::from_be_bytes([b'F', b'L', b'S', b'H']); const HEADER_VERSION: u16 = 0x0001; const CALIPTRA_FMC_RT_IDENTIFIER: u32 = 0x00000001; const SOC_MANIFEST_IDENTIFIER: u32 = 0x00000002; const MCU_RT_IDENTIFIER: u32 = 0x00000002; const SOC_IMAGES_BASE_IDENTIFIER: u32 = 0x00001000; -pub struct FlashImage { +pub struct FlashImage<'a> { header: FlashImageHeader, checksum: FlashImageChecksum, - payload: FlashImagePayload, + payload: FlashImagePayload<'a>, } +#[repr(C)] +#[derive(AsBytes, FromBytes)] pub struct FlashImageHeader { - magic_number: u32, + magic_number: U32, header_version: u16, - image_count: u16, // number of images + image_count: u16, } +#[repr(C)] +#[derive(AsBytes, FromBytes)] pub struct FlashImageChecksum { - header: u32, // checksum of the header - payload: u32, // checksum of the payload + header: u32, + payload: u32, } -pub struct FlashImagePayload { - image_info: Vec, - images: Vec, +pub struct FlashImagePayload<'a> { + image_info: &'a [FlashImageInfo], + images: &'a [FirmwareImage<'a>], } -// Per image header +#[repr(C)] +#[derive(AsBytes, FromBytes)] pub struct FlashImageInfo { identifier: u32, image_offset: u32, // Location of the image in the flash as an offset from the header @@ -43,50 +48,29 @@ pub struct FlashImageInfo { } #[derive(Clone)] -pub struct FirmwareImage { - identifier: u32, - data: Vec, +pub struct FirmwareImage<'a> { + pub identifier: u32, + pub data: &'a [u8], } -impl FirmwareImage { - pub fn new(identifier: u32, filename: &str) -> io::Result { - let mut file = File::open(filename)?; - let mut data = Vec::new(); - file.read_to_end(&mut data)?; - - Ok(Self { identifier, data }) +impl<'a> FirmwareImage<'a> { + pub fn new(identifier: u32, content: &'a [u8]) -> io::Result { + Ok(Self { + identifier, + data: content, + }) } } -impl FlashImage { - pub fn new(images: &mut [FirmwareImage]) -> Self { - let mut image_info = Vec::new(); - let mut offset = std::mem::size_of::() as u32 - + std::mem::size_of::() as u32 - + (std::mem::size_of::() * images.len()) as u32; - - for image in images.iter_mut() { - let image_size = image.data.len() as u32; - Self::align_to_4_bytes(&mut image.data); - let padded_size = image.data.len() as u32; - image_info.push(FlashImageInfo { - identifier: image.identifier, - image_offset: offset, - size: image_size, - }); - offset += padded_size; - } - +impl<'a> FlashImage<'a> { + pub fn new(images: &'a [FirmwareImage<'a>], image_info: &'a [FlashImageInfo]) -> Self { let header = FlashImageHeader { - magic_number: FLASH_IMAGE_MAGIC_NUMBER, + magic_number: FLASH_IMAGE_MAGIC_NUMBER.into(), header_version: HEADER_VERSION, image_count: image_info.len() as u16, }; - let payload = FlashImagePayload { - image_info, - images: images.to_owned(), - }; + let payload = FlashImagePayload { image_info, images }; let checksum = FlashImageChecksum::new(&header, &payload); @@ -97,33 +81,16 @@ impl FlashImage { } } - fn align_to_4_bytes(data: &mut Vec) { - let padding = data.len().next_multiple_of(4) - data.len(); - data.extend(vec![0; padding]); - } - - pub fn write_to_file(&self, filename: &str) -> io::Result<()> { - let mut file = File::create(filename)?; - - // Write header - file.write_all(&self.header.magic_number.to_le_bytes())?; - file.write_all(&self.header.header_version.to_le_bytes())?; - file.write_all(&self.header.image_count.to_le_bytes())?; - - // Write checksums - file.write_all(&self.checksum.header.to_le_bytes())?; - file.write_all(&self.checksum.payload.to_le_bytes())?; - - // Write image info - for info in &self.payload.image_info { - file.write_all(&info.identifier.to_le_bytes())?; - file.write_all(&info.image_offset.to_le_bytes())?; - file.write_all(&info.size.to_le_bytes())?; + pub fn write_to_file(&self, filename: &str) -> Result<(), DynError> { + let mut file = File::create(filename) + .map_err(|e| format!("Unable to create file {}: {}", filename, e))?; + file.write_all(self.header.as_bytes())?; + file.write_all(self.checksum.as_bytes())?; + for info in self.payload.image_info { + file.write_all(info.as_bytes())?; } - - // Write images - for image in &self.payload.images { - file.write_all(&image.data)?; + for image in self.payload.images { + file.write_all(image.data)?; } Ok(()) @@ -131,12 +98,11 @@ impl FlashImage { pub fn verify_flash_image(image: &[u8]) -> Result<(), DynError> { // Parse and verify header - let magic_number = u32::from_le_bytes(image[0..4].try_into().unwrap()); + let magic_number = u32::from_be_bytes(image[0..4].try_into().unwrap()); let header_version = u16::from_le_bytes(image[4..6].try_into().unwrap()); let image_count = u16::from_le_bytes(image[6..8].try_into().unwrap()); if magic_number != FLASH_IMAGE_MAGIC_NUMBER { - // Return error return Err("Invalid header: incorrect magic number or header version.")?; } @@ -151,8 +117,8 @@ impl FlashImage { // Parse and verify checksums let header_checksum = u32::from_le_bytes(image[8..12].try_into().unwrap()); let payload_checksum = u32::from_le_bytes(image[12..16].try_into().unwrap()); - let calculated_header_checksum = FlashImageChecksum::calculate_checksum(&image[0..8]); - let calculated_payload_checksum = FlashImageChecksum::calculate_checksum(&image[16..]); + let calculated_header_checksum = calculate_checksum(&image[0..8]); + let calculated_payload_checksum = calculate_checksum(&image[16..]); if header_checksum != calculated_header_checksum { return Err("Header checksum mismatch.")?; @@ -163,9 +129,10 @@ impl FlashImage { } // Parse and verify image info and data - let mut offset = 16; // Start after header and checksums - for i in 0..image_count as usize { + let offset = std::mem::size_of::() + + std::mem::size_of::() + + (std::mem::size_of::() * i); let identifier = u32::from_le_bytes(image[offset..offset + 4].try_into().unwrap()); match i { 0 => { @@ -190,8 +157,6 @@ impl FlashImage { } _ => return Err("Invalid image identifier")?, } - - offset += 12; } println!("Image is valid!"); @@ -199,76 +164,109 @@ impl FlashImage { } } -impl FlashImageHeader { - fn serialize(&self) -> Vec { - let mut buffer = Vec::new(); - buffer.extend_from_slice(&self.magic_number.to_le_bytes()); - buffer.extend_from_slice(&self.header_version.to_le_bytes()); - buffer.extend_from_slice(&self.image_count.to_le_bytes()); - buffer - } +pub fn calculate_checksum(data: &[u8]) -> u32 { + let mut hasher = Hasher::new(); + hasher.update(data); + hasher.finalize() } -impl FlashImagePayload { - pub fn serialize(&self) -> Vec { - let mut buffer = Vec::new(); - for info in &self.image_info { - buffer.extend_from_slice(&info.identifier.to_le_bytes()); - buffer.extend_from_slice(&info.image_offset.to_le_bytes()); - buffer.extend_from_slice(&info.size.to_le_bytes()); +impl FlashImagePayload<'_> { + pub fn calculate_checksum(&self) -> u32 { + let mut hasher = Hasher::new(); + for info in self.image_info { + hasher.update(info.as_bytes()); } - for image in &self.images { - buffer.extend_from_slice(&image.data); + for image in self.images { + hasher.update(image.data); } - buffer + hasher.finalize() } } impl FlashImageChecksum { pub fn new(header: &FlashImageHeader, payload: &FlashImagePayload) -> Self { Self { - header: Self::calculate_checksum(&header.serialize()), - payload: Self::calculate_checksum(&payload.serialize()), + header: calculate_checksum(header.as_bytes()), + payload: payload.calculate_checksum(), } } - pub fn calculate_checksum(data: &[u8]) -> u32 { - let mut hasher = Hasher::new(); - hasher.update(data); - hasher.finalize() - } +} + +fn load_file(filename: &str) -> Result, DynError> { + let mut buffer = Vec::new(); + + // Open the file, map errors to a custom error message + let mut file = + File::open(filename).map_err(|e| format!("Cannot open file '{}': {}", filename, e))?; + + // Read the file into the buffer, map errors similarly + file.read_to_end(&mut buffer) + .map_err(|e| format!("Cannot read file '{}': {}", filename, e))?; + + let padding = buffer.len().next_multiple_of(4) - buffer.len(); // Calculate padding size + buffer.extend(vec![0; padding]); // Append padding bytes + + Ok(buffer) } pub(crate) fn flash_image_create( caliptra_fw_path: &str, soc_manifest_path: &str, mcu_runtime_path: &str, - soc_image_paths: &Option>, + soc_image_paths: &Option>, output_path: &str, ) -> Result<(), DynError> { let mut images: Vec = Vec::new(); - images.push(FirmwareImage::new( - CALIPTRA_FMC_RT_IDENTIFIER, - caliptra_fw_path, - )?); - images.push(FirmwareImage::new( - SOC_MANIFEST_IDENTIFIER, - soc_manifest_path, - )?); - images.push(FirmwareImage::new(MCU_RT_IDENTIFIER, mcu_runtime_path)?); + + let content = load_file(caliptra_fw_path)?; + images.push(FirmwareImage::new(CALIPTRA_FMC_RT_IDENTIFIER, &content)?); + + let content = load_file(soc_manifest_path)?; + images.push(FirmwareImage::new(SOC_MANIFEST_IDENTIFIER, &content)?); + + let content = load_file(mcu_runtime_path)?; + images.push(FirmwareImage::new(MCU_RT_IDENTIFIER, &content)?); + + // Load SOC images into a buffer + let mut soc_img_buffers: Vec> = Vec::new(); if let Some(soc_image_paths) = soc_image_paths { - let mut soc_image_identifer = SOC_IMAGES_BASE_IDENTIFIER; for soc_image_path in soc_image_paths { - images.push(FirmwareImage::new(soc_image_identifer, soc_image_path)?); - soc_image_identifer += 1; + let soc_image_data = load_file(soc_image_path)?; // Store the buffer + soc_img_buffers.push(soc_image_data); } } - let flash_image = FlashImage::new(&mut images); + // Generate FirmwareImage from soc image buffer + let mut soc_image_identifer = SOC_IMAGES_BASE_IDENTIFIER; + for soc_img in soc_img_buffers.iter() { + images.push(FirmwareImage::new(soc_image_identifer, soc_img)?); + soc_image_identifer += 1; + } + + let image_info = generate_image_info(images.clone()); + + let flash_image = FlashImage::new(&images, &image_info); flash_image.write_to_file(output_path)?; Ok(()) } +pub fn generate_image_info(images: Vec) -> Vec { + let mut info = Vec::new(); + let mut offset = std::mem::size_of::() as u32 + + std::mem::size_of::() as u32 + + (std::mem::size_of::() * images.len()) as u32; + for image in images.iter() { + info.push(FlashImageInfo { + identifier: image.identifier, + image_offset: offset, + size: image.data.len() as u32, + }); + offset += image.data.len() as u32; + } + info +} + pub(crate) fn flash_image_verify(image_file_path: &str) -> Result<(), DynError> { let mut file = File::open(image_file_path).map_err(|e| { Error::new( @@ -294,16 +292,16 @@ mod tests { use crate::PROJECT_ROOT; use std::fs::{self, File}; use std::io::Write; + use tempfile::NamedTempFile; /// Helper function to create a temporary file with specific content - fn create_temp_file(content: &[u8], file_name: &str) -> io::Result { - let tmp_directory = PROJECT_ROOT.join("target").join("tmp"); - fs::create_dir_all(tmp_directory.clone())?; - let path = tmp_directory.join(file_name); - let mut file = File::create(&path).expect("Failed to create temp file"); - file.write_all(content) + fn create_temp_file(content: &[u8]) -> io::Result { + let mut temp_file = NamedTempFile::new().expect("Failed to create temp file"); + + temp_file + .write_all(content) .expect("Failed to write to temp file"); - Ok(String::from(path.to_str().unwrap())) + Ok(temp_file) } #[test] @@ -316,19 +314,20 @@ mod tests { let soc_image2_content = b"Soc Image 2 Data - POIUYTREWQ"; // Create temporary files with the generated content - let caliptra_fw_path = create_temp_file(caliptra_fw_content, "caliptra_fw.bin") - .expect("Failed to create caliptra_fw.bin"); - let soc_manifest_path = create_temp_file(soc_manifest_content, "soc_manifest.bin") - .expect("Failed to create soc_manifest.bin"); - let mcu_runtime_path = create_temp_file(mcu_runtime_content, "mcu_runtime.bin") - .expect("Failed to create mcu_runtime.bin"); - let soc_image1_path = create_temp_file(soc_image1_content, "soc_image1.bin") - .expect("Failed to create soc_image1.bin"); - let soc_image2_path = create_temp_file(soc_image2_content, "soc_image2.bin") - .expect("Failed to create soc_image2.bin"); + let caliptra_fw = + create_temp_file(caliptra_fw_content).expect("Failed to create caliptra_fw"); + let soc_manifest = + create_temp_file(soc_manifest_content).expect("Failed to create soc_manifest"); + let mcu_runtime = + create_temp_file(mcu_runtime_content).expect("Failed to create mcu_runtime"); + let soc_image1 = create_temp_file(soc_image1_content).expect("Failed to create soc_image1"); + let soc_image2 = create_temp_file(soc_image2_content).expect("Failed to create soc_image2"); // Collect SoC image paths - let soc_image_paths = Some(vec![soc_image1_path.clone(), soc_image2_path.clone()]); + let soc_image_paths = Some(vec![ + soc_image1.path().to_str().unwrap(), + soc_image2.path().to_str().unwrap(), + ]); // Specify the output file path let output_path = PROJECT_ROOT @@ -339,9 +338,9 @@ mod tests { // Build the flash image flash_image_create( - &caliptra_fw_path, - &soc_manifest_path, - &mcu_runtime_path, + caliptra_fw.path().to_str().unwrap(), + soc_manifest.path().to_str().unwrap(), + mcu_runtime.path().to_str().unwrap(), &soc_image_paths, output_path, ) @@ -355,7 +354,7 @@ mod tests { .expect("Failed to read flash image"); // Verify header - let magic_number = u32::from_le_bytes(data[0..4].try_into().unwrap()); + let magic_number = u32::from_be_bytes(data[0..4].try_into().unwrap()); let header_version = u16::from_le_bytes(data[4..6].try_into().unwrap()); let image_count = u16::from_le_bytes(data[6..8].try_into().unwrap()); @@ -366,8 +365,8 @@ mod tests { // Verify checksums let header_checksum = u32::from_le_bytes(data[8..12].try_into().unwrap()); let payload_checksum = u32::from_le_bytes(data[12..16].try_into().unwrap()); - let calculated_header_checksum = FlashImageChecksum::calculate_checksum(&data[0..8]); - let calculated_payload_checksum = FlashImageChecksum::calculate_checksum(&data[16..]); + let calculated_header_checksum = calculate_checksum(&data[0..8]); + let calculated_payload_checksum = calculate_checksum(&data[16..]); assert_eq!(header_checksum, calculated_header_checksum); assert_eq!(payload_checksum, calculated_payload_checksum); @@ -379,34 +378,34 @@ mod tests { (SOC_IMAGES_BASE_IDENTIFIER + 1, soc_image2_content), ]; let mut image_offsets = Vec::new(); - let mut offset = 16; // Start after header and checksums - for i in 0..image_count as usize { + for (i, _item) in expected_images + .iter() + .enumerate() + .take(image_count as usize) + { + let offset = std::mem::size_of::() + + std::mem::size_of::() + + (std::mem::size_of::() * i); let identifier = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); let image_offset = u32::from_le_bytes(data[offset + 4..offset + 8].try_into().unwrap()); let size = u32::from_le_bytes(data[offset + 8..offset + 12].try_into().unwrap()); // Verify identifier and size assert_eq!(identifier, expected_images[i].0); - assert_eq!(size as usize, expected_images[i].1.len()); + assert_eq!( + size as usize, + expected_images[i].1.len().next_multiple_of(4) + ); image_offsets.push((image_offset as usize, size as usize)); - offset += 12; } // Verify image data using offsets - for (i, (start_offset, size)) in image_offsets.iter().enumerate() { - let actual_data = &data[*start_offset..*start_offset + size]; + for (i, (start_offset, _size)) in image_offsets.iter().enumerate() { + let actual_data = &data[*start_offset..*start_offset + expected_images[i].1.len()]; assert_eq!(actual_data, expected_images[i].1); } - - // Cleanup temporary files - fs::remove_file(caliptra_fw_path).unwrap(); - fs::remove_file(soc_manifest_path).unwrap(); - fs::remove_file(mcu_runtime_path).unwrap(); - fs::remove_file(soc_image1_path).unwrap(); - fs::remove_file(soc_image2_path).unwrap(); - fs::remove_file(output_path).unwrap(); } #[test] @@ -418,30 +417,31 @@ mod tests { let image_path = image_path.to_str().unwrap(); // Create a valid firmware image - let mut expected_images = [ + let expected_images = [ FirmwareImage { identifier: CALIPTRA_FMC_RT_IDENTIFIER, - data: b"Caliptra Firmware Data - ABCDEFGH".to_vec(), + data: b"Caliptra Firmware Data - ABCDEFGH", }, FirmwareImage { identifier: SOC_MANIFEST_IDENTIFIER, - data: b"Soc Manifest Data - 123456789".to_vec(), + data: b"Soc Manifest Data - 123456789", }, FirmwareImage { identifier: MCU_RT_IDENTIFIER, - data: b"MCU Runtime Data - QWERTYUI".to_vec(), + data: b"MCU Runtime Data - QWERTYUI", }, FirmwareImage { identifier: SOC_IMAGES_BASE_IDENTIFIER, - data: b"Soc Image 1 Data - ZXCVBNMLKJ".to_vec(), + data: b"Soc Image 1 Data - ZXCVBNMLKJ", }, FirmwareImage { identifier: SOC_IMAGES_BASE_IDENTIFIER + 1, - data: b"Soc Image 2 Data - POIUYTREWQ".to_vec(), + data: b"Soc Image 2 Data - POIUYTREWQ", }, ]; // Create a flash image from the mutable slice - let flash_image = FlashImage::new(&mut expected_images); + let image_info = generate_image_info(expected_images.to_vec()); + let flash_image = FlashImage::new(&expected_images, &image_info); flash_image .write_to_file(image_path) .expect("Failed to write flash image"); @@ -466,18 +466,21 @@ mod tests { let image_path = image_path.to_str().unwrap(); // Create a corrupted firmware image (tamper with the header or data) - FlashImage::new(&mut vec![ + let images = [ FirmwareImage { identifier: CALIPTRA_FMC_RT_IDENTIFIER, - data: b"Valid Caliptra Firmware Data".to_vec(), + data: b"Valid Caliptra Firmware Data", }, FirmwareImage { identifier: SOC_MANIFEST_IDENTIFIER, - data: b"Valid SOC Manifest Data".to_vec(), + data: b"Valid SOC Manifest Data", }, - ]) - .write_to_file(image_path) - .expect("Failed to write flash image"); + ]; + let image_info = generate_image_info(images.to_vec()); + let flash_image = FlashImage::new(&images, &image_info); + flash_image + .write_to_file(image_path) + .expect("Failed to write flash image"); // Corrupt the file by modifying the data let mut file = File::options() diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 07eba08..fd6d082 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -118,6 +118,13 @@ static PROJECT_ROOT: LazyLock = LazyLock::new(|| { .to_path_buf() }); +fn convert_option_vec(input: &Option>) -> Option> { + // Convert a series of String arguments into &str + input + .as_ref() + .map(|vec| vec.iter().map(|s| s.as_str()).collect()) +} + fn main() { let cli = Xtask::parse(); let result = match &cli.xtask { @@ -136,7 +143,7 @@ fn main() { caliptra_fw, soc_manifest, mcu_runtime, - soc_images, + &convert_option_vec(soc_images), output, ), FlashImageCommands::Verify { file } => flash_image::flash_image_verify(file), From aa38d971cd5c9a878b620cfbb79c2eb22aa35405 Mon Sep 17 00:00:00 2001 From: Marco Visaya Date: Fri, 22 Nov 2024 16:31:07 -0800 Subject: [PATCH 6/8] rearrange tempfile dependency --- xtask/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index e18d58d..02a768e 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] semver = "1.0.23" +tempfile = "3.14.0" walkdir = "2.5.0" proc-macro2.workspace = true clap.workspace = true @@ -15,4 +16,3 @@ registers-systemrdl.workspace = true quote.workspace = true crc32fast = "1.4.2" zerocopy.workspace = true -tempfile = "3.14.0" From afd6cdbae7dadec1fbe616b430bda8176ea5535b Mon Sep 17 00:00:00 2001 From: Marco Visaya Date: Fri, 22 Nov 2024 16:46:28 -0800 Subject: [PATCH 7/8] Fix dependency issues --- Cargo.lock | 2 +- xtask/src/flash_image.rs | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e19274f..e35f44b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2529,7 +2529,7 @@ dependencies = [ "semver", "tempfile", "walkdir", - "zerocopy 0.6.6", + "zerocopy 0.8.10", ] [[package]] diff --git a/xtask/src/flash_image.rs b/xtask/src/flash_image.rs index 012d760..9fb98db 100644 --- a/xtask/src/flash_image.rs +++ b/xtask/src/flash_image.rs @@ -4,7 +4,8 @@ use crate::DynError; use crc32fast::Hasher; use std::fs::File; use std::io::{self, Error, ErrorKind, Read, Write}; -use zerocopy::{byteorder::U32, AsBytes, FromBytes}; +use zerocopy::Immutable; +use zerocopy::{byteorder::U32, FromBytes, IntoBytes}; const FLASH_IMAGE_MAGIC_NUMBER: u32 = u32::from_be_bytes([b'F', b'L', b'S', b'H']); const HEADER_VERSION: u16 = 0x0001; @@ -20,7 +21,7 @@ pub struct FlashImage<'a> { } #[repr(C)] -#[derive(AsBytes, FromBytes)] +#[derive(IntoBytes, FromBytes, Immutable)] pub struct FlashImageHeader { magic_number: U32, header_version: u16, @@ -28,7 +29,7 @@ pub struct FlashImageHeader { } #[repr(C)] -#[derive(AsBytes, FromBytes)] +#[derive(IntoBytes, FromBytes, Immutable)] pub struct FlashImageChecksum { header: u32, payload: u32, @@ -40,7 +41,7 @@ pub struct FlashImagePayload<'a> { } #[repr(C)] -#[derive(AsBytes, FromBytes)] +#[derive(IntoBytes, FromBytes, Immutable)] pub struct FlashImageInfo { identifier: u32, image_offset: u32, // Location of the image in the flash as an offset from the header From 435be3825322546db7e697ee7e15ff4aed18644a Mon Sep 17 00:00:00 2001 From: Marco Visaya Date: Mon, 25 Nov 2024 14:06:03 -0800 Subject: [PATCH 8/8] Address comments - use zerocopy to parse image buffer - add packed trait to structs - use String instead of &str in CLI args --- xtask/src/flash_image.rs | 66 +++++++++++++++++++++------------------- xtask/src/main.rs | 9 +----- 2 files changed, 36 insertions(+), 39 deletions(-) diff --git a/xtask/src/flash_image.rs b/xtask/src/flash_image.rs index 9fb98db..cbfdc25 100644 --- a/xtask/src/flash_image.rs +++ b/xtask/src/flash_image.rs @@ -4,11 +4,14 @@ use crate::DynError; use crc32fast::Hasher; use std::fs::File; use std::io::{self, Error, ErrorKind, Read, Write}; -use zerocopy::Immutable; use zerocopy::{byteorder::U32, FromBytes, IntoBytes}; +use zerocopy::{Immutable, KnownLayout}; -const FLASH_IMAGE_MAGIC_NUMBER: u32 = u32::from_be_bytes([b'F', b'L', b'S', b'H']); +const FLASH_IMAGE_MAGIC_NUMBER: u32 = u32::from_be_bytes(*b"FLSH"); const HEADER_VERSION: u16 = 0x0001; +const HEADER_SIZE: usize = std::mem::size_of::(); +const CHECKSUM_SIZE: usize = std::mem::size_of::(); +const IMAGE_INFO_SIZE: usize = std::mem::size_of::(); const CALIPTRA_FMC_RT_IDENTIFIER: u32 = 0x00000001; const SOC_MANIFEST_IDENTIFIER: u32 = 0x00000002; const MCU_RT_IDENTIFIER: u32 = 0x00000002; @@ -20,16 +23,16 @@ pub struct FlashImage<'a> { payload: FlashImagePayload<'a>, } -#[repr(C)] -#[derive(IntoBytes, FromBytes, Immutable)] +#[repr(C, packed)] +#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)] pub struct FlashImageHeader { magic_number: U32, header_version: u16, image_count: u16, } -#[repr(C)] -#[derive(IntoBytes, FromBytes, Immutable)] +#[repr(C, packed)] +#[derive(IntoBytes, FromBytes, Immutable, KnownLayout)] pub struct FlashImageChecksum { header: u32, payload: u32, @@ -40,7 +43,7 @@ pub struct FlashImagePayload<'a> { images: &'a [FirmwareImage<'a>], } -#[repr(C)] +#[repr(C, packed)] #[derive(IntoBytes, FromBytes, Immutable)] pub struct FlashImageInfo { identifier: u32, @@ -99,60 +102,61 @@ impl<'a> FlashImage<'a> { pub fn verify_flash_image(image: &[u8]) -> Result<(), DynError> { // Parse and verify header - let magic_number = u32::from_be_bytes(image[0..4].try_into().unwrap()); - let header_version = u16::from_le_bytes(image[4..6].try_into().unwrap()); - let image_count = u16::from_le_bytes(image[6..8].try_into().unwrap()); - - if magic_number != FLASH_IMAGE_MAGIC_NUMBER { + if image.len() < HEADER_SIZE { + return Err("Image too small to contain the header.".into()); + } + let header = FlashImageHeader::read_from_bytes(&image[..HEADER_SIZE]) + .map_err(|_| "Failed to parse header: invalid format or size")?; + if header.magic_number != FLASH_IMAGE_MAGIC_NUMBER { return Err("Invalid header: incorrect magic number or header version.")?; } - if header_version != HEADER_VERSION { + if header.header_version != HEADER_VERSION { return Err("Unsupported header version")?; } - if image_count < 3 { + if header.image_count < 3 { return Err("Expected at least 3 images")?; } // Parse and verify checksums - let header_checksum = u32::from_le_bytes(image[8..12].try_into().unwrap()); - let payload_checksum = u32::from_le_bytes(image[12..16].try_into().unwrap()); - let calculated_header_checksum = calculate_checksum(&image[0..8]); + let checksum = + FlashImageChecksum::read_from_bytes(&image[HEADER_SIZE..(HEADER_SIZE + CHECKSUM_SIZE)]) + .map_err(|_| "Failed to parse checksum field")?; + let calculated_header_checksum = calculate_checksum(header.as_bytes()); let calculated_payload_checksum = calculate_checksum(&image[16..]); - if header_checksum != calculated_header_checksum { + if checksum.header != calculated_header_checksum { return Err("Header checksum mismatch.")?; } - if payload_checksum != calculated_payload_checksum { + if checksum.payload != calculated_payload_checksum { return Err("Payload checksum mismatch.")?; } // Parse and verify image info and data - for i in 0..image_count as usize { - let offset = std::mem::size_of::() - + std::mem::size_of::() - + (std::mem::size_of::() * i); - let identifier = u32::from_le_bytes(image[offset..offset + 4].try_into().unwrap()); + for i in 0..header.image_count as usize { + let offset = HEADER_SIZE + CHECKSUM_SIZE + (IMAGE_INFO_SIZE * i); + let info = FlashImageInfo::read_from_bytes(&image[offset..offset + IMAGE_INFO_SIZE]) + .map_err(|_| "Failed to read image info")?; match i { 0 => { - if identifier != CALIPTRA_FMC_RT_IDENTIFIER { + if info.identifier != CALIPTRA_FMC_RT_IDENTIFIER { return Err("Image 0 is not Caliptra Identifier")?; } } 1 => { - if identifier != SOC_MANIFEST_IDENTIFIER { + if info.identifier != SOC_MANIFEST_IDENTIFIER { return Err("Image 0 is not SOC Manifest Identifier")?; } } 2 => { - if identifier != MCU_RT_IDENTIFIER { + if info.identifier != MCU_RT_IDENTIFIER { return Err("Image 0 is not MCU RT Identifier")?; } } 3..255 => { - if identifier != (SOC_IMAGES_BASE_IDENTIFIER + (i as u32) - 3) { + if info.identifier != (SOC_IMAGES_BASE_IDENTIFIER + (i as u32) - 3) { return Err("Invalid SOC image identifier")?; } } @@ -214,7 +218,7 @@ pub(crate) fn flash_image_create( caliptra_fw_path: &str, soc_manifest_path: &str, mcu_runtime_path: &str, - soc_image_paths: &Option>, + soc_image_paths: &Option>, output_path: &str, ) -> Result<(), DynError> { let mut images: Vec = Vec::new(); @@ -326,8 +330,8 @@ mod tests { // Collect SoC image paths let soc_image_paths = Some(vec![ - soc_image1.path().to_str().unwrap(), - soc_image2.path().to_str().unwrap(), + soc_image1.path().to_str().unwrap().to_string(), + soc_image2.path().to_str().unwrap().to_string(), ]); // Specify the output file path diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 76bf105..ed6c94f 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -127,13 +127,6 @@ static PROJECT_ROOT: LazyLock = LazyLock::new(|| { .to_path_buf() }); -fn convert_option_vec(input: &Option>) -> Option> { - // Convert a series of String arguments into &str - input - .as_ref() - .map(|vec| vec.iter().map(|s| s.as_str()).collect()) -} - fn main() { let cli = Xtask::parse(); let result = match &cli.xtask { @@ -152,7 +145,7 @@ fn main() { caliptra_fw, soc_manifest, mcu_runtime, - &convert_option_vec(soc_images), + soc_images, output, ), FlashImageCommands::Verify { file } => flash_image::flash_image_verify(file),