diff --git a/src/driver/safe/core.rs b/src/driver/safe/core.rs index 919d36f..95b8311 100644 --- a/src/driver/safe/core.rs +++ b/src/driver/safe/core.rs @@ -703,6 +703,21 @@ impl<'a, T> CudaView<'a, T> { marker: PhantomData, }) } + + /// Reinterprets the slice of memory into a different type. `len` is the number + /// of elements of the new type `S` that are expected. If not enough bytes + /// are allocated in `self` for the view, then this returns `None`. + /// + /// # Safety + /// This is unsafe because not the memory for the view may not be a valid interpretation + /// for the type `S`. + pub unsafe fn transmute(&self, len: usize) -> Option> { + (len * std::mem::size_of::() <= self.num_bytes()).then_some(CudaView { + ptr: self.ptr, + len, + marker: PhantomData, + }) + } } /// A mutable sub-view into a [CudaSlice] created by [CudaSlice::try_slice_mut()] or [CudaSlice::slice_mut()]. @@ -871,6 +886,21 @@ impl<'a, T> CudaViewMut<'a, T> { }) } + /// Reinterprets the slice of memory into a different type. `len` is the number + /// of elements of the new type `S` that are expected. If not enough bytes + /// are allocated in `self` for the view, then this returns `None`. + /// + /// # Safety + /// This is unsafe because not the memory for the view may not be a valid interpretation + /// for the type `S`. + pub unsafe fn transmute(&self, len: usize) -> Option> { + (len * std::mem::size_of::() <= self.num_bytes()).then_some(CudaView { + ptr: self.ptr, + len, + marker: PhantomData, + }) + } + /// Creates a [CudaViewMut] at the specified offset from the start of `self`. /// /// Panics if `range` and `0...self.len()` are not overlapping. @@ -934,6 +964,21 @@ impl<'a, T> CudaViewMut<'a, T> { }, )) } + + /// Reinterprets the slice of memory into a different type. `len` is the number + /// of elements of the new type `S` that are expected. If not enough bytes + /// are allocated in `self` for the view, then this returns `None`. + /// + /// # Safety + /// This is unsafe because not the memory for the view may not be a valid interpretation + /// for the type `S`. + pub unsafe fn transmute_mut(&mut self, len: usize) -> Option> { + (len * std::mem::size_of::() <= self.num_bytes()).then_some(CudaViewMut { + ptr: self.ptr, + len, + marker: PhantomData, + }) + } } trait RangeHelper: RangeBounds { @@ -993,5 +1038,19 @@ mod tests { assert!(unsafe { slice.transmute::(26) }.is_none()); assert!(unsafe { slice.transmute_mut::(25) }.is_some()); assert!(unsafe { slice.transmute_mut::(26) }.is_none()); + + { + let view = slice.slice(0..100); + assert!(unsafe { view.transmute::(25) }.is_some()); + assert!(unsafe { view.transmute::(26) }.is_none()); + } + + { + let mut view_mut = slice.slice_mut(0..100); + assert!(unsafe { view_mut.transmute::(25) }.is_some()); + assert!(unsafe { view_mut.transmute::(26) }.is_none()); + assert!(unsafe { view_mut.transmute_mut::(25) }.is_some()); + assert!(unsafe { view_mut.transmute_mut::(26) }.is_none()); + } } }