diff --git a/python/jax/carfac.py b/python/jax/carfac.py index 984fefa..2d622e1 100644 --- a/python/jax/carfac.py +++ b/python/jax/carfac.py @@ -556,6 +556,172 @@ def tree_unflatten(cls, _, children): return cls(*children) +@jax.tree_util.register_pytree_node_class +@dataclasses.dataclass +class SynDesignParameters: + """Variables needed for the synapse implementation.""" + + n_classes: int = 3 # Default. Modify params and redesign to change. + ihcs_per_channel: dataclasses.InitVar[int] = 10 # Maybe 20 would be better? + healthy_n_fibers: jnp.ndarray = dataclasses.field( + default_factory=lambda: (jnp.array([50, 35, 25])) + ) + spont_rates: jnp.ndarray = dataclasses.field( + default_factory=lambda: jnp.array([50, 6, 1]) + ) + sat_rates: float = 200.0 + sat_reservoir: float = 0.2 + v_width: float = 0.02 + tau_lpf: float = 0.000080 + reservoir_tau: float = 0.02 + # Tweaked. + # The weights 1.2 were picked before correctly account for sample rate + # and number of fibers. This way works for more different numbers. + agc_weights: jnp.ndarray = dataclasses.field( + default_factory=lambda: (jnp.array([1.2, 1.2, 1.2]) / (22050)) + ) + + def __post_init__(self, ihcs_per_channel): + self.healthy_n_fibers *= ihcs_per_channel + self.agc_weights /= ihcs_per_channel + + # The following 2 functions are boiler code required by pytree. + # Reference: https://jax.readthedocs.io/en/latest/pytrees.html + # Note that ihcs_per_channel is set to 1 for flattening/unflattening, + # to avoid repeat multiplication in __post_init__. + def tree_flatten(self): # pylint: disable=missing-function-docstring + children = ( + self.n_classes, + 1, + self.healthy_n_fibers, + self.spont_rates, + self.sat_rates, + self.sat_reservoir, + self.v_width, + self.tau_lpf, + self.reservoir_tau, + self.agc_weights, + ) + aux_data = ( + 'n_classes', + 'ihcs_per_channel', + 'healthy_n_fibers', + 'spont_rates', + 'sat_rates', + 'sat_reservoir', + 'v_width', + 'tau_lpf', + 'reservoir_tau', + 'agc_weights', + ) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, _, children): + return cls(*children) + + +@jax.tree_util.register_pytree_node_class +@dataclasses.dataclass +class SynHypers: + """Hyperparameters for the IHC synapse. Tagged `static` in `jax.jit`.""" + + do_syn: bool = False + + # The following 2 functions are boiler code required by pytree. + # Reference: https://jax.readthedocs.io/en/latest/pytrees.html + def tree_flatten(self): + children = [self.do_syn] + aux_data = ('do_syn',) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, _, children): + return cls(*children) + + # Needed by `static_argnums` of `jax.jit`. + def __hash__(self): + children, _ = self.tree_flatten() + return hash(children) + + def __eq__(self, other): + return _are_two_equal_hypers(self, other) + + +@jax.tree_util.register_pytree_node_class +@dataclasses.dataclass +class SynWeights: + """Trainable weights of the IHC synapse.""" + + n_fibers: jnp.ndarray + v_widths: jnp.ndarray + v_halfs: jnp.ndarray # Same units as v_recep and v_widths. + a1: float # Feedback gain + a2: float # Output gain + agc_weights: jnp.ndarray # For making a nap out to agc in. + spont_p: jnp.ndarray # Used only to init the output LPF + spont_sub: jnp.ndarray + res_lpf_inits: jnp.ndarray + res_coeff: float + lpf_coeff: float + + # The following 2 functions are boiler code required by pytree. + # Reference: https://jax.readthedocs.io/en/latest/pytrees.html + def tree_flatten(self): + """Flatten tree for pytree.""" + children = ( + self.n_fibers, + self.v_widths, + self.v_halfs, + self.a1, + self.a2, + self.agc_weights, + self.spont_p, + self.spont_sub, + self.res_lpf_inits, + self.res_coeff, + self.lpf_coeff, + ) + aux_data = ( + 'n_fibers', + 'v_widths', + 'v_halfs', + 'a1', + 'a2', + 'agc_weights', + 'spont_p', + 'spont_sub', + 'res_lpf_inits', + 'res_coeff', + 'lpf_coeff', + ) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, _, children): + return cls(*children) + + +@jax.tree_util.register_pytree_node_class +@dataclasses.dataclass +class SynState: + """All the state variables for the IHC synapse.""" + + reservoirs: jnp.ndarray + lpf_state: jnp.ndarray + + # The following 2 functions are boiler code required by pytree. + # Reference: https://jax.readthedocs.io/en/latest/pytrees.html + def tree_flatten(self): + children = (self.reservoirs, self.lpf_state) + aux_data = ('reservoirs', 'lpf_state') + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, _, children): + return cls(*children) + + @jax.tree_util.register_pytree_node_class @dataclasses.dataclass class IhcHypers: @@ -705,6 +871,9 @@ class EarDesignParameters: ihc: IhcDesignParameters = dataclasses.field( default_factory=IhcDesignParameters ) + syn: SynDesignParameters = dataclasses.field( + default_factory=SynDesignParameters + ) # The following 2 functions are boiler code required by pytree. # Reference: https://jax.readthedocs.io/en/latest/pytrees.html @@ -732,14 +901,29 @@ class EarHypers: default_factory=list # One element per AGC layer. ) ihc: Optional[IhcHypers] = None + syn: Optional[SynHypers] = None # The following 2 functions are boiler code required by pytree. # Reference: https://jax.readthedocs.io/en/latest/pytrees.html def tree_flatten(self): # pylint: disable=missing-function-docstring - children = (self.n_ch, self.pole_freqs, self.max_channels_per_octave, - self.car, self.agc, self.ihc) - aux_data = ('n_ch', 'pole_freqs', 'max_channels_per_octave', - 'car', 'agc', 'ihc') + children = ( + self.n_ch, + self.pole_freqs, + self.max_channels_per_octave, + self.car, + self.agc, + self.ihc, + self.syn, + ) + aux_data = ( + 'n_ch', + 'pole_freqs', + 'max_channels_per_octave', + 'car', + 'agc', + 'ihc', + 'syn', + ) return (children, aux_data) @classmethod @@ -773,12 +957,13 @@ class EarWeights: car: Optional[CarWeights] = None agc: Optional[List[AgcWeights]] = None ihc: Optional[IhcWeights] = None + syn: Optional[SynWeights] = None # The following 2 functions are boiler code required by pytree. # Reference: https://jax.readthedocs.io/en/latest/pytrees.html def tree_flatten(self): # pylint: disable=missing-function-docstring - children = (self.car, self.agc, self.ihc) - aux_data = ('car', 'agc', 'ihc') + children = (self.car, self.agc, self.ihc, self.syn) + aux_data = ('car', 'agc', 'ihc', 'syn') return (children, aux_data) @classmethod @@ -793,12 +978,13 @@ class EarState: car: Optional[CarState] = None ihc: Optional[IhcState] = None agc: Optional[List[AgcState]] = None + syn: Optional[SynState] = None # The following 2 functions are boiler code required by pytree. # Reference: https://jax.readthedocs.io/en/latest/pytrees.html def tree_flatten(self): # pylint: disable=missing-function-docstring - children = (self.car, self.ihc, self.agc) - aux_data = ('car', 'ihc', 'agc') + children = (self.car, self.ihc, self.agc, self.syn) + aux_data = ('car', 'ihc', 'agc', 'syn') return (children, aux_data) @classmethod @@ -1109,6 +1295,87 @@ def ihc_detect(x: Union[float, jnp.ndarray]) -> Union[float, jnp.ndarray]: return conductance +def design_and_init_ihc_syn( + ear: int, params: CarfacDesignParameters, hypers: CarfacHypers +) -> Tuple[SynHypers, Optional[SynWeights], Optional[SynState]]: + """Design the IHC Synapse Hypers/Weights/State. + + Only returns State and Weights if Params and CarfacHypers are set up to use + synapse. + + Args: + ear: the index of the ear. + params: all the design parameters + hypers: all the hyper parameters. + + Returns: + A tuple containing the newly created and initialised synapse coefficients, + weights and state. + """ + ear_params = params.ears[ear] + ear_hypers = hypers.ears[ear] + ihc_params = ear_params.ihc + syn_params = ear_params.syn + if ihc_params.ihc_style != 'two_cap_with_syn': + hypers = SynHypers(do_syn=False) + return hypers, None, None + + hypers = SynHypers(do_syn=True) + fs = params.fs + n_ch = ear_hypers.pole_freqs.shape[0] + n_classes = syn_params.n_classes + v_widths = syn_params.v_width * jnp.ones((1, n_classes)) + # Do some design. First, gains to get sat_rate when sigmoid is 1, which + # involves the reservoir steady-state solution. + # Most of these are not per-channel, but could be expanded that way + # later if desired. + # Mean sat prob of spike per sample per neuron, likely same for all + # classes. + # Use names 1 for sat and 0 for spont in some of these. + p1 = syn_params.sat_rates / fs + p0 = syn_params.spont_rates / fs + w1 = syn_params.sat_reservoir + q1 = 1 - w1 + # Assume the sigmoid is switching between 0 and 1 at 50% duty cycle, so + # normalized mean value is 0.5 at saturation. + s1 = 0.5 + r1 = s1 * w1 + # solve q1 = a1*r1 for gain coeff a1: + a1 = q1 / r1 + # solve p1 = a2*r1 for gain coeff a2: + a2 = p1 / r1 + # Now work out how to get the desired spont. + r0 = p0 / a2 + q0 = r0 * a1 + w0 = 1 - q0 + s0 = r0 / w0 + # Solve for (negative) sigmoid midpoint offset that gets s0 right. + offsets = jnp.log((1 - s0) / s0) + spont_p = a2 * w0 * s0 # should match p0; check it; yes it does + agc_weights = fs * syn_params.agc_weights + spont_sub = (syn_params.healthy_n_fibers * spont_p) @ jnp.transpose( + agc_weights + ) + weights = SynWeights( + n_fibers=jnp.ones((n_ch, 1)) * syn_params.healthy_n_fibers, + v_widths=v_widths, + v_halfs=offsets * v_widths, + a1=a1, + a2=a2, + agc_weights=agc_weights, + spont_p=spont_p, + spont_sub=spont_sub, + res_lpf_inits=q0, + res_coeff=1 - math.exp(-1 / (syn_params.reservoir_tau * fs)), + lpf_coeff=1 - math.exp(-1 / (syn_params.tau_lpf * fs)), + ) + state = SynState( + reservoirs=jnp.ones((n_ch, 1)) * weights.res_lpf_inits, + lpf_state=jnp.ones((n_ch, 1)) * weights.spont_p, + ) + return hypers, weights, state + + def design_and_init_ihc( ear: int, params: CarfacDesignParameters, @@ -1136,7 +1403,10 @@ def design_and_init_ihc( ihc_style_num = 0 elif ihc_params.ihc_style == 'one_cap': ihc_style_num = 1 - elif ihc_params.ihc_style == 'two_cap': + elif ( + ihc_params.ihc_style == 'two_cap' + or ihc_params.ihc_style == 'two_cap_with_syn' + ): ihc_style_num = 2 else: raise NotImplementedError @@ -1167,7 +1437,10 @@ def design_and_init_ihc( lpf1_state=ihc_weights.rest_output * jnp.ones((n_ch,)), lpf2_state=ihc_weights.rest_output * jnp.ones((n_ch,)), ) - elif ihc_params.ihc_style == 'two_cap': + elif ( + ihc_params.ihc_style == 'two_cap' + or ihc_params.ihc_style == 'two_cap_with_syn' + ): g1_max = ihc_detect(10) # receptor conductance at high level r1min = 1 / g1_max @@ -1458,6 +1731,9 @@ def design_and_init_carfac( ihc_hypers, ihc_weights, ihc_state = design_and_init_ihc( ear, params, hypers ) + syn_hypers, syn_weights, syn_state = design_and_init_ihc_syn( + ear, params, hypers + ) ear_hypers.car = car_hypers ear_weights.car = car_weights @@ -1468,6 +1744,9 @@ def design_and_init_carfac( ear_hypers.ihc = ihc_hypers ear_weights.ihc = ihc_weights ear_state.ihc = ihc_state + ear_hypers.syn = syn_hypers + ear_weights.syn = syn_weights + ear_state.syn = syn_state weights.ears.append(ear_weights) state.ears.append(ear_state) @@ -1612,13 +1891,61 @@ def car_step( return ac_diff, car_state +def syn_step( + v_recep: jnp.ndarray, + ear: int, + weights: CarfacWeights, + syn_state: SynState, +) -> Tuple[jnp.ndarray, jnp.ndarray, SynState]: + """Step the inner-hair cell model with ont input sample.""" + syn_weights = weights.ears[ear].syn + + # Drive multiple synapse classes with receptor potential from IHC, + # returning instantaneous spike rates per class, for a group of neurons + # associated with the CF channel, including reductions due to synaptopathy. + # Normalized offset position into neurotransmitter release sigmoid. + x = (v_recep - jnp.transpose(syn_weights.v_halfs)) / jnp.transpose( + syn_weights.v_widths + ) + x = jnp.transpose(x) + s = 1 / (1 + jnp.exp(-x)) # Between 0 and 1; positive at rest. + q = syn_state.reservoirs # aka 1 - w, between 0 and 1; positive at rest. + r = (1 - q) * s # aka w*s, between 0 and 1, proportional to release rate. + + # Smooth once with LPF (receptor potential was already smooth), after + # applying the gain coeff a2 to convert to firing prob per sample. + syn_state.lpf_state = syn_state.lpf_state + syn_weights.lpf_coeff * ( + syn_weights.a2 * r - syn_state.lpf_state + ) # this is firing probs. + firing_probs = syn_state.lpf_state # Poisson rate per neuron per sample. + # Include number of effective neurons per channel here, and interval T; + # so the rates (instantaneous action potentials per second) can be huge. + firings = syn_weights.n_fibers * firing_probs + + # Feedback, to update reservoir state q for next time. + syn_state.reservoirs = q + syn_weights.res_coeff * (syn_weights.a1 * r - q) + # Make an output that resembles ihc_out, to go to agc_in + # (collapse over classes). + # Includes synaptopathy's presumed effect of reducing feedback via n_fibers. + # But it's relative to the healthy nominal spont, so could potentially go + # a bit negative in quiet is there was loss of high-spont or medium-spont + # units. + # The weight multiplication is an inner product, reducing n_classes + # columns to 1 column (first transpose the agc_weights row to a column). + syn_out = ( + syn_weights.n_fibers * firing_probs + ) @ syn_weights.agc_weights - syn_weights.spont_sub + + return syn_out, firings, syn_state + + def ihc_step( bm_out: jnp.ndarray, ear: int, hypers: CarfacHypers, weights: CarfacWeights, ihc_state: IhcState, -) -> Tuple[jnp.ndarray, IhcState]: +) -> Tuple[jnp.ndarray, jnp.ndarray, IhcState]: """Step the inner-hair cell model with ont input sample. One sample-time update of inner-hair-cell (IHC) model, including the @@ -1632,13 +1959,13 @@ def ihc_step( ihc_state: The ihc state. Returns: - The firing probability of the hair cells in each channel - and the new state. + The firing probability of the hair cells in each channel, + the receptor potential output and the new state. """ ihc_weights = weights.ears[ear].ihc ihc_hypers = hypers.ears[ear].ihc - + v_recep = jnp.zeros(hypers.ears[ear].car.n_ch) if ihc_hypers.ihc_style == 0: ihc_out = jnp.min(2, jnp.max(0, bm_out)) # pytype: disable=wrong-arg-types # jnp-type # limit it for stability @@ -1683,11 +2010,12 @@ def ihc_step( ihc_out - ihc_state.lpf1_state ) ihc_out = ihc_state.lpf1_state - ihc_weights.rest_output + v_recep = ihc_weights.rest_cap1 - ihc_state.cap1_voltage # for where decimated output is useful ihc_state.ihc_accum = ihc_state.ihc_accum + ihc_out - return ihc_out, ihc_state + return ihc_out, v_recep, ihc_state def _agc_step_jit_helper( @@ -2052,10 +2380,15 @@ def run_segment_scan_helper(carry, k): ) # update IHC state & output on every time step, too - ihc_out, state.ears[ear].ihc = ihc_step( + ihc_out, v_recep, state.ears[ear].ihc = ihc_step( car_out, ear, hypers, weights, state.ears[ear].ihc ) + if hypers.ears[ear].syn.do_syn: + ihc_out, _, state.ears[ear].syn = syn_step( + v_recep, ear, weights, state.ears[ear].syn + ) + # run the AGC update step, decimating internally, agc_updated, state.ears[ear].agc = agc_step( ihc_out, ear, hypers, weights, state.ears[ear].agc diff --git a/python/jax/carfac_float64_test.py b/python/jax/carfac_float64_test.py index 1a406e5..921dc7d 100644 --- a/python/jax/carfac_float64_test.py +++ b/python/jax/carfac_float64_test.py @@ -36,7 +36,7 @@ def _assert_almost_equal_pytrees(self, pytree1, pytree2, delta=None): @parameterized.product( random_seed=[x for x in range(20)], - ihc_style=['one_cap', 'two_cap'], + ihc_style=['one_cap', 'two_cap', 'two_cap_with_syn'], n_ears=[1, 2], ) def test_backward_pass(self, random_seed, ihc_style, n_ears): @@ -68,8 +68,9 @@ def loss(weights, input_waves, hypers, state): # Computes gradients by `jax.grad`. gfunc = jax.grad(loss, has_aux=True) params_jax = carfac_jax.CarfacDesignParameters(n_ears=n_ears) - params_jax.ears[0].ihc.ihc_style = ihc_style - params_jax.ears[0].car.linear_car = False + for ear in range(n_ears): + params_jax.ears[ear].ihc.ihc_style = ihc_style + params_jax.ears[ear].car.linear_car = False hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac( params_jax ) diff --git a/python/jax/carfac_test.py b/python/jax/carfac_test.py index 34b552b..918d074 100644 --- a/python/jax/carfac_test.py +++ b/python/jax/carfac_test.py @@ -7,6 +7,8 @@ from absl.testing import parameterized import jax import jax.numpy as jnp +from jax.tree_util import tree_flatten +from jax.tree_util import tree_unflatten import sys sys.path.insert(0, '..') @@ -21,6 +23,17 @@ def setUp(self): # The default tolerance used in `assertAlmostEqual` etc. self.default_delta = 1e-6 + def test_post_init_syn_params(self): + # We add this test since SynDesignParameters contains a __post_init__ that + # we want to make sure does not interfere. + syn = carfac_jax.SynDesignParameters() + treedef, leaves = tree_flatten(syn) + reconstituted_syn = tree_unflatten(leaves, treedef) + self.assertLess( + jnp.max(abs(syn.healthy_n_fibers - reconstituted_syn.healthy_n_fibers)), + self.default_delta, + ) + def test_hypers_hash(self): # Tests hyperparameter objects can be hashed. This is needed by `jax.jit` # because hyperparameters are tagged `static`. @@ -108,7 +121,7 @@ def container_comparison(self, left_side, right_side, exclude_keys=None): msg='failed comparison on key item %s' % (k), ) - @parameterized.parameters(['one_cap', 'two_cap']) + @parameterized.parameters(['one_cap', 'two_cap', 'two_cap_with_syn']) def test_equal_design(self, ihc_style): # Test: the designs are similar. cfp = carfac_np.design_carfac(ihc_style=ihc_style) @@ -192,9 +205,16 @@ def test_equal_design(self, ihc_style): # Ihc params very specific for one cap, only a subset is used, so for # now we only check these one by one. We could add tests for 2 cap # similarly. - self.assertEqual( - cfp.ihc_params.ihc_style, params_jax.ears[ear_idx].ihc.ihc_style - ) + if ihc_style == 'two_cap_with_syn': + # on the numpy side, we keep the IHC as two cap name. + self.assertEqual('two_cap', cfp.ihc_params.ihc_style) + self.assertEqual( + 'two_cap_with_syn', params_jax.ears[ear_idx].ihc.ihc_style + ) + else: + self.assertEqual( + cfp.ihc_params.ihc_style, params_jax.ears[ear_idx].ihc.ihc_style + ) self.assertEqual( cfp.ihc_params.tau_in, params_jax.ears[ear_idx].ihc.tau_in @@ -221,6 +241,40 @@ def test_equal_design(self, ihc_style): weights_jax.ears[ear_idx].car, ear_params_np.car_coeffs ) + if ihc_style == 'two_cap_with_syn': + self.container_comparison( + weights_jax.ears[ear_idx].syn, + ear_params_np.syn_coeffs, + exclude_keys=['n_fibers', 'v_widths', 'v_halfs'], + ) + self.assertLess( + jnp.max( + jnp.abs( + weights_jax.ears[ear_idx].syn.n_fibers + - ear_params_np.syn_coeffs.n_fibers + ) + ), + self.default_delta, + ) + self.assertLess( + jnp.max( + jnp.abs( + weights_jax.ears[ear_idx].syn.v_widths + - ear_params_np.syn_coeffs.v_widths + ) + ), + self.default_delta, + ) + self.assertLess( + jnp.max( + jnp.abs( + weights_jax.ears[ear_idx].syn.v_halfs + - ear_params_np.syn_coeffs.v_halfs + ), + ), + self.default_delta, + ) + # Test AGC parameters for each stage. for stage_idx, ear_agc_coeffs_np in enumerate(ear_params_np.agc_coeffs): logging.info('running stage idx %d', stage_idx) @@ -245,7 +299,7 @@ def test_equal_design(self, ihc_style): @parameterized.product( random_seed=[x for x in range(5)], - ihc_style=['one_cap', 'two_cap'], + ihc_style=['one_cap', 'two_cap', 'two_cap_with_syn'], ) def test_chunked_naps_same_as_jit(self, random_seed, ihc_style): """Tests whether `run_segment` produces the same results as np version.""" @@ -289,7 +343,7 @@ def test_chunked_naps_same_as_jit(self, random_seed, ihc_style): @parameterized.product( random_seed=[x for x in range(20)], - ihc_style=['one_cap', 'two_cap'], + ihc_style=['one_cap', 'two_cap', 'two_cap_with_syn'], n_ears=[1, 2], delay_buffer=[False, True], )