diff --git a/HISTORY.rst b/HISTORY.rst index 69c3be3d..cc3c2771 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,6 +4,8 @@ History 0.3.8 (2024-09-29) ------------------ +* Support an `init_state` argument into both `Term.init_fields` + and `Transformer.init_fields` (:pr:`319`) * Use virtualenv to setup github CI test environments (:pr:`321`) * Update to NumPy 2.0.0 (:pr:`317`) * Update to python-casacore 3.6.1 (:pr:`317`) diff --git a/africanus/experimental/rime/fused/arguments.py b/africanus/experimental/rime/fused/arguments.py index b7cfc52e..36fa0b95 100644 --- a/africanus/experimental/rime/fused/arguments.py +++ b/africanus/experimental/rime/fused/arguments.py @@ -45,13 +45,13 @@ class ArgumentDependencies: REQUIRED_ARGS = ("time", "antenna1", "antenna2", "feed1", "feed2") KEY_ARGS = ( "utime", - "time_index", + "time_inverse", "uantenna", - "antenna1_index", - "antenna2_index", + "antenna1_inverse", + "antenna2_inverse", "ufeed", - "feed1_index", - "feed2_index", + "feed1_inverse", + "feed2_inverse", ) def __init__(self, arg_names, terms, transformers): diff --git a/africanus/experimental/rime/fused/core.py b/africanus/experimental/rime/fused/core.py index 661dcfcc..a6595a8f 100644 --- a/africanus/experimental/rime/fused/core.py +++ b/africanus/experimental/rime/fused/core.py @@ -94,7 +94,7 @@ def impl(*args): for s in range(nsrc): for r in range(nrow): - t = state.time_index[r] + t = state.time_inverse[r] a1 = state.antenna1[r] a2 = state.antenna2[r] f1 = state.feed1[r] diff --git a/africanus/experimental/rime/fused/intrinsics.py b/africanus/experimental/rime/fused/intrinsics.py index 04b16ce4..a2fce378 100644 --- a/africanus/experimental/rime/fused/intrinsics.py +++ b/africanus/experimental/rime/fused/intrinsics.py @@ -13,7 +13,10 @@ from africanus.averaging.support import _unique_internal -from africanus.experimental.rime.fused.arguments import ArgumentPack +from africanus.experimental.rime.fused.arguments import ( + ArgumentDependencies, + ArgumentPack, +) from africanus.experimental.rime.fused.terms.core import StateStructRef try: @@ -212,13 +215,13 @@ def _add(x, y): class IntrinsicFactory: KEY_ARGS = ( "utime", - "time_index", + "time_inverse", "uantenna", - "antenna1_index", - "antenna2_index", + "antenna1_inverse", + "antenna2_inverse", "ufeed", - "feed1_index", - "feed2_index", + "feed1_inverse", + "feed2_inverse", ) def __init__(self, arg_dependencies): @@ -312,13 +315,13 @@ def pack_index(typingctx, args): key_types = { "utime": arg_info["time"][0], - "time_index": types.int64[:], + "time_inverse": types.int64[:], "uantenna": arg_info["antenna1"][0], - "antenna1_index": types.int64[:], - "antenna2_index": types.int64[:], + "antenna1_inverse": types.int64[:], + "antenna2_inverse": types.int64[:], "ufeed": arg_info["feed1"][0], - "feed1_index": types.int64[:], - "feed2_index": types.int64[:], + "feed1_inverse": types.int64[:], + "feed2_inverse": types.int64[:], } if tuple(key_types.keys()) != argdeps.KEY_ARGS: @@ -365,23 +368,23 @@ def codegen(context, builder, signature, args): fn_sig = types.Tuple(list(key_types.values()))(*fn_arg_types) def _indices(time, antenna1, antenna2, feed1, feed2): - utime, _, time_index, _ = _unique_internal(time) + utime, _, time_inverse, _ = _unique_internal(time) uants = np.unique(np.concatenate((antenna1, antenna2))) ufeeds = np.unique(np.concatenate((feed1, feed2))) - antenna1_index = np.searchsorted(uants, antenna1) - antenna2_index = np.searchsorted(uants, antenna2) - feed1_index = np.searchsorted(ufeeds, feed1) - feed2_index = np.searchsorted(ufeeds, feed2) + antenna1_inverse = np.searchsorted(uants, antenna1) + antenna2_inverse = np.searchsorted(uants, antenna2) + feed1_inverse = np.searchsorted(ufeeds, feed1) + feed2_inverse = np.searchsorted(ufeeds, feed2) return ( utime, - time_index, + time_inverse, uants, - antenna1_index, - antenna2_index, + antenna1_inverse, + antenna2_inverse, ufeeds, - feed1_index, - feed2_index, + feed1_inverse, + feed2_inverse, ) index = context.compile_internal(builder, _indices, fn_sig, fn_args) @@ -409,26 +412,30 @@ def pack_transformed_fn(self, arg_names): @intrinsic def pack_transformed(typingctx, args): assert len(args) == len(arg_names) - it = zip(arg_names, args, range(len(arg_names))) - arg_info = {n: (t, i) for n, t, i in it} + arg_pack = ArgumentPack(arg_names, args, range(len(arg_names))) rvt = typingctx.resolve_value_type_prefer_literal transform_output_types = [] + init_state_arg_fields = [ + (k, arg_pack.type(k)) for k in ArgumentDependencies.KEY_ARGS + ] + init_state_type = StateStructRef(init_state_arg_fields) + for transformer in transformers: # Figure out argument types for calling init_fields kw = {} for a in transformer.ARGS: - kw[a] = arg_info[a][0] + kw[a] = arg_pack.type(a) for a, d in transformer.KWARGS.items(): try: - kw[a] = arg_info[a][0] + kw[a] = arg_pack.type(a) except KeyError: kw[a] = rvt(d) - fields, _ = transformer.init_fields(typingctx, **kw) + fields, _ = transformer.init_fields(typingctx, init_state_type, **kw) if len(transformer.OUTPUTS) == 0: raise TypingError(f"{transformer} produces no outputs") @@ -460,6 +467,32 @@ def codegen(context, builder, signature, args): llvm_ret_type = context.get_value_type(return_type) ret_tuple = cgutils.get_null_value(llvm_ret_type) + def make_init_struct(): + return structref.new(init_state_type) + + init_state = context.compile_internal( + builder, make_init_struct, init_state_type(), [] + ) + U = structref._Utils(context, builder, init_state_type) + init_data_struct = U.get_data_struct(init_state) + + for arg_name in ArgumentDependencies.KEY_ARGS: + value = builder.extract_value(args[0], arg_pack.index(arg_name)) + value_type = signature.args[0][arg_pack.index(arg_name)] + # We increment the reference count here + # as we're taking a reference from data in + # the args tuple and placing it on the structref + context.nrt.incref(builder, value_type, value) + field_type = init_state_type.field_dict[arg_name] + casted = context.cast(builder, value, value_type, field_type) + context.nrt.incref(builder, value_type, casted) + + # The old value on the structref is being replaced, + # decrease it's reference count + old_value = getattr(init_data_struct, arg_name) + context.nrt.decref(builder, value_type, old_value) + setattr(init_data_struct, arg_name, casted) + # Extract supplied arguments from original arg tuple # and insert into the new one for i, typ in enumerate(signature.args[0]): @@ -478,16 +511,13 @@ def codegen(context, builder, signature, args): if o != out_names[i + j + n]: raise TypingError(f"{o} != {out_names[i + j + n]}") - transform_args = [] - transform_types = [] + transform_args = [init_state] + transform_types = [init_state_type] # Get required arguments out of the argument pack for name in transformer.ARGS: - try: - typ, j = arg_info[name] - except KeyError: - raise TypingError(f"{name} is not present in arg_types") - + typ = arg_pack.type(name) + j = arg_pack.index(name) value = builder.extract_value(args[0], j) transform_args.append(value) transform_types.append(typ) @@ -561,12 +591,19 @@ def term_state(typingctx, args): term_fields = [] constructors = [] + init_state_arg_fields = [ + (k, arg_pack.type(k)) for k in ArgumentDependencies.KEY_ARGS + ] + init_state_type = StateStructRef(init_state_arg_fields) + # Query Terms for fields and their associated types # that should be created on the State object for term in argdeps.terms: it = zip(term.ALL_ARGS, arg_pack.indices(*term.ALL_ARGS)) arg_types = {a: args[i] for a, i in it} - fields, constructor = term.init_fields(typingctx, **arg_types) + fields, constructor = term.init_fields( + typingctx, init_state_type, **arg_types + ) term.validate_constructor(constructor) term_fields.append(fields) state_fields.extend(fields) @@ -584,8 +621,34 @@ def codegen(context, builder, signature, args): typingctx = context.typing_context rvt = typingctx.resolve_value_type_prefer_literal + # Create the initial state struct + def make_init_struct(): + return structref.new(init_state_type) + + init_state = context.compile_internal( + builder, make_init_struct, init_state_type(), [] + ) + U = structref._Utils(context, builder, init_state_type) + init_data_struct = U.get_data_struct(init_state) + + for arg_name in ArgumentDependencies.KEY_ARGS: + value = builder.extract_value(args[0], arg_pack.index(arg_name)) + value_type = signature.args[0][arg_pack.index(arg_name)] + # We increment the reference count here + # as we're taking a reference from data in + # the args tuple and placing it on the structref + context.nrt.incref(builder, value_type, value) + field_type = init_state_type.field_dict[arg_name] + casted = context.cast(builder, value, value_type, field_type) + context.nrt.incref(builder, value_type, casted) + + # The old value on the structref is being replaced, + # decrease it's reference count + old_value = getattr(init_data_struct, arg_name) + context.nrt.decref(builder, value_type, old_value) + setattr(init_data_struct, arg_name, casted) + def make_struct(): - """Allocate the structure""" return structref.new(state_type) state = context.compile_internal(builder, make_struct, state_type(), []) @@ -616,8 +679,8 @@ def make_struct(): # need to extract those arguments necessary to construct # the term StructRef for term in argdeps.terms: - cargs = [] - ctypes = [] + cargs = [init_state] + ctypes = [init_state_type] arg_types = arg_pack.types(*term.ALL_ARGS) arg_index = arg_pack.indices(*term.ALL_ARGS) diff --git a/africanus/experimental/rime/fused/terms/brightness.py b/africanus/experimental/rime/fused/terms/brightness.py index f48ad490..58b37285 100644 --- a/africanus/experimental/rime/fused/terms/brightness.py +++ b/africanus/experimental/rime/fused/terms/brightness.py @@ -144,12 +144,21 @@ def dask_schema(self, stokes, spi, ref_freq, chan_freq, spi_base="standard"): LOG10 = 2 def init_fields( - self, typingctx, stokes, spi, ref_freq, chan_freq, spi_base="standard" + self, + typingctx, + init_state, + stokes, + spi, + ref_freq, + chan_freq, + spi_base="standard", ): expected_nstokes = len(self.stokes) fields = [("spectral_model", stokes.dtype[:, :, :])] - def brightness(stokes, spi, ref_freq, chan_freq, spi_base="standard"): + def brightness( + init_state, stokes, spi, ref_freq, chan_freq, spi_base="standard" + ): nsrc, nstokes = stokes.shape (nchan,) = chan_freq.shape nspi = spi.shape[1] diff --git a/africanus/experimental/rime/fused/terms/core.py b/africanus/experimental/rime/fused/terms/core.py index a4c6bae1..4a528eef 100644 --- a/africanus/experimental/rime/fused/terms/core.py +++ b/africanus/experimental/rime/fused/terms/core.py @@ -36,7 +36,8 @@ class members on the subclass based on the above signatures """ - REQUIRED = ("init_fields", "dask_schema", "sampler") + REQUIRED_METHODS = ("init_fields", "dask_schema", "sampler") + INIT_FIELDS_REQUIRED_ARGS = ("self", "typingctx", "init_state") @classmethod def _expand_namespace(cls, name, namespace): @@ -53,7 +54,7 @@ def _expand_namespace(cls, name, namespace): """ methods = [] - for method_name in cls.REQUIRED: + for method_name in cls.REQUIRED_METHODS: try: method = namespace[method_name] except KeyError: @@ -61,26 +62,23 @@ def _expand_namespace(cls, name, namespace): else: methods.append(method) - methods = dict(zip(cls.REQUIRED, methods)) + methods = dict(zip(cls.REQUIRED_METHODS, methods)) init_fields_sig = inspect.signature(methods["init_fields"]) field_params = list(init_fields_sig.parameters.values()) + sig_error = InvalidSignature( + f"{name}.init_fields{init_fields_sig} " + f"should be " + f"{name}.init_fields({', '.join(cls.INIT_FIELDS_REQUIRED_ARGS)}, ...)" + ) - if len(init_fields_sig.parameters) < 2: - raise InvalidSignature( - f"{name}.init_fields{init_fields_sig} " - f"should be " - f"{name}.init_fields(self, typingctx, ...)" - ) + if len(init_fields_sig.parameters) < 3: + raise sig_error it = iter(init_fields_sig.parameters.items()) - first, second = next(it), next(it) + expected_args = tuple((next(it)[0], next(it)[0], next(it)[0])) - if first[0] != "self" or second[0] != "typingctx": - raise InvalidSignature( - f"{name}.init_fields{init_fields_sig} " - f"should be " - f"{name}.init_fields(self, typingctx, ...)" - ) + if expected_args != cls.INIT_FIELDS_REQUIRED_ARGS: + raise sig_error for n, p in it: if p.kind == p.VAR_POSITIONAL: @@ -98,7 +96,7 @@ def _expand_namespace(cls, name, namespace): ) dask_schema_sig = inspect.signature(methods["dask_schema"]) - expected_dask_params = field_params[0:1] + field_params[2:] + expected_dask_params = field_params[0:1] + field_params[3:] expected_dask_sig = init_fields_sig.replace(parameters=expected_dask_params) if dask_schema_sig != expected_dask_sig: @@ -127,7 +125,7 @@ def _expand_namespace(cls, name, namespace): n for n, p in init_fields_sig.parameters.items() if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD} - and n not in {"self", "typingctx"} + and n not in set(cls.INIT_FIELDS_REQUIRED_ARGS) and p.default is p.empty ) @@ -135,7 +133,7 @@ def _expand_namespace(cls, name, namespace): (n, p.default) for n, p in init_fields_sig.parameters.items() if p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} - and n not in {"self", "typingctx"} + and n not in set(cls.INIT_FIELDS_REQUIRED_ARGS) and p.default is not p.empty ] diff --git a/africanus/experimental/rime/fused/terms/cube_dde.py b/africanus/experimental/rime/fused/terms/cube_dde.py index 870270b1..78eeb782 100644 --- a/africanus/experimental/rime/fused/terms/cube_dde.py +++ b/africanus/experimental/rime/fused/terms/cube_dde.py @@ -68,6 +68,7 @@ def dask_schema( def init_fields( self, typingctx, + init_state, beam, beam_lm_extents, beam_freq_map, @@ -88,6 +89,7 @@ def init_fields( ] def beam( + init_state, beam, beam_lm_extents, beam_freq_map, @@ -169,8 +171,8 @@ def sampler(self): zero_vis = zero_vis_factory(ncorr) def cube_dde(state, s, r, t, f1, f2, a1, a2, c): - a = state.antenna1_index[r] if left else state.antenna2_index[r] - feed = state.feed1_index[r] if left else state.feed2_index[r] + a = state.antenna1_inverse[r] if left else state.antenna2_inverse[r] + feed = state.feed1_inverse[r] if left else state.feed2_inverse[r] sin_pa = state.beam_parangle[t, feed, a, 0] cos_pa = state.beam_parangle[t, feed, a, 1] diff --git a/africanus/experimental/rime/fused/terms/feed_rotation.py b/africanus/experimental/rime/fused/terms/feed_rotation.py index 9b18ca0a..f97c8247 100644 --- a/africanus/experimental/rime/fused/terms/feed_rotation.py +++ b/africanus/experimental/rime/fused/terms/feed_rotation.py @@ -29,8 +29,8 @@ def __init__(self, configuration, feed_type, corrs): super().__init__(configuration) self.feed_type = feed_type - def init_fields(self, typingctx, feed_parangle): - def dummy(feed_parangle): + def init_fields(self, typingctx, init_state, feed_parangle): + def dummy(init_state, feed_parangle): pass return [], dummy @@ -43,8 +43,8 @@ def sampler(self): linear = self.feed_type == "linear" def feed_rotation(state, s, r, t, f1, f2, a1, a2, c): - a = state.antenna1_index[r] if left else state.antenna2_index[r] - f = state.feed1_index[r] if left else state.feed2_index[r] + a = state.antenna1_inverse[r] if left else state.antenna2_inverse[r] + f = state.feed1_inverse[r] if left else state.feed2_inverse[r] sin_a = state.feed_parangle[t, f, a, 0, 0] cos_a = state.feed_parangle[t, f, a, 0, 1] sin_b = state.feed_parangle[t, f, a, 1, 0] diff --git a/africanus/experimental/rime/fused/terms/gaussian.py b/africanus/experimental/rime/fused/terms/gaussian.py index 8222a731..e12e35f1 100644 --- a/africanus/experimental/rime/fused/terms/gaussian.py +++ b/africanus/experimental/rime/fused/terms/gaussian.py @@ -18,7 +18,7 @@ def dask_schema(self, uvw, chan_freq, gauss_shape): "gauss_shape": ("source", "gauss_shape_params"), } - def init_fields(self, typingctx, uvw, chan_freq, gauss_shape): + def init_fields(self, typingctx, init_state, uvw, chan_freq, gauss_shape): guv_dtype = typingctx.unify_types(uvw.dtype, chan_freq.dtype, gauss_shape.dtype) fields = [("gauss_uv", guv_dtype[:, :, :]), ("scaled_freq", chan_freq)] @@ -26,7 +26,7 @@ def init_fields(self, typingctx, uvw, chan_freq, gauss_shape): fwhminv = 1.0 / fwhm gauss_scale = fwhminv * np.sqrt(2.0) * np.pi / lightspeed - def gaussian_init(uvw, chan_freq, gauss_shape): + def gaussian_init(init_state, uvw, chan_freq, gauss_shape): nsrc, _ = gauss_shape.shape nrow, _ = uvw.shape diff --git a/africanus/experimental/rime/fused/terms/phase.py b/africanus/experimental/rime/fused/terms/phase.py index a0757aff..f950b520 100644 --- a/africanus/experimental/rime/fused/terms/phase.py +++ b/africanus/experimental/rime/fused/terms/phase.py @@ -20,11 +20,13 @@ def dask_schema(self, lm, uvw, chan_freq, convention="fourier"): "convention": None, } - def init_fields(self, typingctx, lm, uvw, chan_freq, convention="fourier"): + def init_fields( + self, typingctx, init_state, lm, uvw, chan_freq, convention="fourier" + ): phase_dt = typingctx.unify_types(lm.dtype, uvw.dtype, chan_freq.dtype) fields = [("phase_dot", phase_dt[:, :])] - def phase(lm, uvw, chan_freq, convention="fourier"): + def phase(init_state, lm, uvw, chan_freq, convention="fourier"): nsrc, _ = lm.shape nrow, _ = uvw.shape (nchan,) = chan_freq.shape diff --git a/africanus/experimental/rime/fused/transformers/core.py b/africanus/experimental/rime/fused/transformers/core.py index ad975f7d..fd8c1aa3 100644 --- a/africanus/experimental/rime/fused/transformers/core.py +++ b/africanus/experimental/rime/fused/transformers/core.py @@ -26,13 +26,14 @@ class members on the subclass based on the above signatures """ - REQUIRED = ("dask_schema", "init_fields") + REQUIRED_METHODS = ("dask_schema", "init_fields") + INIT_FIELDS_REQUIRED_ARGS = ("self", "typingctx", "init_state") @classmethod def _expand_namespace(cls, name, namespace): methods = [] - for method_name in cls.REQUIRED: + for method_name in cls.REQUIRED_METHODS: try: method = namespace[method_name] except KeyError: @@ -40,26 +41,24 @@ def _expand_namespace(cls, name, namespace): else: methods.append(method) - methods = dict(zip(cls.REQUIRED, methods)) + methods = dict(zip(cls.REQUIRED_METHODS, methods)) init_fields_sig = inspect.signature(methods["init_fields"]) field_params = list(init_fields_sig.parameters.values()) - if len(init_fields_sig.parameters) < 2: - raise InvalidSignature( - f"{name}.init_fields{init_fields_sig} " - f"should be " - f"{name}.init_fields(self, typingctx, ...)" - ) + sig_error = InvalidSignature( + f"{name}.init_fields{init_fields_sig} " + f"should be " + f"{name}.init_fields({', '.join(cls.INIT_FIELDS_REQUIRED_ARGS)}, ...)" + ) + + if len(init_fields_sig.parameters) < 3: + raise sig_error it = iter(init_fields_sig.parameters.items()) - first, second = next(it), next(it) + expected_args = tuple((next(it)[0], next(it)[0], next(it)[0])) - if first[0] != "self" or second[0] != "typingctx": - raise InvalidSignature( - f"{name}.init_fields{init_fields_sig} " - f"should be " - f"{name}.init_fields(self, typingctx, ...)" - ) + if expected_args != cls.INIT_FIELDS_REQUIRED_ARGS: + raise sig_error for n, p in it: if p.kind == p.VAR_POSITIONAL: @@ -77,7 +76,7 @@ def _expand_namespace(cls, name, namespace): ) dask_schema_sig = inspect.signature(methods["dask_schema"]) - expected_dask_params = field_params[0:1] + field_params[2:] + expected_dask_params = field_params[0:1] + field_params[3:] expected_dask_sig = init_fields_sig.replace(parameters=expected_dask_params) if dask_schema_sig != expected_dask_sig: @@ -106,7 +105,7 @@ def _expand_namespace(cls, name, namespace): n for n, p in init_fields_sig.parameters.items() if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD} - and n not in {"self", "typingctx"} + and n not in set(cls.INIT_FIELDS_REQUIRED_ARGS) and p.default is p.empty ) @@ -114,7 +113,7 @@ def _expand_namespace(cls, name, namespace): (n, p.default) for n, p in init_fields_sig.parameters.items() if p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} - and n not in {"self", "typingctx"} + and n not in set(cls.INIT_FIELDS_REQUIRED_ARGS) and p.default is not p.empty ) diff --git a/africanus/experimental/rime/fused/transformers/lm.py b/africanus/experimental/rime/fused/transformers/lm.py index 2b67c9a6..d490f393 100644 --- a/africanus/experimental/rime/fused/transformers/lm.py +++ b/africanus/experimental/rime/fused/transformers/lm.py @@ -6,11 +6,11 @@ class LMTransformer(Transformer): OUTPUTS = ["lm"] - def init_fields(self, typingctx, radec, phase_dir): + def init_fields(self, typingctx, init_state, radec, phase_dir): dt = typingctx.unify_types(radec.dtype, phase_dir.dtype) fields = [("lm", dt[:, :])] - def lm(radec, phase_dir): + def lm(init_state, radec, phase_dir): lm = np.empty_like(radec) pc_ra = phase_dir[0] pc_dec = phase_dir[1] diff --git a/africanus/experimental/rime/fused/transformers/parangle.py b/africanus/experimental/rime/fused/transformers/parangle.py index 64cb1448..15b0d119 100644 --- a/africanus/experimental/rime/fused/transformers/parangle.py +++ b/africanus/experimental/rime/fused/transformers/parangle.py @@ -15,15 +15,18 @@ def __init__(self, process_pool): def init_fields( self, typingctx, - utime, - ufeed, - uantenna, + init_state, antenna_position, phase_dir, receptor_angle=None, ): + fdict = init_state.field_dict + dt = typingctx.unify_types( - utime.dtype, ufeed.dtype, antenna_position.dtype, phase_dir.dtype + fdict["utime"].dtype, + fdict["ufeed"].dtype, + antenna_position.dtype, + phase_dir.dtype, ) fields = [ ("feed_parangle", dt[:, :, :, :, :]), @@ -46,17 +49,15 @@ def parangle_stub(time, antenna, phase_dir): return out - def parangles( - utime, ufeed, uantenna, antenna_position, phase_dir, receptor_angle=None - ): - (ntime,) = utime.shape - (nant,) = uantenna.shape - (nfeed,) = ufeed.shape + def parangles(init_state, antenna_position, phase_dir, receptor_angle=None): + (ntime,) = init_state.utime.shape + (nant,) = init_state.uantenna.shape + (nfeed,) = init_state.ufeed.shape # Select out the antennae we're interested in - antenna_position = antenna_position[uantenna] + antenna_position = antenna_position[init_state.uantenna] - parangles = parangle_stub(utime, antenna_position, phase_dir) + parangles = parangle_stub(init_state.utime, antenna_position, phase_dir) feed_pa = np.empty((ntime, nfeed, nant, 2, 2), parangles.dtype) beam_pa = np.empty((ntime, nfeed, nant, 2), parangles.dtype) @@ -65,10 +66,10 @@ def parangles( raise ValueError("receptor_angle.ndim != 2") if receptor_angle.shape[1] != 2: - raise ValueError("Only 2 receptor angles " "currently supported") + raise ValueError("Only 2 receptor angles currently supported") # Select out the feeds we're interested in - receptor_angle = receptor_angle[ufeed, :] + receptor_angle = receptor_angle[init_state.ufeed, :] for t in range(ntime): for f in range(nfeed): @@ -94,10 +95,8 @@ def parangles( return fields, parangles - def dask_schema( - self, utime, ufeed, uantenna, antenna_position, phase_dir, receptor_angle=None - ): - dt = np.result_type(utime, ufeed, antenna_position, phase_dir, receptor_angle) + def dask_schema(self, antenna_position, phase_dir, receptor_angle=None): + dt = np.result_type(antenna_position, phase_dir, receptor_angle) inputs = {"antenna_position": ("antenna", "ant-comp"), "phase_dir": ("radec",)} if receptor_angle is not None: diff --git a/docs/experimental.rst b/docs/experimental.rst index 8b5f50d6..f0d40bfc 100644 --- a/docs/experimental.rst +++ b/docs/experimental.rst @@ -138,7 +138,7 @@ defined on the `Phase` term, called `init_fields`. from africanus.experimental.rime.fused.terms.core import Term class Phase(Term) - def init_fields(self, typingctx, lm, uvw, chan_freq): + def init_fields(self, typingctx, init_state, lm, uvw, chan_freq): # Given the numba types of the lm, uvw and chan_freq # arrays, derive a unified output numba type numba_type = typingctx.unify_types(lm.dtype, @@ -241,7 +241,7 @@ In the following code snippet, ``LMTransformer.init_fields`` # OUTPUTS class attribute OUTPUTS = ["lm"] - def init_fields(self, typingctx, radec, phase_dir): + def init_fields(self, typingctx, init_state, radec, phase_dir): # Type and provide method for initialising the lm output dt = typingctx.unify_types(radec.dtype, phase_dir.dtype) fields = [("lm", dt[:, :])] @@ -272,6 +272,61 @@ In the following code snippet, ``LMTransformer.init_fields`` The ``lm`` array will be available on the ``state`` object and as a valid input for :meth:`Term.init_fields`. +Indexing arrays ++++++++++++++++ + +The ``init_state`` and ``state`` objects contains NumPy arrays storing +Measurement Set v2.0 indexing information. + +.. code-block:: python + + class State: + utime # Unique times + uantenna # Unique antenna indices + ufeed # Unique feed indices + time_inverse # Maps the time at a row into utime + antenna1_inverse # Maps the antenna1 index at a row into uantenna + antenna2_inverse # Maps the antenna2 index at a row into uantenna + feed1_inverse # Maps the feed1 index at a row into ufeed + feed2_inverse # Maps the feed2 index at a row into ufeed + ... + +These arrays are useful in cases where the developer wishes to avoid +recomputing values multiple times for each row in the sampling function. +Instead they can be pre-computed for unique times, antennas and feeds +in :meth:`Term.init_fields` and then looked up in :meth:`Term.sampler`. + +.. code-block:: python + + class MyTerm(Term): + def init_fields(self, typingctx, init_state, ...): + fields = [("precomputed", numba.float64[:, :, :])] + + def precompute(init_state, ...): + ntime = init_state.utime.shape[0] + nfeed = init_state.ufeed.shape[0] + nant = init_state.uantenna.shape[0] + precomputed = np.empty((ntime, nfeed, nant), np.float64) + + for t in range(ntime): + for f in range(nfeed): + for a in range(nant): + precomputed[t, f, a] = ... + + return precomputed + + return fields, precompute + + def sampler(self, state, s, r, t, f1, f2, a1, a2, c): + left = self.configuration == "left" + + def sample_precomputed(state, s, r, t, f1, f2, a1, a2, c): + f = state.feed1_inverse[r] if left else state.feed2_inverse[r] + a = state.antenna1_inverse[r] if left else state.antenna2_inverse[r] + return state.precomputed[t, f, a] + + return sample_precomputed + Invoking the RIME +++++++++++++++++ @@ -429,7 +484,8 @@ API def __init__(self, configuration): super().__init__(configuration) - .. py:method:: Term.init_fields(self, typing_ctx, arg1, ..., argn, \ + .. py:method:: Term.init_fields(self, typing_ctx, init_state, \ + arg1, ..., argn, \ kwarg1=None, ..., kwargn=None) Requests inputs to the RIME term, ensuring that they are @@ -445,7 +501,7 @@ API ``init_fields`` should return a :code:`(fields, function)` tuple. ``fields`` should be a list of the form :code:`[(name, numba_type)]`, while ``function`` should be a function of the form - :code:`fn(arg1, ..., argn, kwarg1=None, .., kwargn=None)` + :code:`fn(init_state, arg1, ..., argn, kwarg1=None, .., kwargn=None)` and should return the variables of the type defined in ``fields``. Note that it's signature therefore matches that of ``init_fields`` from after the ``typingctx`` @@ -453,6 +509,7 @@ API :ref:`Simple Example `. :param typingctx: A Numba typing context. + :param init_state: State object holding index information. :param arg1...argn: Required RIME inputs for this Term. :param kwarg1...kwargn: Optional RIME inputs for this Term. \ Types here should be simple: ints, floats, complex numbers @@ -528,8 +585,9 @@ API This should correspond to the fields produced by :meth:`Transformer.init_fields`. - .. py:method:: Transformer.init_fields(self, typing_ctx, arg1, ..., argn, \ - kwarg1=None, ..., kwargn=None) + .. py:method:: Transformer.init_fields(self, typing_ctx, init_state, \ + arg1, ..., argn, \ + kwarg1=None, ..., kwargn=None) Requests inputs to the Transformer, and specifies new fields and the function for creating them on the ``state`` object. @@ -547,8 +605,9 @@ API in Numba's `nopython `_ mode. - .. py:method:: dask_schema(self, arg1, ..., argn, \ - kwargs1=None, ..., kwargn=None) + .. py:method:: dask_schema(self, init_state, \ + arg1, ..., argn, \ + kwargs1=None, ..., kwargn=None) @@ -571,7 +630,6 @@ API - Predefined Terms ++++++++++++++++