From 92cd3c73e76babdca076af89d1840159dc22a1e6 Mon Sep 17 00:00:00 2001 From: Spencer Date: Thu, 15 Aug 2019 12:12:06 -0700 Subject: [PATCH] `ngmix_catalog` now handles `gauss`, `cm` and `bdf` which required some updates to `ngmix2gs()` #79. Random rotations are now applied to the truth table `g1g2` values. Closes #76 --- balrog/balinput.py | 16 ++++++- balrog/balobject.py | 72 +++++++++++++++++++++++------ balrog/balrog_injection.py | 4 -- balrog/ngmix_catalog.py | 93 ++++++++++++++++++++++++++------------ balrog/tile.py | 20 ++------ 5 files changed, 142 insertions(+), 63 deletions(-) diff --git a/balrog/balinput.py b/balrog/balinput.py index 9843e69..acfaee6 100644 --- a/balrog/balinput.py +++ b/balrog/balinput.py @@ -62,7 +62,11 @@ def _register_input_type(self, input_type, CATMOD, LOADMOD, has_nobj=False): return - def _setup_proxy_catalog(self, save_proxy=False): + def _setup_proxy_catalog(self, save_proxy=False, get_subtype=False): + # Some input catalogs have additional subtypes, like ngmix_catalog + # Sometimes need to have this information for truth catalog updates, but + # can only do it here if save_proxy is False + conf = deepcopy(self.gsconfig) # Need to remove other input types as they may not have been registered @@ -86,6 +90,11 @@ def _setup_proxy_catalog(self, save_proxy=False): self.cat = cat_proxy.getCatalog() self.nobjects = cat_proxy.getNObjects() + if get_subtype is True: + self.sub_type = cat_proxy.getSubtype() + else: + self.sub_type = None + # Need to load in additional parametric catalog for some input types try: self.parametric_cat = cat_proxy.getParamCatalog() @@ -97,9 +106,11 @@ def _setup_proxy_catalog(self, save_proxy=False): def generate_inj_catalog(self, config, tile, realization, mixed, mixed_grid=None): inp_type = self.input_type inj_type = self.inj_type + sub_type = self.sub_type self.inj_catalog = balobject.build_bal_inject_cat(inp_type, inj_type, + sub_type, tile, needs_band=self.needs_band, mixed=mixed) @@ -146,7 +157,8 @@ def __init__(self, gsconfig, indx=None, tilename=None): # TODO: Can we grab the injection type from the registered GS catalog? self.inj_type = 'ngmixGalaxy' - self._setup_proxy_catalog() + # ngmix catalogs have additional subtypes, e.g. cm + self._setup_proxy_catalog(get_subtype=True) return diff --git a/balrog/balobject.py b/balrog/balobject.py index 1a3033b..d817fe8 100644 --- a/balrog/balobject.py +++ b/balrog/balobject.py @@ -1,15 +1,15 @@ import numpy as np import galsim import galsim.config.input as gsinput +import ngmix import os import copy +from numpy.lib.recfunctions import append_fields # Balrog files import mathutil as util import grid -# import pudb - # TODO: Implement as needed class BalObject(object): ''' @@ -20,9 +20,10 @@ def __init__(self): class BalInjectionCatalog(object): - def __init__(self, input_type, inj_type, tile, needs_band=False, mixed=False): + def __init__(self, input_type, inj_type, sub_type, tile, needs_band=False, mixed=False): self.input_type = input_type self.inj_type = inj_type + self.sub_type = sub_type self.needs_band = needs_band self.mixed = False @@ -276,8 +277,24 @@ def get_truth_outfile(self, base_outfile, real): self.truth_outfile[real] = os.path.join(base_outfile, truth_fname) return self.truth_outfile[real] - def write_new_positions(self, truth_cat, realization): - pos = self.pos[realization] + def update_truth_cols(self, config, truth_cat, real): + self.write_new_positions(truth_cat, real) + self.update_colnames(truth_cat, real) + + if self.rotate[real] is not None: + self.update_truth_shapes(config, truth_cat, real) + + def update_truth_shapes(self, config, truth_cat, real): + # Only here to be inherited by subclasses, but not bothering with + # a full abstract class for now + raise NotImplementedError('Need to implement `update_truth_shapes()` to apply rotations to ' + + 'custom BalObjects!') + + def update_colnames(self, truth_cat, real): + pass + + def write_new_positions(self, truth_cat, real): + pos = self.pos[real] # If nothing is set for a given custom input, try the obvious try: @@ -294,9 +311,6 @@ def write_new_positions(self, truth_cat, realization): return - def update_colnames(self, truth_cat, realization): - pass - def setup_chip_config(self, config, bal_config, chip, chip_indx): # Many injection types will requite nothing special in setup pass @@ -308,10 +322,10 @@ def build_multi_chip_config(self, config, bal_config, chip, chip_indx, input_ind pass class DESInjectionCatalog(BalInjectionCatalog): - def __init__(self, input_type, inj_type, tile, needs_band, mixed=False): + def __init__(self, input_type, inj_type, sub_type, tile, needs_band, mixed=False): # All catalogs require band input assert needs_band is True - super(DESInjectionCatalog, self).__init__(input_type, inj_type, tile, needs_band, mixed) + super(DESInjectionCatalog, self).__init__(input_type, inj_type, sub_type, tile, needs_band, mixed) return @@ -355,6 +369,33 @@ def generate_objects(self, config, realization, mixed_grid=None): return mixed_grid + def update_truth_shapes(self, config, truth_cat, real): + g_colname = self.sub_type + '_g' + g1 = truth_cat[self.inj_type][g_colname][:,0] + g2 = truth_cat[self.inj_type][g_colname][:,1] + + # Need to unpack array of '{rotation in deg} deg' + deg2rad = np.pi / 180. + theta = np.array([float(r.split()[0].strip()) for r in self.rotate[real]]) + + g1_rot, g2_rot = ngmix.shape.rotate_shape(g1, g2, deg2rad*theta) + + # Update truth catalog shape information + truth_cat[self.inj_type][g_colname][:,0] = g1_rot + truth_cat[self.inj_type][g_colname][:,1] = g2_rot + + # These values are also stored in pars + truth_cat[self.inj_type][self.sub_type+'_pars'][:,2] = g1_rot + truth_cat[self.inj_type][self.sub_type+'_pars'][:,3] = g2_rot + + # Add rotation angles to truth catalog + truth_cat[self.inj_type] = append_fields(truth_cat[self.inj_type], + 'rotation', + theta, + usemask=False) + + return + class MEDSInjectionCatalog(DESInjectionCatalog): def generate_objects(self, config, realization, mixed_grid=None): mixed_grid = super(MEDSInjectionCatalog, self).generate_objects(config, @@ -409,8 +450,8 @@ def build_multi_chip_config(self, config, bal_config, chip, chip_indx, input_ind return class DESStarInjectionCatalog(DESInjectionCatalog): - def __init__(self, input_type, inj_type, tile, needs_band=False, mixed=False): - super(DESStarInjectionCatalog, self).__init__(input_type, inj_type, tile, needs_band, mixed) + def __init__(self, input_type, inj_type, sub_type, tile, needs_band=False, mixed=False): + super(DESStarInjectionCatalog, self).__init__(input_type, inj_type, sub_type, tile, needs_band, mixed) # Might add something like this in the future, if we end up passing configs # during init... @@ -489,6 +530,10 @@ def _generate_sahar_coords(self, config, realization): return + def update_truth_shapes(self, config, truth_cat, real): + # Stars are injected as delta functions, so no need to update shape info + pass + def write_new_positions(self, truth_cat, realization): # Currently, all used DES star catalogs have Sahar's naming scheme anyway, # so this is check is not needed @@ -603,11 +648,12 @@ class Star(object): def __init__(self): pass -def build_bal_inject_cat(input_type, inj_type, tile, needs_band, mixed=False): +def build_bal_inject_cat(input_type, inj_type, sub_type, tile, needs_band, mixed=False): if input_type in BALROG_INJECTION_TYPES: # User-defined injection catalog construction inject_cat = BALROG_INJECTION_TYPES[input_type](input_type, inj_type, + sub_type, tile, needs_band, mixed) diff --git a/balrog/balrog_injection.py b/balrog/balrog_injection.py index 131bfa3..ae01099 100644 --- a/balrog/balrog_injection.py +++ b/balrog/balrog_injection.py @@ -15,15 +15,11 @@ import tile as Tile import fileio as io -# Use for debugging -# import pudb - #------------------------------------------------------------------------------- # Important todo's: # TODO: Check pixel origin for transformations! # Some extra todo's: -# TODO: Have code automatically check fitsio vs astropy # TODO: Add check for python path! # TODO: Fix some bugged multi-line print statements diff --git a/balrog/ngmix_catalog.py b/balrog/ngmix_catalog.py index 4e6e986..5b78fcb 100644 --- a/balrog/ngmix_catalog.py +++ b/balrog/ngmix_catalog.py @@ -12,7 +12,6 @@ import logging import warnings from past.builtins import basestring # Python 2&3 compatibility -# import pudb # TODO: Include noise, pixscale @@ -54,29 +53,33 @@ class ngmixCatalog(object): _req_params = { 'file_name' : str, 'bands' : str} _opt_params = { 'dir' : str, 'catalog_type' : str, 'snr_min' : float, 'snr_max' : float, 't_frac' : float, 't_min' : float, 't_max' : float, 'version' : str, - 'de_redden' : bool} + 'de_redden' : bool, 'TdByTe' : float} _single_params = [] _takes_rng = False # Only these ngmix catalog types currently supported # `gauss`: Single Gaussian fit # `cm`: Combined bulge+disk model - # `mof`: CM model with multi-object fitting + # `bdf`: CM model with buldge-disk ratio fixed # NOTE: In principle, should be able to handle any type supported by ngmix - _valid_catalog_types = ['gauss','cm','mof'] + # See Issue #79 + _valid_catalog_types = ['gauss','cm','bdf'] # Only these color bands are currently supported for an ngmix catalog _valid_band_types = 'griz' # Dictionary of color band flux to array index in ngmix catalogs + # TODO: This should be grabbed from the fits header rather than be hardcoded + # See Issue #80 _band_index = {'g' : 0, 'r' : 1, 'i' : 2, 'z' : 3} - # The catalog column name prefix doens't always match the catalog type (e.g. 'mof' has a prefix - # of 'cm' for most columns). Set this for each new supported catalog type. - _cat_col_prefix = {'gauss' : 'gauss', 'cm' : 'cm', 'mof' : 'cm'} + # NOTE: In previous versions, the catalog column name prefix didn't always + # match the catalog type (e.g. 'mof' had a prefix of 'cm' for most columns). + # This shouldn't be needed in the future but leaving for now + _cat_col_prefix = {'gauss' : 'gauss', 'cm' : 'cm', 'bdf' : 'bdf'} def __init__(self, file_name, bands, dir=None, catalog_type=None, snr_min=None, snr_max=None, - t_frac=None, t_min=None, t_max=None, version=None, de_redden=False, + t_frac=None, t_min=None, t_max=None, version=None, de_redden=False, TdByTe=None, _nobjects_only=False): if dir: @@ -102,10 +105,10 @@ def __init__(self, file_name, bands, dir=None, catalog_type=None, snr_min=None, # Attempt to determine catalog type from filename (this is generally true for DES ngmix # catalogs) match = 0 - for type in self._valid_catalog_types: - if type in self.file_name: + for t in self._valid_catalog_types: + if t in self.file_name: match +=1 - self.cat_type = type + self.cat_type = t # Reject if multiple matches if match == 0: raise ValueError("No inputted ngmix catalog type, and no matches in filename! " @@ -114,8 +117,17 @@ def __init__(self, file_name, bands, dir=None, catalog_type=None, snr_min=None, raise ValueError("No inputted ngmix catalog type, and multiple matches in filename!" " Please set a valid catalog type.") + if TdByTe is not None: + if self.cat_type != 'bdf': + raise ValueError('Can only set a constant `TdByTe` for ngmix type `bdf`!') + if TdByTe < 0: + raise ValueError('TdByTe must be non-negative!') + else: + # This should almost always be 1 + TdByTe = 1. + self._TdByTe = TdByTe + # Catalog column name prefixes don't always match catalog type - # (e.g. 'cm' is still used for many 'mof' columns) self.col_prefix = self._cat_col_prefix[self.cat_type] if isinstance(bands, basestring): @@ -242,13 +254,13 @@ def getFlags(self): # General flags self.flags = self.catalog['flags'] - # TODO: Look at these in more detail! + # We don't want to cut on these explicitly anymore: + #self.obj_flags = self.catalog['obj_flags'] # ngmix catalog-specific flags #self.ngmix_flags = self.catalog[self.col_prefix+'_flags'] - # TODO: Check for additional flags # if self.cat_type == 'mof': # # mof has additional flags # self.mof_flags = self.catalog[self.col_prefix+'_mof_flags'] @@ -269,7 +281,8 @@ def makeMask(self): # For now, remove objects with any flags present mask[self.flags != 0] = False - # TODO: We probably want to remove these! + + # No longer do these explicitly: # mask[self.obj_flags !=0] = False # mask[self.ngmix_flags !=0] = False # Extra flags for 'mof' catalogs @@ -408,24 +421,43 @@ def ngmix2gs(self, index, band, gsparams=None): flux = self.catalog[flux_colname][index][self._band_index[band]] - # Gaussian-Mixture parameters are in the format of: + # NOTE: It used to be the case that all Gaussian-Mixture parameters + # were in the format of: # gm_pars = [centroid1, centroid2, g1, g2, T, flux] - # (this is identical to ngmix catalogs, except that flux is a vector - # of fluxes in all color bands) - gm_pars = [0.0, 0.0, g1, g2, T, flux] - - # Build the appropriate Gaussian mixture for a cm-model - fracdev = self.catalog[cp+'_fracdev'][index] - TdByTe = self.catalog[cp+'_TdByTe'][index] - gm = ngmix.gmix.GMixCM(fracdev, TdByTe, gm_pars) + # (this is identical to ngmix `pars`, except that flux is a vector + # of fluxes in all bands) + # However, this now depends on the gmix type, so we have to wait + # gm_pars = [0.0, 0.0, g1, g2, T, flux] + + # TODO: Implement this once we get a response back from Erin why + # CM isn't included in this function + # https://github.com/esheldon/ngmix/blob/master/ngmix/gmix.py#L39 + # This allows us to construct the given gmix type without knowing + # gm.make_gmix_model(pars, model, **kw): + + # Build the appropriate Gaussian mixture for a given model + if ct == 'gauss': + # Uses 'simple' pars scheme + gm_pars = [0.0, 0.0, g1, g2, T, flux] + gm = ngmix.gmix.GMixModel(gm_pars, 'gaussian') + + elif ct == 'cm': + fracdev = self.catalog[cp+'_fracdev'][index] + TdByTe = self.catalog[cp+'_TdByTe'][index] + # Uses 'simple' pars scheme + gm_pars = [0.0, 0.0, g1, g2, T, flux] + gm = ngmix.gmix.GMixCM(fracdev, TdByTe, gm_pars) + + elif ct == 'bdf': + fracdev = self.catalog[cp+'_fracdev'][index] + TdByTe = self._TdByTe + # Uses different 'bdf' pars scheme + gm_pars = [0.0, 0.0, g1, g2, T, fracdev, flux] + gm = ngmix.gmix.GMixBDF(pars=gm_pars, TdByTe=TdByTe) # The majority of the conversion will be handled by `ngmix.gmix.py` gs_gal = gm.make_galsim_object(gsparams=gsp) - # NOTE: Can add any model-specific requirements in future if needed - # if ct == 'mof': - # gal = ... - return gs_gal #---------------------------------------------------------------------------------------------- @@ -486,6 +518,11 @@ def selectRandomIndex(self, n_random=1, rng=None, _n_rng_calls=False): #---------------------------------------------------------------------------------------------- + # The catalog type is referred to as a 'subtype' w.r.t BalInput types + # (i.e. the BalInput type is 'ngmix_catalog' with subtype self.catalog_type) + def getSubtype(self): + return self.cat_type + def getNObjects(self): # Used by input/logger methods return self.nobjects diff --git a/balrog/tile.py b/balrog/tile.py index 3a0ed17..89c9900 100644 --- a/balrog/tile.py +++ b/balrog/tile.py @@ -20,8 +20,6 @@ import balobject as balobj import grid -# import pudb - #------------------------------------------------------------------------------- # Tile class and functions @@ -856,8 +854,7 @@ def write_truth_catalog(self, config): # Parametric truth catalog previously loaded in `load_input_catalogs()` truth[inj_type] = inpt.parametric_cat[inj.indx[real]] - self._write_new_positions(truth, inj) - self._update_colnames(truth, inj) + self._update_truth_cols(config, truth, inj) for inj_type, outfile in outfiles.items(): try: @@ -866,8 +863,6 @@ def write_truth_catalog(self, config): truth_table.write(truth[inj_type]) # Fill primary HDU with simulation metadata - # hdr = fits.Header() - # TODO: Add ext_fact and ext_mag ! hdr = {} hdr['run_name'] = config.run_name hdr['config_file'] = config.args.config_file @@ -901,17 +896,10 @@ def write_truth_catalog(self, config): return - def _write_new_positions(self, truth_cat, inj_cat): - # Position re-writing (including default behaviour) has been moved to class - # methods in `balobject.py` - inj_cat.write_new_positions(truth_cat, self.curr_real) - - return - - def _update_colnames(self, truth_cat, inj_cat): - # Column re-writing (including default behaviour) has been moved to class + def _update_truth_cols(self, config, truth_cat, inj_cat): + # Column re-writing (including positions) has been moved to class # methods in `balobject.py` - inj_cat.update_colnames(truth_cat, self.curr_real) + inj_cat.update_truth_cols(config, truth_cat, self.curr_real) return