Skip to content

Commit

Permalink
Small tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
jinlow committed Sep 12, 2023
1 parent f07d187 commit 8582089
Showing 1 changed file with 65 additions and 21 deletions.
86 changes: 65 additions & 21 deletions src/splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,28 +348,72 @@ impl Splitter for MissingBranchSplitter {
return None;
}

let mut left_weight = constrained_weight(
&self.l2,
left_gradient,
left_hessian,
lower_bound,
upper_bound,
constraint,
);
let mut right_weight = constrained_weight(
&self.l2,
right_gradient,
right_hessian,
lower_bound,
upper_bound,
constraint,
);
let (left_weight, right_weight) = if self.force_children_to_bound_parent {
// What if we made it proportional to the size of the
// cover in each node?
let l_p = left_hessian / (right_hessian + left_hessian);
let r_p = left_hessian / (right_hessian + left_hessian);

if self.force_children_to_bound_parent {
(left_weight, right_weight) = bound_to_parent(parent_weight, left_weight, right_weight);
assert!(between(lower_bound, upper_bound, left_weight));
assert!(between(lower_bound, upper_bound, right_weight));
}
(
constrained_weight(
&self.l2,
left_gradient + missing_gradient * l_p,
left_hessian + missing_hessian * l_p,
lower_bound,
upper_bound,
constraint,
),
constrained_weight(
&self.l2,
right_gradient + missing_gradient * r_p,
right_hessian + missing_hessian * r_p,
lower_bound,
upper_bound,
constraint,
),
)
} else {
(
constrained_weight(
&self.l2,
left_gradient,
left_hessian,
lower_bound,
upper_bound,
constraint,
),
constrained_weight(
&self.l2,
right_gradient,
right_hessian,
lower_bound,
upper_bound,
constraint,
),
)
};
// let mut left_weight = constrained_weight(
// &self.l2,
// left_gradient,
// left_hessian,
// lower_bound,
// upper_bound,
// constraint,
// );
// let mut right_weight = constrained_weight(
// &self.l2,
// right_gradient,
// right_hessian,
// lower_bound,
// upper_bound,
// constraint,
// );

// if self.force_children_to_bound_parent {
// (left_weight, right_weight) = bound_to_parent(parent_weight, left_weight, right_weight);
// assert!(between(lower_bound, upper_bound, left_weight));
// assert!(between(lower_bound, upper_bound, right_weight));
// }

let left_gain = gain_given_weight(&self.l2, left_gradient, left_hessian, left_weight);
let right_gain = gain_given_weight(&self.l2, right_gradient, right_hessian, right_weight);
Expand Down

0 comments on commit 8582089

Please sign in to comment.