diff --git a/python/jax/carfac.py b/python/jax/carfac.py index 2d622e1..1e73e0a 100644 --- a/python/jax/carfac.py +++ b/python/jax/carfac.py @@ -2301,7 +2301,7 @@ def run_segment( weights: CarfacWeights, state: CarfacState, open_loop: bool = False, -) -> Tuple[jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """This function runs the entire CARFAC model. That is, filters a 1 or more channel @@ -2335,6 +2335,8 @@ def run_segment( Returns: naps: neural activity pattern + naps_fibers: neural activity of different fibers + (only populated with non-zeros when ihc_style equals "two_cap_with_syn") state: the updated state of the CARFAC model. BM: The basilar membrane motion seg_ohc & seg_agc are optional extra outputs useful for seeing what the @@ -2343,6 +2345,7 @@ def run_segment( if len(input_waves.shape) < 2: input_waves = jnp.reshape(input_waves, (-1, 1)) [n_samp, n_ears] = input_waves.shape + n_fibertypes = SynDesignParameters.n_classes # TODO(honglinyu): add more assertions using checkify. # if n_ears != cfp.n_ears: @@ -2352,6 +2355,7 @@ def run_segment( n_ch = hypers.ears[0].car.n_ch naps = jnp.zeros((n_samp, n_ch, n_ears)) # allocate space for result + naps_fibers = jnp.zeros((n_samp, n_ch, n_fibertypes, n_ears)) bm = jnp.zeros((n_samp, n_ch, n_ears)) seg_ohc = jnp.zeros((n_samp, n_ch, n_ears)) seg_agc = jnp.zeros((n_samp, n_ch, n_ears)) @@ -2370,7 +2374,7 @@ def run_segment( # Note that we can use naive for loops here because it will make gradient # computation very slow. def run_segment_scan_helper(carry, k): - naps, state, bm, seg_ohc, seg_agc, input_waves = carry + naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves = carry agc_updated = False for ear in range(n_ears): # This would be cleaner if we could just get and use a reference to @@ -2385,9 +2389,12 @@ def run_segment_scan_helper(carry, k): ) if hypers.ears[ear].syn.do_syn: - ihc_out, _, state.ears[ear].syn = syn_step( + ihc_out, firings, state.ears[ear].syn = syn_step( v_recep, ear, weights, state.ears[ear].syn ) + naps_fibers = naps_fibers.at[k, :, :, ear].set(firings) + else: + naps_fibers = naps_fibers.at[k, :, :, ear].set(jnp.zeros([jnp.shape(ihc_out)[0], n_fibertypes])) # run the AGC update step, decimating internally, agc_updated, state.ears[ear].agc = agc_step( @@ -2420,11 +2427,11 @@ def close_agc_loop_helper( state, ) - return (naps, state, bm, seg_ohc, seg_agc, input_waves), None + return (naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves), None return jax.lax.scan( run_segment_scan_helper, - (naps, state, bm, seg_ohc, seg_agc, input_waves), + (naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves), jnp.arange(n_samp), )[0][:-1] @@ -2442,7 +2449,7 @@ def run_segment_jit( weights: CarfacWeights, state: CarfacState, open_loop: bool = False, -) -> Tuple[jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """A JITted version of run_segment for convenience. Care should be taken with the hyper parameters in hypers. If the hypers object @@ -2468,6 +2475,8 @@ def run_segment_jit( Returns: naps: neural activity pattern + naps_fibers: neural activity of the different fiber types + (only populated with non-zeros when ihc_style equals "two_cap_with_syn") state: the updated state of the CARFAC model. BM: The basilar membrane motion seg_ohc & seg_agc are optional extra outputs useful for seeing what the @@ -2483,7 +2492,7 @@ def run_segment_jit_in_chunks_notraceable( state: CarfacState, open_loop: bool = False, segment_chunk_length: int = 32 * 48000, -) -> tuple[jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]: +) -> tuple[jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Runs the jitted segment runner in segment groups. Running CARFAC on an audio segment this way is most useful when running @@ -2526,6 +2535,7 @@ def run_segment_jit_in_chunks_notraceable( if len(input_waves.shape) < 2: input_waves = jnp.reshape(input_waves, (-1, 1)) naps_out = [] + naps_fibers_out = [] bm_out = [] ohc_out = [] agc_out = [] @@ -2534,10 +2544,11 @@ def run_segment_jit_in_chunks_notraceable( [n_samp, _] = input_waves.shape if n_samp >= segment_length: [current_waves, input_waves] = jnp.split(input_waves, [segment_length], 0) - naps_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = run_segment_jit( + naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = run_segment_jit( current_waves, hypers, weights, state, open_loop ) naps_out.append(naps_jax) + naps_fibers_out.append(naps_fibers_jax) bm_out.append(bm_jax) ohc_out.append(seg_ohc_jax) agc_out.append(seg_agc_jax) @@ -2546,15 +2557,17 @@ def run_segment_jit_in_chunks_notraceable( [n_samp, _] = input_waves.shape # Take the last few items and just run them. if n_samp > 0: - naps_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = run_segment_jit( + naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = run_segment_jit( input_waves, hypers, weights, state, open_loop ) naps_out.append(naps_jax) + naps_fibers_out.append(naps_fibers_jax) bm_out.append(bm_jax) ohc_out.append(seg_ohc_jax) agc_out.append(seg_agc_jax) naps_out = np.concatenate(naps_out, 0) + naps_fibers_out = np.concatenate(naps_fibers_out, 0) bm_out = np.concatenate(bm_out, 0) ohc_out = np.concatenate(ohc_out, 0) agc_out = np.concatenate(agc_out, 0) - return naps_out, state, bm_out, ohc_out, agc_out + return naps_out, naps_fibers_out, state, bm_out, ohc_out, agc_out diff --git a/python/jax/carfac_bench.py b/python/jax/carfac_bench.py index 6a07e6f..4490e7f 100644 --- a/python/jax/carfac_bench.py +++ b/python/jax/carfac_bench.py @@ -179,7 +179,7 @@ def loss_func( weights: carfac_jax.CarfacWeights, state: carfac_jax.CarfacState, ): - nap_output, _, _, _, _ = carfac_jax.run_segment( + nap_output, naps_fibers_output, _, _, _, _ = carfac_jax.run_segment( audio, hypers, weights, state ) return jnp.sum(nap_output), nap_output @@ -242,7 +242,7 @@ def bench_jit_compile_time(state: google_benchmark.State): # that this benchmark is appropriate. n_samp += 1 state.resume_timing() - naps_jax, state_jax, _, _, _ = carfac_jax.run_segment_jit( + naps_jax, naps_fibers_jax, state_jax, _, _, _ = carfac_jax.run_segment_jit( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) naps_jax.block_until_ready() @@ -295,7 +295,7 @@ def bench_jax_in_slices(state: google_benchmark.State): for _, segment in enumerate(silence_slices): if segment.shape not in compiled_shapes: compiled_shapes.add(segment.shape) - naps_jax, _, _, _, _ = carfac_jax.run_segment_jit( + naps_jax, naps_fibers_jax, _, _, _, _ = carfac_jax.run_segment_jit( segment, hypers_jax, weights_jax, state_jax, open_loop=False ) naps_jax.block_until_ready() @@ -316,7 +316,7 @@ def bench_jax_in_slices(state: google_benchmark.State): jax_loop_state = state_jax state.resume_timing() for _, segment in enumerate(run_seg_slices): - seg_naps, jax_loop_state, seg_bm, seg_ohc, seg_agc = ( + seg_naps, seg_naps_fibers, jax_loop_state, seg_bm, seg_ohc, seg_agc = ( carfac_jax.run_segment_jit( segment, hypers_jax, weights_jax, jax_loop_state, open_loop=False ) @@ -389,7 +389,7 @@ def bench_jax(state: google_benchmark.State): params_jax ) short_silence = jnp.zeros(shape=(n_samp, n_ears)) - naps_jax, state_jax, _, _, _ = run_segment_function( + naps_jax, _, state_jax, _, _, _ = run_segment_function( short_silence, hypers_jax, weights_jax, state_jax, open_loop=False ) # This block ensures calculation. @@ -404,7 +404,7 @@ def bench_jax(state: google_benchmark.State): jax.random.normal(random_generator, (n_samp, n_ears)) * _NOISE_FACTOR ).block_until_ready() state.resume_timing() - naps_jax, state_jax, _, _, _ = run_segment_function( + naps_jax, naps_fibers_jax, state_jax, _, _, _ = run_segment_function( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) if state.range(0) != 1: diff --git a/python/jax/carfac_float64_test.py b/python/jax/carfac_float64_test.py index 921dc7d..6bc54a5 100644 --- a/python/jax/carfac_float64_test.py +++ b/python/jax/carfac_float64_test.py @@ -46,7 +46,7 @@ def loss(weights, input_waves, hypers, state): # A loss function for tests. Note that we shouldn't use `run_segment_jit` # here because it will donate the `state` which causes unnecessary # inconvenience for tests. - naps_jax, state_jax, _, _, _ = carfac_jax.run_segment( + naps_jax, naps_fibers_jax, state_jax, _, _, _ = carfac_jax.run_segment( input_waves, hypers, weights, state, open_loop=False ) # For testing, just fit `naps` to 1. diff --git a/python/jax/carfac_test.py b/python/jax/carfac_test.py index 918d074..acb985b 100644 --- a/python/jax/carfac_test.py +++ b/python/jax/carfac_test.py @@ -324,10 +324,10 @@ def test_chunked_naps_same_as_jit(self, random_seed, ihc_style): state_jax_copied = copy.deepcopy(state_jax) # Only tests the JITted version because this is what we will use. - naps_jax, _, bm_jax, ohc_jax, agc_jax = carfac_jax.run_segment_jit( + naps_jax, naps_fibers_jax, _, bm_jax, ohc_jax, agc_jax = carfac_jax.run_segment_jit( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) - naps_jax_chunked, _, bm_chunked, ohc_chunked, agc_chunked = ( + naps_jax_chunked, naps_fibes_jax_chunked, _, bm_chunked, ohc_chunked, agc_chunked = ( carfac_jax.run_segment_jit_in_chunks_notraceable( run_seg_input, hypers_jax, @@ -380,7 +380,7 @@ def test_equal_forward_pass( run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) # Only tests the JITted version because this is what we will use. - naps_jax, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = ( + naps_jax, naps_fibers_jax, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = ( carfac_jax.run_segment_jit( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) diff --git a/python/jax/carfac_util.py b/python/jax/carfac_util.py index 9071b09..7678571 100644 --- a/python/jax/carfac_util.py +++ b/python/jax/carfac_util.py @@ -38,6 +38,7 @@ def run_multiple_segment_states_shmap( open_loop: bool = False, ) -> Sequence[ Tuple[ + jnp.ndarray, jnp.ndarray, carfac_jax.CarfacState, jnp.ndarray, @@ -88,7 +89,7 @@ def parallel_helper(input_waves, state): """ input_waves = input_waves[0] state = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), state) - naps, ret_state, bm, seg_ohc, seg_agc = carfac_jax.run_segment_jit( + naps, naps_fibers, ret_state, bm, seg_ohc, seg_agc = carfac_jax.run_segment_jit( input_waves, hypers, weights, state, open_loop ) ret_state = jax.tree_util.tree_map( @@ -96,13 +97,14 @@ def parallel_helper(input_waves, state): ) return ( naps[None], + naps_fibers[None], ret_state, bm[None], seg_ohc[None], seg_agc[None], ) - stacked_naps, stacked_states, stacked_bm, stacked_ohc, stacked_agc = ( + stacked_naps, stacked_naps_fibers, stacked_states, stacked_bm, stacked_ohc, stacked_agc = ( parallel_helper(input_waves_array, batch_state) ) output_states = _tree_unstack(stacked_states) @@ -112,6 +114,7 @@ def parallel_helper(input_waves, state): for i, output_state in enumerate(output_states): tup = ( stacked_naps[i], + stacked_naps_fibers[i], output_state, stacked_bm[i], stacked_ohc[i], @@ -130,6 +133,7 @@ def run_multiple_segment_pmap( open_loop: bool = False, ) -> Sequence[ Tuple[ + jnp.ndarray, jnp.ndarray, carfac_jax.CarfacState, jnp.ndarray, @@ -155,7 +159,7 @@ def run_multiple_segment_pmap( in_axes=(0, None, None, None, None), static_broadcasted_argnums=[1, 4], ) - stacked_naps, stacked_states, stacked_bm, stacked_ohc, stacked_agc = pmapped( + stacked_naps, stacked_naps_fibers, stacked_states, stacked_bm, stacked_ohc, stacked_agc = pmapped( input_waves_array, hypers, weights, state, open_loop ) @@ -164,6 +168,7 @@ def run_multiple_segment_pmap( for i, output_state in enumerate(output_states): tup = ( stacked_naps[i], + stacked_naps_fibers[i], output_state, stacked_bm[i], stacked_ohc[i], diff --git a/python/jax/carfac_util_test.py b/python/jax/carfac_util_test.py index 753238b..a88c449 100644 --- a/python/jax/carfac_util_test.py +++ b/python/jax/carfac_util_test.py @@ -59,7 +59,7 @@ def test_same_outputs_parallel_for_pmap(self): ], axis=0, ) - nap_out_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a = ( + nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a = ( carfac.run_segment_jit( self.sample_a, self.hypers, @@ -68,7 +68,7 @@ def test_same_outputs_parallel_for_pmap(self): self.open_loop, ) ) - nap_out_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b = ( + nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b = ( carfac.run_segment_jit( self.sample_b, self.hypers, @@ -86,20 +86,22 @@ def test_same_outputs_parallel_for_pmap(self): ) self.assertTrue((combined_output[0][0] == nap_out_a).all()) self.assertTrue((combined_output[1][0] == nap_out_b).all()) - self.assertTrue((combined_output[0][2] == bm_out_a).all()) - self.assertTrue((combined_output[1][2] == bm_out_b).all()) - self.assertTrue((combined_output[0][3] == ohc_out_a).all()) - self.assertTrue((combined_output[1][3] == ohc_out_b).all()) - self.assertTrue((combined_output[0][4] == agc_out_a).all()) - self.assertTrue((combined_output[1][4] == agc_out_b).all()) + self.assertTrue((combined_output[0][1] == nap_fibers_out_a).all()) + self.assertTrue((combined_output[1][1] == nap_fibers_out_b).all()) + self.assertTrue((combined_output[0][3] == bm_out_a).all()) + self.assertTrue((combined_output[1][3] == bm_out_b).all()) + self.assertTrue((combined_output[0][4] == ohc_out_a).all()) + self.assertTrue((combined_output[1][4] == ohc_out_b).all()) + self.assertTrue((combined_output[0][5] == agc_out_a).all()) + self.assertTrue((combined_output[1][5] == agc_out_b).all()) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_a, combined_output[0][1]) + jax.tree.map(jnp.allclose, state_out_a, combined_output[0][2]) ) ) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_b, combined_output[1][1]) + jax.tree.map(jnp.allclose, state_out_b, combined_output[1][2]) ) ) @@ -112,7 +114,7 @@ def test_same_outputs_parallel_for_shmap(self): axis=0, ) - nap_out_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a = ( + nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a = ( carfac.run_segment_jit( self.sample_a, self.hypers, @@ -124,7 +126,7 @@ def test_same_outputs_parallel_for_shmap(self): # Run sample B twice, so we have a separate "starting" state for the # test for shmap. - _, state_out_b_first, _, _, _ = carfac.run_segment_jit( + _, _, state_out_b_first, _, _, _ = carfac.run_segment_jit( self.sample_b, self.hypers, self.weights, @@ -132,7 +134,7 @@ def test_same_outputs_parallel_for_shmap(self): self.open_loop, ) - nap_out_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b = ( + nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b = ( carfac.run_segment_jit( self.sample_b, self.hypers, @@ -150,20 +152,22 @@ def test_same_outputs_parallel_for_shmap(self): ) self.assertTrue((combined_output[0][0] == nap_out_a).all()) self.assertTrue((combined_output[1][0] == nap_out_b).all()) - self.assertTrue((combined_output[0][2] == bm_out_a).all()) - self.assertTrue((combined_output[1][2] == bm_out_b).all()) - self.assertTrue((combined_output[0][3] == ohc_out_a).all()) - self.assertTrue((combined_output[1][3] == ohc_out_b).all()) - self.assertTrue((combined_output[0][4] == agc_out_a).all()) - self.assertTrue((combined_output[1][4] == agc_out_b).all()) + self.assertTrue((combined_output[0][1] == nap_fibers_out_a).all()) + self.assertTrue((combined_output[1][1] == nap_fibers_out_b).all()) + self.assertTrue((combined_output[0][3] == bm_out_a).all()) + self.assertTrue((combined_output[1][3] == bm_out_b).all()) + self.assertTrue((combined_output[0][4] == ohc_out_a).all()) + self.assertTrue((combined_output[1][4] == ohc_out_b).all()) + self.assertTrue((combined_output[0][5] == agc_out_a).all()) + self.assertTrue((combined_output[1][5] == agc_out_b).all()) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_a, combined_output[0][1]) + jax.tree.map(jnp.allclose, state_out_a, combined_output[0][2]) ) ) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_b, combined_output[1][1]) + jax.tree.map(jnp.allclose, state_out_b, combined_output[1][2]) ) ) @@ -171,12 +175,12 @@ def test_same_outputs_parallel_for_shmap(self): # equality is complete and double sided. self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, combined_output[0][1], state_out_a) + jax.tree.map(jnp.allclose, combined_output[0][2], state_out_a) ) ) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, combined_output[1][1], state_out_b) + jax.tree.map(jnp.allclose, combined_output[1][2], state_out_b) ) )