Skip to content

Commit

Permalink
Add set_offset to that curandSetGeneratorOffset can be called. (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
mneilly authored Oct 25, 2023
1 parent 1895828 commit 13e308d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/curand/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ pub unsafe fn set_seed(generator: sys::curandGenerator_t, seed: u64) -> Result<(
sys::curandSetPseudoRandomGeneratorSeed(generator, seed).result()
}

/// Set the offset value of the pseudo-random number generator.
///
/// See [cuRAND docs](https://docs.nvidia.com/cuda/curand/group__HOST.html#group__HOST_1gb21ba987f85486e552797206451b0939)
///
/// # Safety
/// The generator must be allocated and not already freed.
pub unsafe fn set_offset(generator: sys::curandGenerator_t, offset: u64) -> Result<(), CurandError> {
sys::curandSetGeneratorOffset(generator, offset).result()
}

/// Set the current stream for CURAND kernel launches.
///
/// See [cuRAND docs](https://docs.nvidia.com/cuda/curand/group__HOST.html#group__HOST_1gc78c8d07c7acea4242e2a62bc41ff1f5)
Expand Down
24 changes: 24 additions & 0 deletions src/curand/safe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ impl CudaRng {
unsafe { result::set_seed(self.gen, seed) }
}

pub fn set_offset(&mut self, offset: u64) -> Result<(), result::CurandError> {
unsafe { result::set_offset(self.gen, offset) }
}

/// Fill the [CudaSlice] with data from a `Uniform` distribution
pub fn fill_with_uniform<T>(&self, t: &mut CudaSlice<T>) -> Result<(), result::CurandError>
where
Expand Down Expand Up @@ -199,6 +203,26 @@ mod tests {
assert_ne!(a_host, b_host);
}

#[test]
fn test_set_offset() {
let dev = CudaDevice::new(0).unwrap();

let mut a_dev = dev.alloc_zeros::<f32>(10).unwrap();
let mut a_rng = CudaRng::new(0, dev.clone()).unwrap();

a_rng.set_seed(42).unwrap();
a_rng.set_offset(0).unwrap();
a_rng.fill_with_uniform(&mut a_dev).unwrap();
let a_host = dev.sync_reclaim(a_dev.clone()).unwrap();

a_rng.set_seed(42).unwrap();
a_rng.set_offset(0).unwrap();
a_rng.fill_with_uniform(&mut a_dev).unwrap();
let b_host = dev.sync_reclaim(a_dev).unwrap();

assert_eq!(a_host, b_host);
}

const N: usize = 1000;

#[test]
Expand Down

0 comments on commit 13e308d

Please sign in to comment.