From 8582089f8ec5d1ee912fa46892196c813a7c62e8 Mon Sep 17 00:00:00 2001 From: jinlow Date: Tue, 12 Sep 2023 07:33:53 -0500 Subject: [PATCH] Small tweak --- src/splitter.rs | 86 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 65 insertions(+), 21 deletions(-) diff --git a/src/splitter.rs b/src/splitter.rs index e093385..0e8fced 100644 --- a/src/splitter.rs +++ b/src/splitter.rs @@ -348,28 +348,72 @@ impl Splitter for MissingBranchSplitter { return None; } - let mut left_weight = constrained_weight( - &self.l2, - left_gradient, - left_hessian, - lower_bound, - upper_bound, - constraint, - ); - let mut right_weight = constrained_weight( - &self.l2, - right_gradient, - right_hessian, - lower_bound, - upper_bound, - constraint, - ); + let (left_weight, right_weight) = if self.force_children_to_bound_parent { + // What if we made it proportional to the size of the + // cover in each node? + let l_p = left_hessian / (right_hessian + left_hessian); + let r_p = left_hessian / (right_hessian + left_hessian); - if self.force_children_to_bound_parent { - (left_weight, right_weight) = bound_to_parent(parent_weight, left_weight, right_weight); - assert!(between(lower_bound, upper_bound, left_weight)); - assert!(between(lower_bound, upper_bound, right_weight)); - } + ( + constrained_weight( + &self.l2, + left_gradient + missing_gradient * l_p, + left_hessian + missing_hessian * l_p, + lower_bound, + upper_bound, + constraint, + ), + constrained_weight( + &self.l2, + right_gradient + missing_gradient * r_p, + right_hessian + missing_hessian * r_p, + lower_bound, + upper_bound, + constraint, + ), + ) + } else { + ( + constrained_weight( + &self.l2, + left_gradient, + left_hessian, + lower_bound, + upper_bound, + constraint, + ), + constrained_weight( + &self.l2, + right_gradient, + right_hessian, + lower_bound, + upper_bound, + constraint, + ), + ) + }; + // let mut left_weight = constrained_weight( + // &self.l2, + // left_gradient, + // left_hessian, + // lower_bound, + // upper_bound, + // constraint, + // ); + // let mut right_weight = constrained_weight( + // &self.l2, + // right_gradient, + // right_hessian, + // lower_bound, + // upper_bound, + // constraint, + // ); + + // if self.force_children_to_bound_parent { + // (left_weight, right_weight) = bound_to_parent(parent_weight, left_weight, right_weight); + // assert!(between(lower_bound, upper_bound, left_weight)); + // assert!(between(lower_bound, upper_bound, right_weight)); + // } let left_gain = gain_given_weight(&self.l2, left_gradient, left_hessian, left_weight); let right_gain = gain_given_weight(&self.l2, right_gradient, right_hessian, right_weight);