Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added naps_fibers parameter as output #19

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions python/jax/carfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2301,7 +2301,7 @@ def run_segment(
weights: CarfacWeights,
state: CarfacState,
open_loop: bool = False,
) -> Tuple[jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
) -> Tuple[jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""This function runs the entire CARFAC model.

That is, filters a 1 or more channel
Expand Down Expand Up @@ -2335,6 +2335,8 @@ def run_segment(

Returns:
naps: neural activity pattern
naps_fibers: neural activity of different fibers
(only populated with non-zeros when ihc_style equals "two_cap_with_syn")
state: the updated state of the CARFAC model.
BM: The basilar membrane motion
seg_ohc & seg_agc are optional extra outputs useful for seeing what the
Expand All @@ -2343,6 +2345,7 @@ def run_segment(
if len(input_waves.shape) < 2:
input_waves = jnp.reshape(input_waves, (-1, 1))
[n_samp, n_ears] = input_waves.shape
n_fibertypes = SynDesignParameters.n_classes

# TODO(honglinyu): add more assertions using checkify.
# if n_ears != cfp.n_ears:
Expand All @@ -2352,6 +2355,7 @@ def run_segment(

n_ch = hypers.ears[0].car.n_ch
naps = jnp.zeros((n_samp, n_ch, n_ears)) # allocate space for result
naps_fibers = jnp.zeros((n_samp, n_ch, n_fibertypes, n_ears))
bm = jnp.zeros((n_samp, n_ch, n_ears))
seg_ohc = jnp.zeros((n_samp, n_ch, n_ears))
seg_agc = jnp.zeros((n_samp, n_ch, n_ears))
Expand All @@ -2370,7 +2374,7 @@ def run_segment(
# Note that we can use naive for loops here because it will make gradient
# computation very slow.
def run_segment_scan_helper(carry, k):
naps, state, bm, seg_ohc, seg_agc, input_waves = carry
naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves = carry
agc_updated = False
for ear in range(n_ears):
# This would be cleaner if we could just get and use a reference to
Expand All @@ -2385,9 +2389,12 @@ def run_segment_scan_helper(carry, k):
)

if hypers.ears[ear].syn.do_syn:
ihc_out, _, state.ears[ear].syn = syn_step(
ihc_out, firings, state.ears[ear].syn = syn_step(
v_recep, ear, weights, state.ears[ear].syn
)
naps_fibers = naps_fibers.at[k, :, :, ear].set(firings)
else:
naps_fibers = naps_fibers.at[k, :, :, ear].set(jnp.zeros([jnp.shape(ihc_out)[0], n_fibertypes]))

# run the AGC update step, decimating internally,
agc_updated, state.ears[ear].agc = agc_step(
Expand Down Expand Up @@ -2420,11 +2427,11 @@ def close_agc_loop_helper(
state,
)

return (naps, state, bm, seg_ohc, seg_agc, input_waves), None
return (naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves), None

return jax.lax.scan(
run_segment_scan_helper,
(naps, state, bm, seg_ohc, seg_agc, input_waves),
(naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves),
jnp.arange(n_samp),
)[0][:-1]

Expand All @@ -2442,7 +2449,7 @@ def run_segment_jit(
weights: CarfacWeights,
state: CarfacState,
open_loop: bool = False,
) -> Tuple[jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
) -> Tuple[jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""A JITted version of run_segment for convenience.

Care should be taken with the hyper parameters in hypers. If the hypers object
Expand All @@ -2468,6 +2475,8 @@ def run_segment_jit(

Returns:
naps: neural activity pattern
naps_fibers: neural activity of the different fiber types
(only populated with non-zeros when ihc_style equals "two_cap_with_syn")
state: the updated state of the CARFAC model.
BM: The basilar membrane motion
seg_ohc & seg_agc are optional extra outputs useful for seeing what the
Expand All @@ -2483,7 +2492,7 @@ def run_segment_jit_in_chunks_notraceable(
state: CarfacState,
open_loop: bool = False,
segment_chunk_length: int = 32 * 48000,
) -> tuple[jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
) -> tuple[jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Runs the jitted segment runner in segment groups.

Running CARFAC on an audio segment this way is most useful when running
Expand Down Expand Up @@ -2526,6 +2535,7 @@ def run_segment_jit_in_chunks_notraceable(
if len(input_waves.shape) < 2:
input_waves = jnp.reshape(input_waves, (-1, 1))
naps_out = []
naps_fibers_out = []
bm_out = []
ohc_out = []
agc_out = []
Expand All @@ -2534,10 +2544,11 @@ def run_segment_jit_in_chunks_notraceable(
[n_samp, _] = input_waves.shape
if n_samp >= segment_length:
[current_waves, input_waves] = jnp.split(input_waves, [segment_length], 0)
naps_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = run_segment_jit(
naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = run_segment_jit(
current_waves, hypers, weights, state, open_loop
)
naps_out.append(naps_jax)
naps_fibers_out.append(naps_fibers_jax)
bm_out.append(bm_jax)
ohc_out.append(seg_ohc_jax)
agc_out.append(seg_agc_jax)
Expand All @@ -2546,15 +2557,17 @@ def run_segment_jit_in_chunks_notraceable(
[n_samp, _] = input_waves.shape
# Take the last few items and just run them.
if n_samp > 0:
naps_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = run_segment_jit(
naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = run_segment_jit(
input_waves, hypers, weights, state, open_loop
)
naps_out.append(naps_jax)
naps_fibers_out.append(naps_fibers_jax)
bm_out.append(bm_jax)
ohc_out.append(seg_ohc_jax)
agc_out.append(seg_agc_jax)
naps_out = np.concatenate(naps_out, 0)
naps_fibers_out = np.concatenate(naps_fibers_out, 0)
bm_out = np.concatenate(bm_out, 0)
ohc_out = np.concatenate(ohc_out, 0)
agc_out = np.concatenate(agc_out, 0)
return naps_out, state, bm_out, ohc_out, agc_out
return naps_out, naps_fibers_out, state, bm_out, ohc_out, agc_out
12 changes: 6 additions & 6 deletions python/jax/carfac_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def loss_func(
weights: carfac_jax.CarfacWeights,
state: carfac_jax.CarfacState,
):
nap_output, _, _, _, _ = carfac_jax.run_segment(
nap_output, naps_fibers_output, _, _, _, _ = carfac_jax.run_segment(
audio, hypers, weights, state
)
return jnp.sum(nap_output), nap_output
Expand Down Expand Up @@ -242,7 +242,7 @@ def bench_jit_compile_time(state: google_benchmark.State):
# that this benchmark is appropriate.
n_samp += 1
state.resume_timing()
naps_jax, state_jax, _, _, _ = carfac_jax.run_segment_jit(
naps_jax, naps_fibers_jax, state_jax, _, _, _ = carfac_jax.run_segment_jit(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
naps_jax.block_until_ready()
Expand Down Expand Up @@ -295,7 +295,7 @@ def bench_jax_in_slices(state: google_benchmark.State):
for _, segment in enumerate(silence_slices):
if segment.shape not in compiled_shapes:
compiled_shapes.add(segment.shape)
naps_jax, _, _, _, _ = carfac_jax.run_segment_jit(
naps_jax, naps_fibers_jax, _, _, _, _ = carfac_jax.run_segment_jit(
segment, hypers_jax, weights_jax, state_jax, open_loop=False
)
naps_jax.block_until_ready()
Expand All @@ -316,7 +316,7 @@ def bench_jax_in_slices(state: google_benchmark.State):
jax_loop_state = state_jax
state.resume_timing()
for _, segment in enumerate(run_seg_slices):
seg_naps, jax_loop_state, seg_bm, seg_ohc, seg_agc = (
seg_naps, seg_naps_fibers, jax_loop_state, seg_bm, seg_ohc, seg_agc = (
carfac_jax.run_segment_jit(
segment, hypers_jax, weights_jax, jax_loop_state, open_loop=False
)
Expand Down Expand Up @@ -389,7 +389,7 @@ def bench_jax(state: google_benchmark.State):
params_jax
)
short_silence = jnp.zeros(shape=(n_samp, n_ears))
naps_jax, state_jax, _, _, _ = run_segment_function(
naps_jax, _, state_jax, _, _, _ = run_segment_function(
short_silence, hypers_jax, weights_jax, state_jax, open_loop=False
)
# This block ensures calculation.
Expand All @@ -404,7 +404,7 @@ def bench_jax(state: google_benchmark.State):
jax.random.normal(random_generator, (n_samp, n_ears)) * _NOISE_FACTOR
).block_until_ready()
state.resume_timing()
naps_jax, state_jax, _, _, _ = run_segment_function(
naps_jax, naps_fibers_jax, state_jax, _, _, _ = run_segment_function(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
if state.range(0) != 1:
Expand Down
2 changes: 1 addition & 1 deletion python/jax/carfac_float64_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def loss(weights, input_waves, hypers, state):
# A loss function for tests. Note that we shouldn't use `run_segment_jit`
# here because it will donate the `state` which causes unnecessary
# inconvenience for tests.
naps_jax, state_jax, _, _, _ = carfac_jax.run_segment(
naps_jax, naps_fibers_jax, state_jax, _, _, _ = carfac_jax.run_segment(
input_waves, hypers, weights, state, open_loop=False
)
# For testing, just fit `naps` to 1.
Expand Down
6 changes: 3 additions & 3 deletions python/jax/carfac_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,10 @@ def test_chunked_naps_same_as_jit(self, random_seed, ihc_style):
state_jax_copied = copy.deepcopy(state_jax)

# Only tests the JITted version because this is what we will use.
naps_jax, _, bm_jax, ohc_jax, agc_jax = carfac_jax.run_segment_jit(
naps_jax, naps_fibers_jax, _, bm_jax, ohc_jax, agc_jax = carfac_jax.run_segment_jit(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
naps_jax_chunked, _, bm_chunked, ohc_chunked, agc_chunked = (
naps_jax_chunked, naps_fibes_jax_chunked, _, bm_chunked, ohc_chunked, agc_chunked = (
carfac_jax.run_segment_jit_in_chunks_notraceable(
run_seg_input,
hypers_jax,
Expand Down Expand Up @@ -380,7 +380,7 @@ def test_equal_forward_pass(
run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears))

# Only tests the JITted version because this is what we will use.
naps_jax, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = (
naps_jax, naps_fibers_jax, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = (
carfac_jax.run_segment_jit(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
Expand Down
11 changes: 8 additions & 3 deletions python/jax/carfac_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def run_multiple_segment_states_shmap(
open_loop: bool = False,
) -> Sequence[
Tuple[
jnp.ndarray,
jnp.ndarray,
carfac_jax.CarfacState,
jnp.ndarray,
Expand Down Expand Up @@ -88,21 +89,22 @@ def parallel_helper(input_waves, state):
"""
input_waves = input_waves[0]
state = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), state)
naps, ret_state, bm, seg_ohc, seg_agc = carfac_jax.run_segment_jit(
naps, naps_fibers, ret_state, bm, seg_ohc, seg_agc = carfac_jax.run_segment_jit(
input_waves, hypers, weights, state, open_loop
)
ret_state = jax.tree_util.tree_map(
lambda x: jnp.asarray(x).reshape((1, -1)), ret_state
)
return (
naps[None],
naps_fibers[None],
ret_state,
bm[None],
seg_ohc[None],
seg_agc[None],
)

stacked_naps, stacked_states, stacked_bm, stacked_ohc, stacked_agc = (
stacked_naps, stacked_naps_fibers, stacked_states, stacked_bm, stacked_ohc, stacked_agc = (
parallel_helper(input_waves_array, batch_state)
)
output_states = _tree_unstack(stacked_states)
Expand All @@ -112,6 +114,7 @@ def parallel_helper(input_waves, state):
for i, output_state in enumerate(output_states):
tup = (
stacked_naps[i],
stacked_naps_fibers[i],
output_state,
stacked_bm[i],
stacked_ohc[i],
Expand All @@ -130,6 +133,7 @@ def run_multiple_segment_pmap(
open_loop: bool = False,
) -> Sequence[
Tuple[
jnp.ndarray,
jnp.ndarray,
carfac_jax.CarfacState,
jnp.ndarray,
Expand All @@ -155,7 +159,7 @@ def run_multiple_segment_pmap(
in_axes=(0, None, None, None, None),
static_broadcasted_argnums=[1, 4],
)
stacked_naps, stacked_states, stacked_bm, stacked_ohc, stacked_agc = pmapped(
stacked_naps, stacked_naps_fibers, stacked_states, stacked_bm, stacked_ohc, stacked_agc = pmapped(
input_waves_array, hypers, weights, state, open_loop
)

Expand All @@ -164,6 +168,7 @@ def run_multiple_segment_pmap(
for i, output_state in enumerate(output_states):
tup = (
stacked_naps[i],
stacked_naps_fibers[i],
output_state,
stacked_bm[i],
stacked_ohc[i],
Expand Down
Loading