Skip to content

Commit

Permalink
small test
Browse files Browse the repository at this point in the history
  • Loading branch information
jinlow committed Dec 22, 2023
1 parent 3e2ebe5 commit a3a96f0
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions src/histogram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ impl Bin<f32> {
}
}

pub fn update(&mut self, gradient_sum: f32, hessian_sum: f32) {
self.gradient_sum = gradient_sum;
self.hessian_sum = hessian_sum;
}

/// Calculate a new bin, using the subtraction trick when the parent node
/// has three directions, left, right, and missing.
pub fn from_parent_two_children(
Expand Down Expand Up @@ -85,25 +90,37 @@ pub fn create_feature_histogram(
index: &[usize],
) -> Vec<Bin<f32>> {
let mut histogram: Vec<Bin<f64>> = Vec::with_capacity(cuts.len());
// The first value is missing, it seems to not matter that we are using
// Missing here, rather than the booster "missing" definition, because
// we just always assume the first bin of the histogram is missing.
histogram.push(Bin::new_f64(f64::NAN));

let mut gradient_sums: Vec<f64> = Vec::with_capacity(cuts.len());
let mut hessian_sums: Vec<f64> = Vec::with_capacity(cuts.len());

histogram.push(Bin::new_f32(f64::NAN));
// The last cut value is simply the maximum possible value, so we don't need it.
// This value is needed initially for binning, but we don't need to count it as
// a histogram bin.
histogram.extend(cuts[..(cuts.len() - 1)].iter().map(|c| Bin::new_f64(*c)));
histogram.extend(cuts[..(cuts.len() - 1)].iter().map(|c| Bin::new_f32(*c)));

index
.iter()
.zip(sorted_grad)
.zip(sorted_hess)
.for_each(|((i, g), h)| {
if let Some(v) = histogram.get_mut(feature[*i] as usize) {
v.gradient_sum += f64::from(*g);
v.hessian_sum += f64::from(*h);
if let Some(v) = gradient_sums.get_mut(feature[*i] as usize) {
*v += f64::from(*g);
}
if let Some(v) = hessian_sums.get_mut(feature[*i] as usize) {
*v += f64::from(*h);
}
});
histogram.iter().map(|b| b.as_f32_bin()).collect()
// The last cut value is simply the maximum possible value, so we don't need it.
// This value is needed initially for binning, but we don't need to count it as
// a histogram bin.
histogram
.iter_mut()
.zip(gradient_sums)
.zip(hessian_sums)
.map(|(h, (g, h))| h.update(g, h))
.collect()
}

impl HistogramMatrix {
Expand Down

0 comments on commit a3a96f0

Please sign in to comment.