diff --git a/src/glum/_cd_fast.pyx b/src/glum/_cd_fast.pyx index 5b08eb83..bd3910cb 100644 --- a/src/glum/_cd_fast.pyx +++ b/src/glum/_cd_fast.pyx @@ -17,16 +17,15 @@ from cython.parallel import prange import warnings from sklearn.exceptions import ConvergenceWarning -from sklearn.utils._random cimport our_rand_r - ctypedef np.float64_t DOUBLE ctypedef np.uint32_t UINT32_t np.import_array() - -# The following two functions are shamelessly copied from the tree code. - +# The following two functions are shamelessly copied from the tree code. (_random.pxd) +# Authors: Arnaud Joly +# +# License: BSD-3-clause cdef enum: # Max value for our rand_r replacement (near the bottom). # We don't use RAND_MAX because it's different across platforms and @@ -34,6 +33,25 @@ cdef enum: RAND_R_MAX = 0x7FFFFFFF +# rand_r replacement using a 32bit XorShift generator +# See http://www.jstatsoft.org/v08/i14/paper for details +cdef inline UINT32_t our_rand_r(UINT32_t* seed) nogil: + """Generate a pseudo-random np.uint32 from a np.uint32 seed""" + # seed shouldn't ever be 0. + if (seed[0] == 0): + seed[0] = 1 + + seed[0] ^= (seed[0] << 13) + seed[0] ^= (seed[0] >> 17) + seed[0] ^= (seed[0] << 5) + + # Use the modulo to make sure that we don't return a values greater than the + # maximum representable value for signed 32bit integers (i.e. 2^31 - 1). + # Note that the parenthesis are needed to avoid overflow: here + # RAND_R_MAX is cast to UINT32_t before 1 is added. + return seed[0] % ((RAND_R_MAX) + 1) + + cdef inline UINT32_t rand_int(UINT32_t end, UINT32_t* random_state) nogil: """Generate a random integer in [0; end).""" return our_rand_r(random_state) % end