Skip to content

Commit

Permalink
Fix a forgotten change in carfac_bench to carfac_np.design_carfac, wh…
Browse files Browse the repository at this point in the history
…ich was using the one_cap argument and broken.

This updates to use the ihc_style string as elsewhere in the codebase and fixes the benchmark.

PiperOrigin-RevId: 684205274
  • Loading branch information
Rob Schonberger authored and copybara-github committed Oct 9, 2024
1 parent 926e96e commit d2a5ccd
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions python/jax/carfac_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def bench_numpy_in_slices(state: google_benchmark.State):
state: the benchmark state for this execution run.
"""
random_seed = 1
one_cap = False
cfp = carfac_np.design_carfac(one_cap=one_cap)
ihc_style = 'two_cap'
cfp = carfac_np.design_carfac(ihc_style=ihc_style)

carfac_np.carfac_init(cfp)
cfp.ears[0].car_coeffs.linear = False
Expand All @@ -72,7 +72,7 @@ def bench_numpy_in_slices(state: google_benchmark.State):
np_random.standard_normal(size=(n_samp, n_ears)) * _NOISE_FACTOR
)
run_seg_slices = np.array_split(run_seg_input_full, split_count)
cfp = carfac_np.design_carfac(one_cap=one_cap)
cfp = carfac_np.design_carfac(ihc_style=ihc_style)
carfac_np.carfac_init(cfp)
cfp.ears[0].car_coeffs.linear = False
state.resume_timing()
Expand Down Expand Up @@ -109,8 +109,8 @@ def bench_numpy(state: google_benchmark.State):
state: the benchmark state for this execution run.
"""
random_seed = 1
one_cap = False
cfp = carfac_np.design_carfac(one_cap=one_cap)
ihc_style = 'two_cap'
cfp = carfac_np.design_carfac(ihc_style=ihc_style)

carfac_np.carfac_init(cfp)
cfp.ears[0].car_coeffs.linear = False
Expand All @@ -123,7 +123,7 @@ def bench_numpy(state: google_benchmark.State):
run_seg_input = (
np_random.standard_normal(size=(n_samp, n_ears)) * _NOISE_FACTOR
)
cfp = carfac_np.design_carfac(one_cap=one_cap)
cfp = carfac_np.design_carfac(ihc_style=ihc_style)
carfac_np.carfac_init(cfp)
cfp.ears[0].car_coeffs.linear = False
state.resume_timing()
Expand Down

0 comments on commit d2a5ccd

Please sign in to comment.