Skip to content

Commit

Permalink
Merge pull request #315 from atomicateam/environment-parameters
Browse files Browse the repository at this point in the history
Environment parameters
  • Loading branch information
robynstuart authored May 14, 2019
2 parents 41cbfa2 + ae64811 commit 3f90d31
Show file tree
Hide file tree
Showing 52 changed files with 1,377 additions and 573 deletions.
86 changes: 47 additions & 39 deletions atomica/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def plot_cascade(results=None, cascade=None, pops=None, year=None, data=None, sh
output = (fig, table)
return output # Either fig or (fig,table)

def sanitize_cascade(framework, cascade) -> tuple:
def sanitize_cascade(framework, cascade, fallback_used:bool =False) -> tuple:
"""
Normalize cascade inputs
Expand Down Expand Up @@ -171,19 +171,18 @@ def sanitize_cascade(framework, cascade) -> tuple:
cascade_name = None
cascade_dict = cascade

validate_cascade(framework, cascade_dict) # Check that the requested cascade dictionary is valid
pop_type = validate_cascade(framework, cascade_dict, fallback_used=fallback_used) # Check that the requested cascade dictionary is valid

return cascade_name,cascade_dict
return cascade_name,cascade_dict, pop_type

def sanitize_pops(pops,pop_source) -> dict:
def sanitize_pops(pops,pop_source,pop_type) -> dict:
"""
Sanitize input populations
The input populations could be specified as
- A list or dict (with single key) containing either population code names
or full names (e.g. from the FE)
- A string like 'all'
- A list or dict (with single key) containing either population code names. List inputs can contain full names (e.g. from the FE)
- A string like 'all' or 'total'
- None, which is shorthand for all populations
For cascade purposes, the specified populations must evaluate to a single
Expand All @@ -194,29 +193,34 @@ def sanitize_pops(pops,pop_source) -> dict:
string)
:param pop_source: Object to draw available populations from (a
:class:`Result` or :class:`PlotData`)
:param pop_type: Population type to select. All returned populations will match this type
:return: A dict with a single key that can be used by :class:`PlotData` to
specify populations
"""

# Retrieve a list mapping result names to labels
if isinstance(pop_source,Result):
available = [(x.name,x.label) for x in pop_source.model.pops]
available = [(x.name,x.label,x.type) for x in pop_source.model.pops]
elif isinstance(pop_source,ProjectData):
available = [(x,y['label']) for x,y in pop_source.pops.items()]
available = [(x,y['label'],y['type']) for x,y in pop_source.pops.items()]
else:
raise Exception("Unrecognized source for pop names - must be a Result or a ProjectData instance")

def sanitize_name(name):
name = name.strip()
for x,y in available:
for x,y,ptype in available:
if x == name or y == name:
if ptype != pop_type:
raise Exception(f'Requested population "{x}" has type "{ptype}" but requested cascade is for "{pop_type}"')
return x,y
raise Exception('Name "%s" not found' % (name))

if pops in [None, 'all', 'All', 'aggregate', 'total']:
# If populations are an aggregation for all pops, then set the dict appropriately
pops = {'Entire population': [x[0] for x in available]}
pops = {'Entire population': [x[0] for x in available if x[2]==pop_type]}
if not pops['Entire population']:
raise Exception('No populations with the requested type were found')

elif isinstance(pops,list) or sc.isstring(pops):
# If it's a list or string, convert it to a dict
Expand All @@ -235,24 +239,25 @@ def sanitize_name(name):

return sc.odict(pops) # Make sure an odict gets returned rather than a dict

def validate_cascade(framework, cascade, cascade_name=None, fallback_used:bool =False) -> bool:
def validate_cascade(framework, cascade, cascade_name=None, fallback_used:bool =False) -> str:
"""
Check if a cascade is valid
A cascade is invalid if any stage does not contain a compartment that appears in subsequent stages i.e.
if the stages are not all nested.
if the stages are not all nested. Also, all compartments referred to must exist in the same population type,
otherwise it is not possible to define a population-specific cascade as it would intrinsically span populations.
:param framework: A :class:`ProjectFramework` instance
:param cascade: A cascade representation supported by :func:`sanitize_cascade`
:param cascade_name: Name of cascade to be printed in error messages
:param fallback_used: If ``True``, then in the event that the cascade is not valid, the error message will reflect the fact that it was not a user-defined cascade
:return: ``True`` if the cascade is valid
:return: The population type if the cascade is valid
:raises: ``InvalidCascade`` if the cascade is not valid
"""

if not isinstance(cascade,dict):
sanitize_cascade(framework, cascade) # This will result in a call to validate_cascade()
sanitize_cascade(framework, cascade, fallback_used=fallback_used) # This will result in a call to validate_cascade()
return True
else:
cascade_dict = cascade
Expand All @@ -265,6 +270,17 @@ def validate_cascade(framework, cascade, cascade_name=None, fallback_used:bool =
for stage, includes in cascade_dict.items():
expanded[stage] = framework.get_charac_includes(includes)

pop_types = set()
comps = framework.comps
for stage in expanded.values():
for comp in stage:
pop_types.add(comps.at[comp,'population type'])
if len(pop_types)>1:
if fallback_used:
raise Exception('The framework defines multiple population types and has characteristics spanning population types. Therefore, a default fallback cascade cannot be automatically constructed. You will need to explicitly define a cascade in the framework file')
else:
raise Exception('Cascade "%s" includes compartments from more than one population type' % (cascade_name))

for i in range(0, len(expanded) - 1):
if not (set(expanded[i + 1]) <= set(expanded[i])):
message = ''
Expand All @@ -289,7 +305,7 @@ def validate_cascade(framework, cascade, cascade_name=None, fallback_used:bool =

raise InvalidCascade(message)

return True
return list(pop_types)[0] # Return the population type

def plot_single_cascade_series(result=None, cascade=None, pops=None, data=None) -> list:
"""
Expand Down Expand Up @@ -319,8 +335,8 @@ def plot_single_cascade_series(result=None, cascade=None, pops=None, data=None)

assert isinstance(result, Result), 'Input must be a single Result object'

cascade_name, cascade_dict = sanitize_cascade(result.framework,cascade)
pops = sanitize_pops(pops,result)
cascade_name, cascade_dict, pop_type = sanitize_cascade(result.framework,cascade)
pops = sanitize_pops(pops, result, pop_type)
d = PlotData(result, outputs=cascade_dict, pops=pops)
d.set_colors(outputs=d.outputs)

Expand Down Expand Up @@ -358,8 +374,8 @@ def plot_single_cascade(result=None, cascade=None, pops=None, year=None, data=No
diffcolor = (0.85, 0.89, 1.00) # (0.74, 0.82, 1.00) # Original: (0.93,0.93,0.93)
losscolor = (0, 0, 0) # (0.8,0.2,0.2)

cascade_name, cascade_dict = sanitize_cascade(result.framework,cascade)
pops = sanitize_pops(pops,result)
cascade_name, cascade_dict, pop_type = sanitize_cascade(result.framework,cascade)
pops = sanitize_pops(pops, result, pop_type)

if not year:
year = result.t[-1] # Draw cascade for last year
Expand Down Expand Up @@ -468,8 +484,8 @@ def plot_multi_cascade(results=None, cascade=None, pops=None, year=None, data=No
elif isinstance(results, NDict):
results = list(results)

cascade_name, cascade_dict = sanitize_cascade(results[0].framework,cascade)
pops = sanitize_pops(pops,results[0])
cascade_name, cascade_dict, pop_type = sanitize_cascade(results[0].framework,cascade)
pops = sanitize_pops(pops, results[0], pop_type)

if not year:
year = results[0].t[-1] # Draw cascade for last year
Expand Down Expand Up @@ -548,12 +564,10 @@ def get_cascade_vals(result, cascade, pops=None, year=None) -> tuple:
"""
Get values for a cascade
If the population list
:param result: A single :class:`Result` instance
:param cascade: A cascade representation supported by
:func:`sanitize_cascade`
:param pops: A string (like ``'all'``), a list of pops to aggregate, or a
dict with a single key specifying an aggregation and the name of the
resulting aggregation
:param cascade: A cascade representation supported by :func:`sanitize_cascade`
:param pops: A population representation supported by :func:`sanitize_pops`
:param year: Optionally specify a subset of years to retrieve values for.
Can be a scalar, list, or array. If ``None``, all time points in the
result will be used
Expand All @@ -565,11 +579,9 @@ def get_cascade_vals(result, cascade, pops=None, year=None) -> tuple:

from .plotting import PlotData # Import here to avoid circular dependencies

if pops in [None, 'all', 'All']:
pops = 'total'

# Sanitize the cascade inputs
_, cascade_dict = sanitize_cascade(result.framework, cascade)
_, cascade_dict, pop_type = sanitize_cascade(result.framework, cascade)
pops = sanitize_pops(pops, result, pop_type) # Get list representation since we don't care about the name of the aggregated pop

if year is None:
d = PlotData(result, outputs=cascade_dict, pops=pops)
Expand Down Expand Up @@ -630,7 +642,7 @@ def cascade_summary(source_data,year:float,pops=None,cascade=0) -> None:

for result in source_data:

cascade_name, cascade_dict = sanitize_cascade(result.framework, cascade)
cascade_name, cascade_dict, pop_type = sanitize_cascade(result.framework, cascade)
absolute, _ = get_cascade_vals(result, cascade_dict, pops=pops, year=year)
percentage = sc.dcp(absolute)
for i in reversed(range(len(percentage))):
Expand Down Expand Up @@ -676,12 +688,8 @@ def get_cascade_data(data, framework, cascade, pops=None, year=None):
"""


if pops is None:
pops = 'all'

_, cascade_dict = sanitize_cascade(framework, cascade)
pops = sanitize_pops(pops,data)[0] # Get list representation since we don't care about the name of the aggregated pop
_, cascade_dict, pop_type = sanitize_cascade(framework, cascade)
pops = sanitize_pops(pops, data, pop_type)[0] # Get list representation since we don't care about the name of the aggregated pop

if year is not None:
t = sc.promotetoarray(year) # Output times are guaranteed to be
Expand Down Expand Up @@ -770,7 +778,7 @@ def __init__(self,framework, cascade, years=None, baseline_results=None, pops=No
cascade_name = None
cascade_dict = cascade
else:
cascade_name, cascade_dict = sanitize_cascade(framework,cascade)
cascade_name, cascade_dict, pop_type = sanitize_cascade(framework,cascade)

if not cascade_name:
cascade_name = 'Cascade'
Expand Down
Loading

0 comments on commit 3f90d31

Please sign in to comment.