Skip to content

Commit

Permalink
updates to masking and shifting
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan Ridden authored and Ryan Ridden committed Sep 11, 2023
1 parent f475c70 commit 446f5a5
Showing 1 changed file with 32 additions and 19 deletions.
51 changes: 32 additions & 19 deletions tessreduce/tessreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def unknown_mask(image):
return mask


def Smooth_bkg(data, extrapolate = True,gauss_smooth=2):
def Smooth_bkg(data, extrapolate = True,gauss_smooth=4):
"""
Interpolate over the masked objects to derive a background estimate.
Expand Down Expand Up @@ -265,6 +265,9 @@ def Calculate_shifts(data,mx,my,daofind):
ind = ind[indo]
shifts[1,indo] = mx[indo] - x[ind]
shifts[0,indo] = my[indo] - y[ind]
else:
shifts[0,indo] = np.nan
shifts[1,indo] = np.nan
return shifts

def image_sub(theta, image, ref):
Expand Down Expand Up @@ -712,21 +715,21 @@ def get_TESS(self,ra=None,dec=None,name=None,Size=None,Sector=None,quality_bitma
self.flux = strip_units(tpf.flux)
self.wcs = tpf.wcs

def make_mask(self,maglim=19,scale=1,strapsize=5):
def make_mask(self,maglim=19,scale=1,strapsize=6):
# make a diagnostic plot for mask
data = strip_units(self.flux)

mask, cat = Cat_mask(self.tpf,maglim,scale,strapsize)
sources = ((mask & 1)+1 ==1) * 1.
sources[sources==0] = np.nan
tmp = np.nansum(data*sources,axis=(1,2))
sky = ((mask & 1)+1 ==1) * 1.
sky[sky==0] = np.nan
tmp = np.nansum(data*sky,axis=(1,2))
tmp[tmp==0] = 1e12 # random big number
ref = data[np.argmin(tmp)] * sources
ref = data[np.argmin(tmp)] * sky
try:
qe = correct_straps(ref,mask,parallel=True)
except:
qe = correct_straps(ref,mask,parallel=False)
#mm = Source_mask(ref * qe * sources)
#mm = Source_mask(ref * qe * sky)
#mm[np.isnan(mm)] = 0
#mm = mm.astype(int)
#mm = abs(mm-1)
Expand All @@ -740,7 +743,13 @@ def make_mask(self,maglim=19,scale=1,strapsize=5):
#mm = (mm*1) | cmask

fullmask = mask | cmask
self.mask = fullmask
sky = ((fullmask & 1)+1 ==1) * 1.
sky[sky==0] = np.nan
masked = ref*sky
mean = np.nanmean(masked) # assume sources weight the mean above the bkg
m_second = (masked > mean).astype(int)

self.mask = fullmask | m_second
self._mask_cat = cat

def background(self,calc_qe=True, strap_iso=True):
Expand Down Expand Up @@ -889,6 +898,7 @@ def centroids_DAO(self,plot=None,savename=None):
y_mid = self.flux.shape[1] / 2
#ind = #((abs(mx - x_mid) <= 30) & (abs(my - y_mid) <= 30) &
ind = (abs(mx - x_mid) >= 5) & (abs(my - y_mid) >= 5)
self._dat_sources = s[ind].to_pandas()
mx = mx[ind]
my = my[ind]
if self.parallel:
Expand All @@ -902,7 +912,7 @@ def centroids_DAO(self,plot=None,savename=None):
for i in range(len(f)):
shifts[i,:,:] = Calculate_shifts(f[i],mx,my,daofind)


self.raw_shifts = shifts
meds = np.nanmedian(shifts,axis = 2)
meds[~np.isfinite(meds)] = 0

Expand Down Expand Up @@ -957,11 +967,12 @@ def fit_shift(self,plot=None,savename=None):
if savename is None:
savename = self.savename

f = self.flux
m = self.ref.copy()#flux[self.ref_ind].copy()
sources = ((self.mask & 1) ==1) * 1.
sources[sources==0] = 0.

f = self.flux * sources[np.newaxis,:,:]
m = self.ref.copy() * sources
if self.parallel:

num_cores = multiprocessing.cpu_count()
shifts = Parallel(n_jobs=num_cores)(
delayed(difference_shifts)(frame,m) for frame in f)
Expand All @@ -978,7 +989,7 @@ def fit_shift(self,plot=None,savename=None):
if self.shift is not None:
self.shift += shifts
else:
self.shift = shift
self.shift = shifts
if plot:
#meds[meds==0] = np.nan
t = self.tpf.time.mjd
Expand Down Expand Up @@ -1662,12 +1673,14 @@ def reduce(self, aper = None, align = None, parallel = None, calibrate=None,
if self.align:
if self.verbose > 0:
print('calculating centroids')
self.fit_shift()
try:
self.centroids_DAO()
if double_shift:
self.shift_images()
self.ref = deepcopy(self.flux[self.ref_ind])
self.fit_shift()
#self.centroids_DAO()
#if double_shift:
#self.shift_images()
self.ref = deepcopy(self.flux[self.ref_ind])
#self.fit_shift()
#self.shift_images()

except:
print('Something went wrong, switching to serial')
Expand Down Expand Up @@ -1700,7 +1713,7 @@ def reduce(self, aper = None, align = None, parallel = None, calibrate=None,
self.ref = deepcopy(self.flux[self.ref_ind])
self.flux -= self.ref

#self.ref -= self.bkg[self.ref_ind]
self.ref -= self.bkg[self.ref_ind]
# remake mask
self.make_mask(maglim=18,strapsize=7,scale=mask_scale*.8)#Source_mask(ref,grid=0)
frac = np.nansum((self.mask== 0) * 1.) / (self.mask.shape[0] * self.mask.shape[1])
Expand Down

0 comments on commit 446f5a5

Please sign in to comment.