Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SST1RSoXSDB: prototype new Dataset workflow #130

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ pygix
scikit-image
scipy
pillow
xarray
# the following minimum version is due to new explicit Coordinate code
xarray>2023.6
tqdm
pydata_sphinx_theme
# the following pin is due to a security update to numexpr: https://github.com/pydata/numexpr/issues/442
Expand Down
16 changes: 12 additions & 4 deletions src/PyHyperScattering/PFEnergySeriesIntegrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,25 @@ def integrateImageStack(self,img_stack,method=None,chunksize=None):
'''

'''

full_stack = None
if isinstance(img_stack,xr.Dataset):
full_stack = img_stack
img_stack = img_stack['I_raw']
if (self.use_chunked_processing and method is None) or method=='dask':
func_args = {}
if chunksize is not None:
func_args['chunksize'] = chunksize
return self.integrateImageStack_dask(img_stack,**func_args)
retval = self.integrateImageStack_dask(img_stack,**func_args)
elif (method is None) or method == 'legacy':
return self.integrateImageStack_legacy(img_stack)
retval = self.integrateImageStack_legacy(img_stack)
else:
raise NotImplementedError(f'unsupported integration method {method}')

if isinstance(full_stack,xr.Dataset):
retval.name='I_integ'
return xr.merge([full_stack,retval])
else:
return retval



def createIntegrator(self,en,recreate=False):
Expand Down
22 changes: 16 additions & 6 deletions src/PyHyperScattering/PFGeneralIntegrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,18 +437,28 @@ def integrateImageStack_dask(self, data, chunksize=5):
integ_fly = integ_fly.unstack('pyhyper_internal_multiindex')
return integ_fly

def integrateImageStack(self, img_stack, method=None, chunksize=None):
''' '''

if (self.use_chunked_processing and method is None) or method == 'dask':
def integrateImageStack(self,img_stack,method=None,chunksize=None):
'''

'''
full_stack = None
if isinstance(img_stack,xr.Dataset):
full_stack = img_stack
img_stack = img_stack['I_raw']
if (self.use_chunked_processing and method is None) or method=='dask':
func_args = {}
if chunksize is not None:
func_args['chunksize'] = chunksize
return self.integrateImageStack_dask(img_stack, **func_args)
retval = self.integrateImageStack_dask(img_stack,**func_args)
elif (method is None) or method == 'legacy':
return self.integrateImageStack_legacy(img_stack)
retval = self.integrateImageStack_legacy(img_stack)
else:
raise NotImplementedError(f'unsupported integration method {method}')
if isinstance(full_stack,xr.Dataset):
retval.name='I_integ'
return xr.merge([full_stack,retval])
else:
return retval

def loadPolyMask(self, maskpoints=[], **kwargs):
'''
Expand Down
127 changes: 83 additions & 44 deletions src/PyHyperScattering/SST1RSoXSDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def subtract_dark(img, pedestal=100, darks=None):
monitors = (
monitors.rename({"time": "system"})
.reset_index("system")
.assign_coords(system=index)
.assign_coords(mindex_coords)
)

if "system_" in monitors.indexes.keys():
Expand Down Expand Up @@ -802,6 +802,7 @@ def subtract_dark(img, pedestal=100, darks=None):
# retxr = (index,monitors,retxr)
monitors.attrs.update(retxr.attrs)
retxr = monitors.merge(retxr)
retxr = self._normalize_monitor_names(retxr)

if self.use_chunked_loading:
# dask and multiindexes are like PEO and PPO. They're kinda the same thing and they don't like each other.
Expand All @@ -818,6 +819,7 @@ def loadMonitors(
integrate_onto_images: bool = True,
useShutterThinning: bool = True,
n_thinning_iters: int = 5,
directLoadPulsedMonitors: bool = True
):
"""Load the monitor streams for entry.

Expand All @@ -838,51 +840,52 @@ def loadMonitors(
useShutterThinning : bool, optional
Whether or not to attempt to thin (filter) the raw time streams to remove data collected during shutter opening/closing, by default False
As of 9 Feb 2023 at NSLS2 SST1, using useShutterThinning= True for exposure times of < 0.5s is
not recommended because the shutter data is unreliable and too many points will be culled
not recommended because the shutter data is unreliable and too many points will be removed
n_thinning_iters : int, optional
how many iterations of thinning to perform, by default 5
If the data is becoming too sparse, try fewer iterations
directLoadPulsedMonitors : bool, optional
Whether or not to load the pulsed monitors using direct reading, by default True
This only applies if integrate_onto_images is True; otherwise you'll get very raw data.
If False, the pulsed monitors will be loaded using a shutter-thinning and masking approach as with continuous counters

Returns
-------
xr.Dataset
xarray dataset containing all monitor streams as data variables mapped against the dimension "time"
"""

monitors = None
raw_monitors = None


# Iterate through the list of streams held by the Bluesky document 'entry'
# Iterate through the list of streams held by the Bluesky document 'entry', and build
for stream_name in list(entry.keys()):
# Add monitor streams to the output xr.Dataset
if "monitor" in stream_name:
if monitors is None: # First one
if raw_monitors is None: # First one
# incantation to extract the dataset from the bluesky stream
monitors = entry[stream_name].data.read()
raw_monitors = entry[stream_name].data.read()
else: # merge into the to existing output xarray
monitors = xr.merge((monitors, entry[stream_name].data.read()))
raw_monitors = xr.merge((raw_monitors, entry[stream_name].data.read()))

# At this stage monitors has dimension time and all streams as data variables
# the time dimension inherited all time values from all streams
# the data variables (Mesh current, sample current etc.) are all sparse, with lots of nans

# if there are no monitors, return an empty xarray Dataset
if monitors is None:
if raw_monitors is None:
return xr.Dataset()

# For each nan value, replace with the closest value ahead of it in time
# For remaining nans, replace with closest value behind it in time
monitors = monitors.ffill("time").bfill("time")
monitors = raw_monitors.ffill("time").bfill("time")

# If we need to remap timepoints to match timepoints for data acquisition
if integrate_onto_images:
try:
# Pull out ndarray of 'primary' timepoints (measurement timepoints)
try:
primary_time = entry.primary.data["time"].values
except AttributeError:
if type(entry.primary.data["time"]) == tiled.client.array.DaskArrayClient:
primary_time = entry.primary.data["time"].read().compute()
elif type(entry.primary.data["time"]) == tiled.client.array.ArrayClient:
primary_time = entry.primary.data["time"].read()
primary_time = entry.primary.data["time"].__array__()
primary_time_bins = np.insert(primary_time, 0,0)

# If we want to exclude values for when the shutter was opening or closing
# This doesn't work for exposure times ~ < 0.5 s, because shutter stream isn't reliable
Expand All @@ -904,22 +907,50 @@ def loadMonitors(
"time"
)

#return monitors
# Bin the indexes in 'time' based on the intervales between timepoints in 'primary_time' and evaluate their mean
# Then rename the 'time_bin' dimension that results to 'time'
monitors = (
monitors.groupby_bins("time", np.insert(primary_time, 0, 0))
monitors.groupby_bins("time",primary_time_bins,include_lowest=True)
.mean()
.rename_dims({"time_bins": "time"})
.rename({"time_bins": "time"})
)

'''
# Add primary measurement time as a coordinate in monitors that is named 'time'
# Remove the coordinate 'time_bins' from the array
monitors = (
monitors.assign_coords({"time": primary_time})
.drop_indexes("time_bins")
.reset_coords("time_bins", drop=True)
)

)'''

# load direct/pulsed monitors

for stream_name in list(entry.keys()):
if "monitor" in stream_name and ("Beamstop" in stream_name or "Sample" in stream_name):
# the pulsed monitors we know about are "SAXS Beamstop", "WAXS Beamstop", "Sample Current"
# if others show up here, they could be added
out_name = stream_name.replace("_monitor", "")
mon = entry[stream_name].data.read()[out_name].compute()
SIGNAL_THRESHOLD = 0.1
threshold = SIGNAL_THRESHOLD*mon.mean('time')
mon_filter = xr.zeros_like(mon)
mon_filter[mon<threshold] = 0
mon_filter[mon>threshold] = 1
mon_filter.values = scipy.ndimage.binary_erosion(mon_filter)
mon_filtered = mon.where(mon_filter==1)
mon_binned = (mon_filtered.groupby_bins("time",primary_time_bins,include_lowest=True)
.mean()
.rename({"time_bins":"time"})
)

if not directLoadPulsedMonitors:
out_name = 'pl_' + out_name

monitors[out_name] = mon_binned
monitors = monitors.assign_coords({"time": primary_time})


except Exception as e:
# raise e # for testing
warnings.warn(
Expand All @@ -932,6 +963,27 @@ def loadMonitors(
)
return monitors

def _normalize_monitor_names(self,run):
"""
Normalize instrument-local monitor names to PyHyper 1.0+ standard names

"""
rename_dict = {
'RSoXS Au Mesh Current' : 'I0',
'NSLS-II Ring Current' : 'Isrc'
}

if run.attrs['rsoxs_config'] == 'saxs':
rename_dict['SAXS Beamstop'] = 'It'
rename_dict['Small Angle CCD Detector_image'] = 'I_raw'
elif run.attrs['rsoxs_config'] == 'waxs':
rename_dict['WAXS Beamstop'] = 'It'
rename_dict['Wide Angle CCD Detector_image'] = 'I_raw'
else:
pass
return run.rename(rename_dict)


def loadMd(self, run):
"""
return a dict of metadata entries from the databroker run xarray
Expand Down Expand Up @@ -1036,19 +1088,12 @@ def loadMd(self, run):
# print(f'Loading from primary: {phs}, value {primary[rsoxs].values}')
except (KeyError, HTTPStatusError):
try:
blval = baseline[rsoxs]
if (
type(blval) == tiled.client.array.ArrayClient
or type(blval) == tiled.client.array.DaskArrayClient
):
blval = blval.read()
blval = baseline[rsoxs].__array__()
md[phs] = blval.mean().round(4)
if blval.var() > 0:
if blval.var() > 1e-4*abs(blval.mean()):
warnings.warn(
(
f"While loading {rsoxs} to infill metadata entry for {phs}, found"
f" beginning and end values unequal: {baseline[rsoxs]}. It is"
" possible something is messed up."
f"{phs} changed during scan: {blval}."
),
stacklevel=2,
)
Expand All @@ -1057,28 +1102,20 @@ def loadMd(self, run):
md[phs] = primary[md_secondary_lookup[phs]].read()
except (KeyError, HTTPStatusError):
try:
blval = baseline[md_secondary_lookup[phs]]
if (
type(blval) == tiled.client.array.ArrayClient
or type(blval) == tiled.client.array.DaskArrayClient
):
blval = blval.read()
blval = baseline[md_secondary_lookup[phs]].__array__()
md[phs] = blval.mean().round(4)
if blval.var() > 0:
if blval.var() > 1e-4*abs(blval.mean()):
warnings.warn(
(
f"While loading {md_secondary_lookup[phs]} to infill"
f" metadata entry for {phs}, found beginning and end"
f" values unequal: {baseline[rsoxs]}. It is possible"
" something is messed up."
f"{phs} changed during scan: {blval}."
),
stacklevel=2,
)
except (KeyError, HTTPStatusError):
warnings.warn(
(
f"Could not find {rsoxs} in either baseline or primary. "
f" Needed to infill value {phs}. Setting to None."
f"Could not find {rsoxs} in either baseline or primary while"
f" looking for {phs}. Setting to None."
),
stacklevel=2,
)
Expand Down Expand Up @@ -1146,6 +1183,7 @@ def loadMd(self, run):
md.update(run.metadata)
return md

'''
def loadSingleImage(self, filepath, coords=None, return_q=False, **kwargs):
"""
DO NOT USE
Expand Down Expand Up @@ -1216,3 +1254,4 @@ def loadSingleImage(self, filepath, coords=None, return_q=False, **kwargs):
)
else:
return xr.DataArray(img, dims=["pix_x", "pix_y"], attrs=headerdict)
'''
Loading