Skip to content

Commit

Permalink
feat: add transmute for slices (#282)
Browse files Browse the repository at this point in the history
Co-authored-by: Roman Walch <[email protected]>
  • Loading branch information
dkales and rw0x0 authored Aug 20, 2024
1 parent d7ac2b4 commit 5af0873
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions src/driver/safe/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,21 @@ impl<'a, T> CudaView<'a, T> {
marker: PhantomData,
})
}

/// Reinterprets the slice of memory into a different type. `len` is the number
/// of elements of the new type `S` that are expected. If not enough bytes
/// are allocated in `self` for the view, then this returns `None`.
///
/// # Safety
/// This is unsafe because not the memory for the view may not be a valid interpretation
/// for the type `S`.
pub unsafe fn transmute<S>(&self, len: usize) -> Option<CudaView<'_, S>> {
(len * std::mem::size_of::<S>() <= self.num_bytes()).then_some(CudaView {
ptr: self.ptr,
len,
marker: PhantomData,
})
}
}

/// A mutable sub-view into a [CudaSlice] created by [CudaSlice::try_slice_mut()] or [CudaSlice::slice_mut()].
Expand Down Expand Up @@ -871,6 +886,21 @@ impl<'a, T> CudaViewMut<'a, T> {
})
}

/// Reinterprets the slice of memory into a different type. `len` is the number
/// of elements of the new type `S` that are expected. If not enough bytes
/// are allocated in `self` for the view, then this returns `None`.
///
/// # Safety
/// This is unsafe because not the memory for the view may not be a valid interpretation
/// for the type `S`.
pub unsafe fn transmute<S>(&self, len: usize) -> Option<CudaView<'_, S>> {
(len * std::mem::size_of::<S>() <= self.num_bytes()).then_some(CudaView {
ptr: self.ptr,
len,
marker: PhantomData,
})
}

/// Creates a [CudaViewMut] at the specified offset from the start of `self`.
///
/// Panics if `range` and `0...self.len()` are not overlapping.
Expand Down Expand Up @@ -934,6 +964,21 @@ impl<'a, T> CudaViewMut<'a, T> {
},
))
}

/// Reinterprets the slice of memory into a different type. `len` is the number
/// of elements of the new type `S` that are expected. If not enough bytes
/// are allocated in `self` for the view, then this returns `None`.
///
/// # Safety
/// This is unsafe because not the memory for the view may not be a valid interpretation
/// for the type `S`.
pub unsafe fn transmute_mut<S>(&mut self, len: usize) -> Option<CudaViewMut<'_, S>> {
(len * std::mem::size_of::<S>() <= self.num_bytes()).then_some(CudaViewMut {
ptr: self.ptr,
len,
marker: PhantomData,
})
}
}

trait RangeHelper: RangeBounds<usize> {
Expand Down Expand Up @@ -993,5 +1038,19 @@ mod tests {
assert!(unsafe { slice.transmute::<f32>(26) }.is_none());
assert!(unsafe { slice.transmute_mut::<f32>(25) }.is_some());
assert!(unsafe { slice.transmute_mut::<f32>(26) }.is_none());

{
let view = slice.slice(0..100);
assert!(unsafe { view.transmute::<f32>(25) }.is_some());
assert!(unsafe { view.transmute::<f32>(26) }.is_none());
}

{
let mut view_mut = slice.slice_mut(0..100);
assert!(unsafe { view_mut.transmute::<f32>(25) }.is_some());
assert!(unsafe { view_mut.transmute::<f32>(26) }.is_none());
assert!(unsafe { view_mut.transmute_mut::<f32>(25) }.is_some());
assert!(unsafe { view_mut.transmute_mut::<f32>(26) }.is_none());
}
}
}

0 comments on commit 5af0873

Please sign in to comment.