From a3a96f00805f2d65143fd5a467412f7fb1e6ab39 Mon Sep 17 00:00:00 2001 From: jinlow Date: Fri, 22 Dec 2023 22:43:40 +0000 Subject: [PATCH] small test --- src/histogram.rs | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/src/histogram.rs b/src/histogram.rs index bf1080f..ecaadea 100644 --- a/src/histogram.rs +++ b/src/histogram.rs @@ -34,6 +34,11 @@ impl Bin { } } + pub fn update(&mut self, gradient_sum: f32, hessian_sum: f32) { + self.gradient_sum = gradient_sum; + self.hessian_sum = hessian_sum; + } + /// Calculate a new bin, using the subtraction trick when the parent node /// has three directions, left, right, and missing. pub fn from_parent_two_children( @@ -85,25 +90,37 @@ pub fn create_feature_histogram( index: &[usize], ) -> Vec> { let mut histogram: Vec> = Vec::with_capacity(cuts.len()); - // The first value is missing, it seems to not matter that we are using - // Missing here, rather than the booster "missing" definition, because - // we just always assume the first bin of the histogram is missing. - histogram.push(Bin::new_f64(f64::NAN)); + + let mut gradient_sums: Vec = Vec::with_capacity(cuts.len()); + let mut hessian_sums: Vec = Vec::with_capacity(cuts.len()); + + histogram.push(Bin::new_f32(f64::NAN)); // The last cut value is simply the maximum possible value, so we don't need it. // This value is needed initially for binning, but we don't need to count it as // a histogram bin. - histogram.extend(cuts[..(cuts.len() - 1)].iter().map(|c| Bin::new_f64(*c))); + histogram.extend(cuts[..(cuts.len() - 1)].iter().map(|c| Bin::new_f32(*c))); + index .iter() .zip(sorted_grad) .zip(sorted_hess) .for_each(|((i, g), h)| { - if let Some(v) = histogram.get_mut(feature[*i] as usize) { - v.gradient_sum += f64::from(*g); - v.hessian_sum += f64::from(*h); + if let Some(v) = gradient_sums.get_mut(feature[*i] as usize) { + *v += f64::from(*g); + } + if let Some(v) = hessian_sums.get_mut(feature[*i] as usize) { + *v += f64::from(*h); } }); - histogram.iter().map(|b| b.as_f32_bin()).collect() + // The last cut value is simply the maximum possible value, so we don't need it. + // This value is needed initially for binning, but we don't need to count it as + // a histogram bin. + histogram + .iter_mut() + .zip(gradient_sums) + .zip(hessian_sums) + .map(|(h, (g, h))| h.update(g, h)) + .collect() } impl HistogramMatrix {