Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support an init_state argument in Term.init_fields #319

Merged
merged 7 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ History

0.3.8 (2024-09-29)
------------------
* Support an `init_state` argument into both `Term.init_fields`
and `Transformer.init_fields` (:pr:`319`)
* Use virtualenv to setup github CI test environments (:pr:`321`)
* Update to NumPy 2.0.0 (:pr:`317`)
* Update to python-casacore 3.6.1 (:pr:`317`)
Expand Down
10 changes: 5 additions & 5 deletions africanus/experimental/rime/fused/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ class ArgumentDependencies:
REQUIRED_ARGS = ("time", "antenna1", "antenna2", "feed1", "feed2")
KEY_ARGS = (
"utime",
"time_index",
"time_inverse",
"uantenna",
"antenna1_index",
"antenna2_index",
"antenna1_inverse",
"antenna2_inverse",
"ufeed",
"feed1_index",
"feed2_index",
"feed1_inverse",
"feed2_inverse",
)

def __init__(self, arg_names, terms, transformers):
Expand Down
2 changes: 1 addition & 1 deletion africanus/experimental/rime/fused/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def impl(*args):

for s in range(nsrc):
for r in range(nrow):
t = state.time_index[r]
t = state.time_inverse[r]
a1 = state.antenna1[r]
a2 = state.antenna2[r]
f1 = state.feed1[r]
Expand Down
137 changes: 100 additions & 37 deletions africanus/experimental/rime/fused/intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

from africanus.averaging.support import _unique_internal

from africanus.experimental.rime.fused.arguments import ArgumentPack
from africanus.experimental.rime.fused.arguments import (
ArgumentDependencies,
ArgumentPack,
)
from africanus.experimental.rime.fused.terms.core import StateStructRef

try:
Expand Down Expand Up @@ -212,13 +215,13 @@ def _add(x, y):
class IntrinsicFactory:
KEY_ARGS = (
"utime",
"time_index",
"time_inverse",
"uantenna",
"antenna1_index",
"antenna2_index",
"antenna1_inverse",
"antenna2_inverse",
"ufeed",
"feed1_index",
"feed2_index",
"feed1_inverse",
"feed2_inverse",
)

def __init__(self, arg_dependencies):
Expand Down Expand Up @@ -312,13 +315,13 @@ def pack_index(typingctx, args):

key_types = {
"utime": arg_info["time"][0],
"time_index": types.int64[:],
"time_inverse": types.int64[:],
"uantenna": arg_info["antenna1"][0],
"antenna1_index": types.int64[:],
"antenna2_index": types.int64[:],
"antenna1_inverse": types.int64[:],
"antenna2_inverse": types.int64[:],
"ufeed": arg_info["feed1"][0],
"feed1_index": types.int64[:],
"feed2_index": types.int64[:],
"feed1_inverse": types.int64[:],
"feed2_inverse": types.int64[:],
}

if tuple(key_types.keys()) != argdeps.KEY_ARGS:
Expand Down Expand Up @@ -365,23 +368,23 @@ def codegen(context, builder, signature, args):
fn_sig = types.Tuple(list(key_types.values()))(*fn_arg_types)

def _indices(time, antenna1, antenna2, feed1, feed2):
utime, _, time_index, _ = _unique_internal(time)
utime, _, time_inverse, _ = _unique_internal(time)
uants = np.unique(np.concatenate((antenna1, antenna2)))
ufeeds = np.unique(np.concatenate((feed1, feed2)))
antenna1_index = np.searchsorted(uants, antenna1)
antenna2_index = np.searchsorted(uants, antenna2)
feed1_index = np.searchsorted(ufeeds, feed1)
feed2_index = np.searchsorted(ufeeds, feed2)
antenna1_inverse = np.searchsorted(uants, antenna1)
antenna2_inverse = np.searchsorted(uants, antenna2)
feed1_inverse = np.searchsorted(ufeeds, feed1)
feed2_inverse = np.searchsorted(ufeeds, feed2)

return (
utime,
time_index,
time_inverse,
uants,
antenna1_index,
antenna2_index,
antenna1_inverse,
antenna2_inverse,
ufeeds,
feed1_index,
feed2_index,
feed1_inverse,
feed2_inverse,
)

index = context.compile_internal(builder, _indices, fn_sig, fn_args)
Expand Down Expand Up @@ -409,26 +412,30 @@ def pack_transformed_fn(self, arg_names):
@intrinsic
def pack_transformed(typingctx, args):
assert len(args) == len(arg_names)
it = zip(arg_names, args, range(len(arg_names)))
arg_info = {n: (t, i) for n, t, i in it}
arg_pack = ArgumentPack(arg_names, args, range(len(arg_names)))

rvt = typingctx.resolve_value_type_prefer_literal
transform_output_types = []

init_state_arg_fields = [
(k, arg_pack.type(k)) for k in ArgumentDependencies.KEY_ARGS
]
init_state_type = StateStructRef(init_state_arg_fields)

for transformer in transformers:
# Figure out argument types for calling init_fields
kw = {}

for a in transformer.ARGS:
kw[a] = arg_info[a][0]
kw[a] = arg_pack.type(a)

for a, d in transformer.KWARGS.items():
try:
kw[a] = arg_info[a][0]
kw[a] = arg_pack.type(a)
except KeyError:
kw[a] = rvt(d)

fields, _ = transformer.init_fields(typingctx, **kw)
fields, _ = transformer.init_fields(typingctx, init_state_type, **kw)

if len(transformer.OUTPUTS) == 0:
raise TypingError(f"{transformer} produces no outputs")
Expand Down Expand Up @@ -460,6 +467,32 @@ def codegen(context, builder, signature, args):
llvm_ret_type = context.get_value_type(return_type)
ret_tuple = cgutils.get_null_value(llvm_ret_type)

def make_init_struct():
return structref.new(init_state_type)

init_state = context.compile_internal(
builder, make_init_struct, init_state_type(), []
)
U = structref._Utils(context, builder, init_state_type)
init_data_struct = U.get_data_struct(init_state)

for arg_name in ArgumentDependencies.KEY_ARGS:
value = builder.extract_value(args[0], arg_pack.index(arg_name))
value_type = signature.args[0][arg_pack.index(arg_name)]
# We increment the reference count here
# as we're taking a reference from data in
# the args tuple and placing it on the structref
context.nrt.incref(builder, value_type, value)
field_type = init_state_type.field_dict[arg_name]
casted = context.cast(builder, value, value_type, field_type)
context.nrt.incref(builder, value_type, casted)

# The old value on the structref is being replaced,
# decrease it's reference count
old_value = getattr(init_data_struct, arg_name)
context.nrt.decref(builder, value_type, old_value)
setattr(init_data_struct, arg_name, casted)

# Extract supplied arguments from original arg tuple
# and insert into the new one
for i, typ in enumerate(signature.args[0]):
Expand All @@ -478,16 +511,13 @@ def codegen(context, builder, signature, args):
if o != out_names[i + j + n]:
raise TypingError(f"{o} != {out_names[i + j + n]}")

transform_args = []
transform_types = []
transform_args = [init_state]
transform_types = [init_state_type]

# Get required arguments out of the argument pack
for name in transformer.ARGS:
try:
typ, j = arg_info[name]
except KeyError:
raise TypingError(f"{name} is not present in arg_types")

typ = arg_pack.type(name)
j = arg_pack.index(name)
value = builder.extract_value(args[0], j)
transform_args.append(value)
transform_types.append(typ)
Expand Down Expand Up @@ -561,12 +591,19 @@ def term_state(typingctx, args):
term_fields = []
constructors = []

init_state_arg_fields = [
(k, arg_pack.type(k)) for k in ArgumentDependencies.KEY_ARGS
]
init_state_type = StateStructRef(init_state_arg_fields)

# Query Terms for fields and their associated types
# that should be created on the State object
for term in argdeps.terms:
it = zip(term.ALL_ARGS, arg_pack.indices(*term.ALL_ARGS))
arg_types = {a: args[i] for a, i in it}
fields, constructor = term.init_fields(typingctx, **arg_types)
fields, constructor = term.init_fields(
typingctx, init_state_type, **arg_types
)
term.validate_constructor(constructor)
term_fields.append(fields)
state_fields.extend(fields)
Expand All @@ -584,8 +621,34 @@ def codegen(context, builder, signature, args):
typingctx = context.typing_context
rvt = typingctx.resolve_value_type_prefer_literal

# Create the initial state struct
def make_init_struct():
return structref.new(init_state_type)

init_state = context.compile_internal(
builder, make_init_struct, init_state_type(), []
)
U = structref._Utils(context, builder, init_state_type)
init_data_struct = U.get_data_struct(init_state)

for arg_name in ArgumentDependencies.KEY_ARGS:
value = builder.extract_value(args[0], arg_pack.index(arg_name))
value_type = signature.args[0][arg_pack.index(arg_name)]
# We increment the reference count here
# as we're taking a reference from data in
# the args tuple and placing it on the structref
context.nrt.incref(builder, value_type, value)
field_type = init_state_type.field_dict[arg_name]
casted = context.cast(builder, value, value_type, field_type)
context.nrt.incref(builder, value_type, casted)

# The old value on the structref is being replaced,
# decrease it's reference count
old_value = getattr(init_data_struct, arg_name)
context.nrt.decref(builder, value_type, old_value)
setattr(init_data_struct, arg_name, casted)

def make_struct():
"""Allocate the structure"""
return structref.new(state_type)

state = context.compile_internal(builder, make_struct, state_type(), [])
Expand Down Expand Up @@ -616,8 +679,8 @@ def make_struct():
# need to extract those arguments necessary to construct
# the term StructRef
for term in argdeps.terms:
cargs = []
ctypes = []
cargs = [init_state]
ctypes = [init_state_type]

arg_types = arg_pack.types(*term.ALL_ARGS)
arg_index = arg_pack.indices(*term.ALL_ARGS)
Expand Down
13 changes: 11 additions & 2 deletions africanus/experimental/rime/fused/terms/brightness.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,21 @@ def dask_schema(self, stokes, spi, ref_freq, chan_freq, spi_base="standard"):
LOG10 = 2

def init_fields(
self, typingctx, stokes, spi, ref_freq, chan_freq, spi_base="standard"
self,
typingctx,
init_state,
stokes,
spi,
ref_freq,
chan_freq,
spi_base="standard",
):
expected_nstokes = len(self.stokes)
fields = [("spectral_model", stokes.dtype[:, :, :])]

def brightness(stokes, spi, ref_freq, chan_freq, spi_base="standard"):
def brightness(
init_state, stokes, spi, ref_freq, chan_freq, spi_base="standard"
):
nsrc, nstokes = stokes.shape
(nchan,) = chan_freq.shape
nspi = spi.shape[1]
Expand Down
36 changes: 17 additions & 19 deletions africanus/experimental/rime/fused/terms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class members on the subclass based on the above
signatures
"""

REQUIRED = ("init_fields", "dask_schema", "sampler")
REQUIRED_METHODS = ("init_fields", "dask_schema", "sampler")
INIT_FIELDS_REQUIRED_ARGS = ("self", "typingctx", "init_state")

@classmethod
def _expand_namespace(cls, name, namespace):
Expand All @@ -53,34 +54,31 @@ def _expand_namespace(cls, name, namespace):
"""
methods = []

for method_name in cls.REQUIRED:
for method_name in cls.REQUIRED_METHODS:
try:
method = namespace[method_name]
except KeyError:
raise NotImplementedError(f"{name}.{method_name}")
else:
methods.append(method)

methods = dict(zip(cls.REQUIRED, methods))
methods = dict(zip(cls.REQUIRED_METHODS, methods))
init_fields_sig = inspect.signature(methods["init_fields"])
field_params = list(init_fields_sig.parameters.values())
sig_error = InvalidSignature(
f"{name}.init_fields{init_fields_sig} "
f"should be "
f"{name}.init_fields({', '.join(cls.INIT_FIELDS_REQUIRED_ARGS)}, ...)"
)

if len(init_fields_sig.parameters) < 2:
raise InvalidSignature(
f"{name}.init_fields{init_fields_sig} "
f"should be "
f"{name}.init_fields(self, typingctx, ...)"
)
if len(init_fields_sig.parameters) < 3:
raise sig_error

it = iter(init_fields_sig.parameters.items())
first, second = next(it), next(it)
expected_args = tuple((next(it)[0], next(it)[0], next(it)[0]))

if first[0] != "self" or second[0] != "typingctx":
raise InvalidSignature(
f"{name}.init_fields{init_fields_sig} "
f"should be "
f"{name}.init_fields(self, typingctx, ...)"
)
if expected_args != cls.INIT_FIELDS_REQUIRED_ARGS:
raise sig_error

for n, p in it:
if p.kind == p.VAR_POSITIONAL:
Expand All @@ -98,7 +96,7 @@ def _expand_namespace(cls, name, namespace):
)

dask_schema_sig = inspect.signature(methods["dask_schema"])
expected_dask_params = field_params[0:1] + field_params[2:]
expected_dask_params = field_params[0:1] + field_params[3:]
expected_dask_sig = init_fields_sig.replace(parameters=expected_dask_params)

if dask_schema_sig != expected_dask_sig:
Expand Down Expand Up @@ -127,15 +125,15 @@ def _expand_namespace(cls, name, namespace):
n
for n, p in init_fields_sig.parameters.items()
if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD}
and n not in {"self", "typingctx"}
and n not in set(cls.INIT_FIELDS_REQUIRED_ARGS)
and p.default is p.empty
)

kw = [
(n, p.default)
for n, p in init_fields_sig.parameters.items()
if p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
and n not in {"self", "typingctx"}
and n not in set(cls.INIT_FIELDS_REQUIRED_ARGS)
and p.default is not p.empty
]

Expand Down
Loading