diff --git a/src/curand/result.rs b/src/curand/result.rs index 48a0cf0..fab2a57 100644 --- a/src/curand/result.rs +++ b/src/curand/result.rs @@ -62,6 +62,16 @@ pub unsafe fn set_seed(generator: sys::curandGenerator_t, seed: u64) -> Result<( sys::curandSetPseudoRandomGeneratorSeed(generator, seed).result() } +/// Set the offset value of the pseudo-random number generator. +/// +/// See [cuRAND docs](https://docs.nvidia.com/cuda/curand/group__HOST.html#group__HOST_1gb21ba987f85486e552797206451b0939) +/// +/// # Safety +/// The generator must be allocated and not already freed. +pub unsafe fn set_offset(generator: sys::curandGenerator_t, offset: u64) -> Result<(), CurandError> { + sys::curandSetGeneratorOffset(generator, offset).result() +} + /// Set the current stream for CURAND kernel launches. /// /// See [cuRAND docs](https://docs.nvidia.com/cuda/curand/group__HOST.html#group__HOST_1gc78c8d07c7acea4242e2a62bc41ff1f5) diff --git a/src/curand/safe.rs b/src/curand/safe.rs index d908c05..db68028 100644 --- a/src/curand/safe.rs +++ b/src/curand/safe.rs @@ -46,6 +46,10 @@ impl CudaRng { unsafe { result::set_seed(self.gen, seed) } } + pub fn set_offset(&mut self, offset: u64) -> Result<(), result::CurandError> { + unsafe { result::set_offset(self.gen, offset) } + } + /// Fill the [CudaSlice] with data from a `Uniform` distribution pub fn fill_with_uniform(&self, t: &mut CudaSlice) -> Result<(), result::CurandError> where @@ -199,6 +203,26 @@ mod tests { assert_ne!(a_host, b_host); } + #[test] + fn test_set_offset() { + let dev = CudaDevice::new(0).unwrap(); + + let mut a_dev = dev.alloc_zeros::(10).unwrap(); + let mut a_rng = CudaRng::new(0, dev.clone()).unwrap(); + + a_rng.set_seed(42).unwrap(); + a_rng.set_offset(0).unwrap(); + a_rng.fill_with_uniform(&mut a_dev).unwrap(); + let a_host = dev.sync_reclaim(a_dev.clone()).unwrap(); + + a_rng.set_seed(42).unwrap(); + a_rng.set_offset(0).unwrap(); + a_rng.fill_with_uniform(&mut a_dev).unwrap(); + let b_host = dev.sync_reclaim(a_dev).unwrap(); + + assert_eq!(a_host, b_host); + } + const N: usize = 1000; #[test]