Coverage for / home / casatest / venv / lib / python3.12 / site-packages / casatasks / private / imagerhelpers / msuvbinflag_algorithms.py: 9%
387 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 07:14 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 07:14 +0000
1import os
2import numpy as np
4# import numba as nb
5import casatools
6import time
7from scipy.optimize import curve_fit
8from typing import Tuple, List, Union, Optional
10from casatasks import casalog
12import matplotlib.pyplot as pl
14ms = casatools.ms()
15tb = casatools.table()
16me = casatools.measures()
17qa = casatools.quanta()
18ia = casatools.image()
19im = casatools.imager()
22class UVGridFlag:
23 def __init__(self, binnedvis: str, doplot: bool = False) -> None:
24 self.binnedvis = binnedvis
25 self.doplot = doplot
26 ## parameter when debugging to test algorithm but not change grid.
27 self.dryrun = False
28 if self.doplot:
29 pl.ion()
31 # @nb.njit(cache=True)
32 def populate_grid(
33 self,
34 uvw: np.array,
35 stokesI: np.array,
36 uvgrid: np.array,
37 uvgrid_npt: np.array,
38 deltau: float,
39 deltav: float,
40 npix: int,
41 ):
42 for ii in range(len(uvw[0])):
43 uidx = int(np.round(uvw[0][ii] // deltau + npix // 2))
44 vidx = int(np.round(uvw[1][ii] // deltav + npix // 2))
46 uvgrid[uidx, vidx] += stokesI[ii]
47 uvgrid_npt[uidx, vidx] += 1
49 return uvgrid, uvgrid_npt
51 def mad(self, inpdat: np.array) -> float:
52 """
53 Calculate the STD via MAD for the input data
55 Inputs:
56 inpdat Input numpy array
58 Returns:
59 std Calculate the std via mad
60 """
62 med = np.median(inpdat)
63 mad = np.median(np.abs(inpdat - med))
65 # 1.4826 is the scaling factor for a normal distribution
66 # to convert MAD to STD
67 return 1.4826 * mad
69 def msuvbin_to_uvgrid(
70 self, ms: str, npix: int, deltau: float, deltav: float
71 ) -> Tuple[np.array, np.array]:
72 tb.open(ms)
73 uvw = tb.getcol("UVW")
74 data = tb.getcol("DATA")
75 tb.close()
77 umin, umax = np.min(uvw[0]), np.max(uvw[0])
78 vmin, vmax = np.min(uvw[1]), np.max(uvw[1])
79 wmin, wmax = np.min(uvw[2]), np.max(uvw[2])
81 uvals = np.linspace(umin, umax, npix)
82 vvals = np.linspace(vmin, vmax, npix)
84 uvgrid = np.zeros((npix, npix), dtype=np.complex128)
85 uvgrid_npt = np.zeros((npix, npix), dtype=int)
87 stokesI = 0.5 * (data[0] + data[1])
88 stokesI = np.squeeze(stokesI)
90 uvgrid, uvgrid_npt = self.populate_grid(
91 uvw, stokesI, uvgrid, uvgrid_npt, deltau, deltav, npix
92 )
93 # Average per uv cell
94 idx = np.where(uvgrid_npt != 0)
95 uvgrid[idx] = uvgrid[idx] / uvgrid_npt[idx]
97 return uvgrid, uvgrid_npt
99 def resid_cube(
100 self, x: float, a: float, b: float, c: float, d: float
101 ) -> float:
102 return a * x**3 + b * x**2 + c * x + d
104 def resid_cinco(
105 self,
106 x: np.ndarray,
107 a: float,
108 b: float,
109 c: float,
110 d: float,
111 e: float,
112 f: float,
113 ) -> np.ndarray:
114 return a * x**5 + b * x**4 + c * x**3 + d * x**2 + e * x + f
116 def fit_radial_profile(
117 self,
118 xvals: np.ndarray,
119 yvals: np.ndarray,
120 ystd: np.ndarray,
121 deg: int = 3,
122 clip_sigma: float = 3,
123 ) -> np.ndarray:
124 """
125 Fit the radial profile with a polynomial
126 """
128 # print(f"deg {deg}, clip_sigma {clip_sigma}")
129 # print(f"xvals shape {xvals.shape}, yvals shape {yvals.shape}")
130 # print(f"non-zero x {np.count_nonzero(xvals)}, non-zero y {np.count_nonzero(yvals)}")
132 idx = np.where(yvals != 0)
133 xvals = xvals[idx]
134 yvals = yvals[idx]
135 ystd = ystd[idx]
137 # print(f"xvals {xvals}, yvals {yvals}")
139 # Fit the radial profile with a polynomial
140 # pfit, pcov = curve_fit(resid_function, xvals, yvals, sigma=ystd, p0=[1, 1, 1, 1])
142 pfit = np.polyfit(xvals, yvals, deg=2)
143 yfit = np.polyval(pfit, np.linspace(xvals.min(), xvals.max(), 100))
145 has_converged = False
146 while not has_converged:
147 # Outlier rejection while fitting
148 resid = yvals - np.polyval(pfit, xvals)
149 resid_mad = self.mad(resid)
150 resid_med = np.median(resid)
151 idx = np.where(
152 (resid > resid_med - clip_sigma * resid_mad)
153 & (resid < resid_med + clip_sigma * resid_mad)
154 )
156 if len(idx[0]) == len(xvals):
157 has_converged = True
158 continue
160 xvals = xvals[idx]
161 yvals = yvals[idx]
162 ystd = ystd[idx]
164 # pfit = np.polyfit(xvals, yvals, deg=deg)
166 pfit, pcov = curve_fit(
167 self.resid_cube, xvals, yvals, sigma=ystd, p0=[1, 1, 1, 1]
168 )
169 yfit = np.polyval(pfit, np.linspace(xvals.min(), xvals.max(), 100))
171 # import matplotlib.pyplot as plt
173 # fig, ax = plt.subplots()
174 # ax.plot(xvals, yvals, '-o', label='Data')
175 # ax.plot(np.linspace(xvals.min(), xvals.max(), 100), yfit, '--', label='Fit')
176 # ax.set_xlabel('Radius')
177 # ax.set_ylabel('Intensity')
178 # ax.legend()
179 # plt.tight_layout()
180 # plt.show()
182 return pfit
184 def calc_radial_profile_ann(
185 self, uvgrid: np.ndarray, uvlen_m: np.ndarray
186 ) -> Tuple[np.array, np.array, np.array]:
187 """
188 Calculate the annular average of the uvgrid for every radius,
189 and fit the 1D radial profile with a polynomial.
190 """
192 nbin = 30
194 npixx, npixy = uvgrid.shape[0], uvgrid.shape[1]
195 cx, cy = npixx // 2, npixy // 2
197 uvlen_m_grid = uvlen_m.reshape([npixx, npixy])
199 # Generate grid of radii
200 x = np.arange(npixx) - cx
201 y = np.arange(npixy) - cy
202 # rad = np.sqrt(x**2 + y**2).astype(int)
204 yy, xx = np.meshgrid(x, y)
205 # radgrid = np.sqrt(xx**2 + yy**2).astype(int)
207 # Create log-spaced annuli to account for reducing UV coverage with radius
208 # Minimum annulus is 5px
209 annuli = np.logspace(0, np.log10(uvlen_m_grid.max()), nbin)
210 annuli = np.round(annuli).astype(int)
212 radial_mean = np.zeros(nbin)
213 radial_mad = np.zeros(nbin)
215 ann_min = 0
216 for idx, ann in enumerate(annuli):
217 ridx = np.where((uvlen_m_grid >= ann_min) & (uvlen_m_grid < ann))
218 uvgrid_sel = uvgrid[ridx]
219 uvgrid_sel = uvgrid_sel[np.abs(uvgrid_sel) != 0]
221 if len(uvgrid_sel) == 0:
222 radial_mean[idx] = 0
223 radial_mad[idx] = 0.0
224 else:
225 radial_mean[idx] = np.mean(np.abs(uvgrid_sel))
226 radial_mad[idx] = self.mad(np.abs(uvgrid_sel))
228 ann_min = ann
230 return radial_mean, radial_mad, annuli
232 def calc_radial_profile_pix(
233 self, uvgrid: np.ndarray, deltau: float, deltav: float
234 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
235 """
236 Calculate the radial per-pixel average of the uvgrid for every radius,
237 and fit the 1D radial profile with a polynomial.
238 """
240 npixx, npixy = uvgrid.shape[0], uvgrid.shape[1]
241 cx, cy = npixx // 2, npixy // 2
243 # Generate radial values from 0 to max
244 x = np.arange(cx)
245 y = np.arange(cy)
246 rad = np.sqrt(x**2 + y**2).astype(int)
248 # Generate grid of radii
249 x = np.arange(npixx) - cx
250 y = np.arange(npixy) - cy
251 yy, xx = np.meshgrid(x, y)
252 radgrid = np.sqrt(xx**2 + yy**2).astype(int)
254 radial_mean = np.zeros(np.max([cx, cy]))
255 radial_mad = np.zeros(np.max([cx, cy]))
257 for idx, rr in enumerate(rad):
258 if idx == len(rad) - 1:
259 ridx = np.where((radgrid >= rad[idx]))
260 else:
261 ridx = np.where(
262 (radgrid > rad[idx]) & (radgrid <= rad[idx + 1])
263 )
265 uvgrid_sel = uvgrid[ridx]
266 uvgrid_sel = uvgrid_sel[np.abs(uvgrid_sel) != 0]
268 if len(uvgrid_sel) == 0:
269 radial_mean[idx] = 0
270 radial_mad[idx] = 0.0
271 else:
272 radial_mean[idx] = np.mean(uvgrid_sel)
273 radial_mad[idx] = self.mad(uvgrid_sel)
275 return uvgrid, radial_mean, radial_mad, rad * deltau
277 #############################################################3
278 def calc_radial_profile_and_fit(
279 self,
280 uvgrid: np.ndarray,
281 wgtgrid: np.ndarray,
282 flggrid: np.ndarray,
283 nsigma: float,
284 ) -> None:
285 """
286 Does a weighted radial mean profile and fit it and determines flag for point that is nsigma
287 above fitted radial profile.
288 flggrid gets modifies
289 Right now the wgtgrid and uvgrid are zeroed at the flagged cells...but this is not necessary
290 """
291 npixx, npixy = uvgrid.shape[0], uvgrid.shape[1]
292 cx, cy = npixx // 2, npixy // 2
293 # print(f"npixx {npixx}, npixy {npixy}, centerx {cx} centery {cy}")
295 # Generate radial values from 0 to max
296 x = np.arange(cx)
297 y = np.arange(cy)
298 rad = np.sqrt(x**2 + y**2).astype(int)
299 npoints = int(np.max(rad)) + 1
300 radamp = np.zeros(npoints)
301 radamp2 = np.zeros(npoints)
302 radwght = np.zeros(npoints)
303 xval = np.array(range(npoints), dtype="float")
304 for j in range(1, npixy):
305 yval2 = (j - cy) * (j - cy)
306 for k in range(1, npixx):
307 rval = int(np.sqrt((k - cx) * (k - cx) + yval2))
308 if wgtgrid[k, j] > 0.0:
309 absval = np.abs(uvgrid[k, j])
310 radamp[rval] += absval * wgtgrid[k, j]
311 radamp2[rval] += absval * absval * wgtgrid[k, j]
312 radwght[rval] += wgtgrid[k, j]
313 if np.max(radwght) == 0.0:
314 # empty channel
315 return
316 radamp[radwght != 0] = radamp[radwght != 0] / radwght[radwght != 0]
317 radamp2[radwght != 0] = radamp2[radwght != 0] / radwght[radwght != 0]
318 maxsenspos = np.argmax(radwght)
319 # normalize radweight
320 # normrdwght=radwght/np.max(radwght)
321 sig = np.sqrt(np.abs(radamp2 - radamp * radamp))
322 # nescale relative sigmas by number of weights att the point
323 # medsig=np.median(sig[sig !=0])
324 sigtouse = sig[maxsenspos]
325 # sig[normrdwght!=0]=sig[normrdwght!=0]/normrdwght[normrdwght!=0]
326 xvalnz = xval[(sig != 0.0) & (radamp != 0)]
327 radampnz = radamp[(sig != 0) & (radamp != 0)]
328 try:
329 fitnz = curve_fit(self.resid_cinco, xvalnz, radampnz)
330 except:
331 # print("failed to curve_fit")
332 return
333 ###
334 # print('min max of sig and max sens one', np.min(sig), np.max(sig), sigtouse)
335 signz = sig[sig != 0.0]
336 sig = np.interp(xval, xvalnz, signz)
337 # print( 'corvar ', fitnz[1])
338 radfit = self.resid_cinco(xval, *fitnz[0])
339 # radamp=np.ma.array(radamp, mask=(radwght == 0))
340 # radfit=np.ma.array(radfit, mask=(radwght == 0))
341 max_rad_idx = np.where(xval == np.max(xvalnz))[0][0]
342 if self.doplot:
343 # pl.figure()
344 ax1 = pl.subplot(211)
346 # pl.plot(xval, radfit+sig,'+')
347 ax1.errorbar(
348 xval[0:max_rad_idx],
349 radfit[0:max_rad_idx],
350 yerr=sig[0:max_rad_idx],
351 ecolor="lime",
352 fmt="none",
353 label="sigma",
354 )
355 ax1.plot(
356 xval[0:max_rad_idx],
357 radamp[0:max_rad_idx],
358 "o",
359 color="magenta",
360 label="mean radial value",
361 )
362 ax1.plot(
363 xval[0:max_rad_idx],
364 radfit[0:max_rad_idx],
365 "k",
366 label="fitted radial value",
367 )
368 ax1.set_ylabel("Amplitude")
369 ax1.set_xlabel("uvdist in pix")
370 ax1.legend()
371 if self.doplot:
372 ax2 = pl.subplot(212)
373 for j in range(npixy):
374 for k in range(npixx):
375 # sweep over all points
376 # if points are not already flagged
377 yval2 = (j - cy) * (j - cy)
378 if not flggrid[k, j]:
379 r = int(np.sqrt(yval2 + (k - cx) * (k - cx)))
380 # if(r < npoints and np.abs(uvgrid[k,j]) > (radfit[r]+nsigma*max(medsig, sig[r]))):
381 if r < npoints and (
382 np.abs(uvgrid[k, j]) > (radfit[r] + nsigma * sigtouse)
383 ):
384 if self.doplot:
385 ax2.plot(r, np.abs(uvgrid[k, j]), "go")
386 uvgrid[k, j] = 0
387 wgtgrid[k, j] = 0
388 flggrid[k, j] = True
389 else:
390 if self.doplot:
391 ax2.plot(r, np.abs(uvgrid[k, j]), "b+")
393 ######################################################################################################
394 # Bonus joke : The Dalai Lama walks into a pizza shop and says "Can you make me one with everything?"
395 # The cashier hands him the pizza and says "That'll be $12.50." The Dalai Lama hands him a
396 # $20 bill and waits. After a few moments, he asks "Where's my change?" The cashier replies
397 # "Change comes from within."
399 # This needs to be a static method, otherwise numba cannot compile it because it does not understand self
400 @staticmethod
401 # @nb.njit(cache=True)
402 def apply_flags(
403 dat: np.ndarray,
404 flg: np.ndarray,
405 uvlen: np.ndarray,
406 radial_mean: np.ndarray,
407 radial_mad: np.ndarray,
408 annuli: np.ndarray,
409 nsigma: float = 5.0,
410 ) -> np.ndarray:
411 """
412 Apply flags based on the radial profile to the input data column
413 """
415 nrow = uvlen.shape[0]
417 for rr in range(nrow):
419 # feature of searchsorted will return N if above max(annuli)
420 annidx = annuli.size - 1
421 if uvlen[rr] < annuli[-1]:
422 annidx = np.searchsorted(annuli, uvlen[rr])
423 if (
424 np.abs(dat[..., rr])
425 > radial_mean[annidx] + nsigma * radial_mad[annidx]
426 ):
427 flg[..., rr] = True
429 return flg
431 ###########################################################################################
432 def accumulate_continuum_grid(
433 self,
434 tb: casatools.table,
435 npol: int,
436 nchan: int,
437 npoints: int,
438 deltau: float,
439 deltav: float,
440 ) -> Tuple[np.ndarray, np.ndarray]:
441 """
442 If the input msuvbin has multiple channels, loop over them to
443 accumulate on a single grid. This allows for a better estimate of the
444 radial profile from a "fuller" UV grid before flagging outliers
445 per-plane.
447 Inputs:
448 tb Table object - must be open
449 npol Number of polarizations'
450 nchan Number of channels
451 npoints Number of points in the grid
452 deltau U spacing in lambda
453 deltav V spacing in lambda
455 Returns:
456 uvgrid Accumulated UV grid
457 uvnpt Number of points per UV cell
458 """
460 uvgrid_cont = np.zeros((npoints, npoints), dtype=np.complex128)
461 wgtgrid_cont = np.zeros((npoints, npoints), dtype=np.float64)
463 for pol in range(npol):
464 for chan in range(nchan):
465 dat = tb.getcolslice("DATA", [pol, chan], [pol, chan], [1, 1])
466 flg = tb.getcolslice("FLAG", [pol, chan], [pol, chan], [1, 1])
467 wgt = tb.getcolslice(
468 "WEIGHT_SPECTRUM", [pol, chan], [pol, chan], [1, 1]
469 )
471 if dat.size == 0 or flg.size == 0 or wgt.size == 0:
472 casalog.post(
473 "Zero size array read. Skipping.",
474 "WARN",
475 "task_msuvbinflag",
476 )
477 continue
479 dat_grid = dat[0, 0, :].reshape([npoints, npoints])
480 flg_grid = flg[0, 0, :].reshape([npoints, npoints])
481 wgt_grid = wgt[0, 0, :].reshape([npoints, npoints])
483 # Flag the data as necessary
484 dat_grid = dat_grid * ~flg_grid
486 uvgrid_cont += dat_grid # should not this respect the flagged data i.e not add the data which are flagged ?
487 wgtgrid_cont += wgt_grid
489 return uvgrid_cont, wgtgrid_cont
491 ##########################################################
492 def flagViaBin_radial(self, sigma: float = 5):
494 tb.open(self.binnedvis, nomodify=False)
495 msuvbinkey = tb.getkeyword("MSUVBIN")
497 # msuvbinkey.keys()
498 # Out[5]: dict_keys(['csys', 'nchan', 'npol', 'numvis', 'nx', 'ny', 'sumweight'])
500 # in radian
501 dra = msuvbinkey["csys"]["direction0"]["cdelt"][0]
502 ddec = msuvbinkey["csys"]["direction0"]["cdelt"][1]
504 nx = msuvbinkey["nx"]
505 ny = msuvbinkey["ny"]
507 # in radian
508 ra_extent = dra * nx
509 dec_extent = ddec * ny
511 # in Lambda
512 deltau = 1.0 / ra_extent
513 deltav = 1.0 / dec_extent
515 npol = msuvbinkey["npol"]
516 nchan = msuvbinkey["nchan"]
518 npoints = min(nx, ny)
520 uvw = tb.getcol("UVW")
521 uvlen_m = np.sqrt(uvw[0] ** 2 + uvw[1] ** 2)
523 # Accumulate all channels in a single grid
524 uvgrid_cont, wgtgrid_cont = self.accumulate_continuum_grid(
525 tb, npol, nchan, npoints, deltau, deltav
526 )
527 # Calculate the radial profile
528 radial_mean, radial_mad, annuli = self.calc_radial_profile_ann(
529 uvgrid_cont, uvlen_m
530 )
531 # radial_fit = fit_radial_profile(annuli, np.abs(radial_mean), radial_mad, deg=2)
533 if self.doplot:
534 import matplotlib.pyplot as plt
535 from matplotlib.colors import LogNorm
537 fig, ax = plt.subplots()
538 ax.plot(annuli[2:], np.abs(radial_mean)[2:], "-o", label="data")
539 # ax.plot(np.linspace(annuli.min(), annuli.max(), 100), np.polyval(radial_fit, np.linspace(annuli.min(), annuli.max(), 100)), label='fit')
540 ax.fill_between(
541 annuli,
542 np.abs(radial_mean) - radial_mad,
543 np.abs(radial_mean) + radial_mad,
544 alpha=0.5,
545 )
546 ax.set_xlabel("Radius")
547 ax.set_ylabel("Intensity (Jy)")
548 ax.set_title("Radial Mean")
549 # ax.set_yscale('symlog', linthresh=1e-9)
550 ax.legend()
551 plt.tight_layout()
553 plt.savefig("radprof.jpg", bbox_inches="tight")
555 fig, ax = plt.subplots(1, 1)
556 uvgrid_shape = uvgrid_cont.shape
557 ax.imshow(
558 np.abs(uvgrid_cont),
559 origin="lower",
560 norm=LogNorm(vmin=1e-12, vmax=1),
561 extent=[
562 -uvgrid_shape[0] // 2,
563 uvgrid_shape[0] // 2,
564 -uvgrid_shape[1] // 2,
565 uvgrid_shape[1] // 2,
566 ],
567 )
568 ax.set_title("UV grid")
569 plt.tight_layout()
570 # plt.savefig('uvgrid.jpg', bbox_inches='tight')
571 plt.show()
573 for pol in range(npol):
574 for chan in range(nchan):
575 dat = np.asarray(
576 tb.getcolslice("DATA", [pol, chan], [pol, chan], [1, 1])
577 )
578 flg = np.asarray(
579 tb.getcolslice("FLAG", [pol, chan], [pol, chan], [1, 1])
580 )
581 wgt = np.asarray(
582 tb.getcolslice(
583 "WEIGHT_SPECTRUM", [pol, chan], [pol, chan], [1, 1]
584 )
585 )
587 if dat.size == 0 or flg.size == 0 or wgt.size == 0:
588 casalog.post(
589 "Zero size array read. Skipping.",
590 "WARN",
591 "task_msuvbinflag",
592 )
593 continue
595 # Do the flagging and write back
596 flg_new = self.apply_flags(
597 dat,
598 flg,
599 uvlen_m,
600 radial_mean,
601 radial_mad,
602 annuli,
603 nsigma=sigma,
604 )
606 tb.putcolslice("DATA", dat, [pol, chan], [pol, chan])
607 tb.putcolslice("FLAG", flg_new, [pol, chan], [pol, chan])
608 tb.putcolslice(
609 "WEIGHT_SPECTRUM", wgt, [pol, chan], [pol, chan]
610 )
612 tb.clearlocks()
613 tb.close()
614 tb.done()
616 ###################################################################
617 def flag_radial_per_plane(self, sigma=5) -> None:
618 tb.open(self.binnedvis, nomodify=False)
619 msuvbinkey = tb.getkeyword("MSUVBIN")
620 nx = msuvbinkey["nx"]
621 ny = msuvbinkey["ny"]
622 if nx != ny:
623 raise Exception("Do not deal with non square gridded vis")
624 npol = msuvbinkey["npol"]
625 nchan = msuvbinkey["nchan"]
627 for c in range(nchan):
628 dat = tb.getcolslice("DATA", [0, c], [npol - 1, c], [1, 1])
629 flg = tb.getcolslice("FLAG", [0, c], [npol - 1, c], [1, 1])
630 wgt = tb.getcolslice(
631 "WEIGHT_SPECTRUM", [0, c], [npol - 1, c], [1, 1]
632 )
633 #########
634 of = np.sum(flg[:, 0, :])
635 casalog.post(
636 "BEFORE chan %d number of unflagged points: %d max: %f"
637 % (
638 c,
639 nx * nx * npol - of,
640 np.max(np.abs(dat[:, 0, :])),
641 ),
642 "DEBUG",
643 "task_msuvbinflag",
644 )
645 # print (f'BEFORE chan {c} number of unflagged points: {nx*nx*npol-of} max: {np.max(np.abs(dat[:,0,:]))}')
646 ########
647 for k in range(npol):
648 if self.doplot:
649 pl.clf()
650 ax1 = pl.subplot(211)
651 ax1.set_title(
652 f"radial mean Amp and fit for chan {c} and pol {k} "
653 )
654 a = dat[k, 0, :].reshape([nx, nx])
655 f = flg[k, 0, :].reshape([nx, nx])
656 w = wgt[k, 0, :].reshape([nx, nx])
657 self.calc_radial_profile_and_fit(a, w, f, sigma)
658 if self.doplot:
659 # pl.show()
660 pl.savefig(f"rad_{self.binnedvis}_c{c}_p{k}.jpg")
661 # input("Press Enter to continue...")
662 #########
663 of = np.sum(flg[:, 0, :])
664 casalog.post(
665 "AFTER chan %d number of unflagged points: %d max: %f"
666 % (
667 c,
668 nx * nx * npol - of,
669 np.max(np.abs(dat[:, 0, :])),
670 ),
671 "DEBUG",
672 "task_msuvbinflag",
673 )
674 # print (f'AFTER chan {c} number of unflagged points: {nx*nx*npol-of} max: {np.max(np.abs(dat[:,0,:]))}')
675 # print ('=======================================================================')
676 ########
677 if not self.dryrun:
678 tb.putcolslice("FLAG", flg, [0, c], [npol - 1, c], [1, 1])
679 tb.putcolslice("DATA", dat, [0, c], [npol - 1, c], [1, 1])
680 tb.putcolslice(
681 "WEIGHT_SPECTRUM", wgt, [0, c], [npol - 1, c], [1, 1]
682 )
684 tb.done()
686 #########################################################################################
687 def flag_gradient(self) -> None:
688 # temporary till pykrige is installed by default
689 ###############################################
690 import pdb
691 import pip
693 try:
694 from pykrige.ok import OrdinaryKriging
695 except:
696 pip.main(["install", "pykrige"])
697 from pykrige.ok import OrdinaryKriging
698 ###############################################
699 factor = 5.0
700 tb.open(self.binnedvis, nomodify=False)
701 msuvbinkey = tb.getkeyword("MSUVBIN")
702 nx = msuvbinkey["nx"]
703 ny = msuvbinkey["ny"]
704 if nx != ny:
705 raise Exception("Do not deal with non square gridded vis")
706 npol = msuvbinkey["npol"]
707 nchan = msuvbinkey["nchan"]
708 # pdb.set_trace()
709 for c in range(nchan):
710 dat = tb.getcolslice("DATA", [0, c], [npol - 1, c], [1, 1])
711 flg = tb.getcolslice("FLAG", [0, c], [npol - 1, c], [1, 1])
712 wgt = tb.getcolslice(
713 "WEIGHT_SPECTRUM", [0, c], [npol - 1, c], [1, 1]
714 )
715 #########
716 of = np.sum(flg[:, 0, :])
717 print(
718 f"BEFORE chan {c} number of unflagged points: {nx*nx*npol-of} max: {np.max(np.abs(dat[:,0,:]))}"
719 )
720 ########
721 for k in range(npol):
722 a = dat[k, 0, :].reshape([nx, nx])
723 f = flg[k, 0, :].reshape([nx, nx])
724 w = wgt[k, 0, :].reshape([nx, nx])
725 f[w == 0.0] = True
727 if self.doplot:
728 pl.clf()
729 pl.ion()
730 af = np.ma.array(np.abs(a), mask=f)
731 med = np.ma.median(af)
732 rms = np.ma.std(af)
733 ax1 = pl.subplot(121)
734 pl.imshow(np.abs(a), vmin=med - rms, vmax=med + 4 * rms)
735 pl.title(f"BEFORE chan {c} and pol{k}")
736 # ax1.set_title(f'gradient stuff for chan {c} and pol {k} ')
737 self.locateViaKrige(a, f, factor)
738 if self.doplot:
740 print(f"chan{c}, pol{k}")
742 pl.subplot(122)
743 pl.imshow(np.abs(a), vmin=med - rms, vmax=med + 4 * rms)
744 pl.title("AFTER")
745 pl.show()
746 pl.savefig(f"rad_{self.binnedvis}_c{c}_p{k}.jpg")
747 # input("Press Enter to continue...")
748 time.sleep(10)
749 #########
750 of = np.sum(flg[:, 0, :])
751 print(
752 f"AFTER chan {c} number of unflagged points: {nx*nx*npol-of} max: {np.max(np.abs(dat[:,0,:]))}"
753 )
754 print(
755 "======================================================================="
756 )
757 ########
758 if not self.dryrun:
759 tb.putcolslice("FLAG", flg, [0, c], [npol - 1, c], [1, 1])
760 # tb.putcolslice('DATA', dat, [0,c], [npol-1,c], [1,1])
761 # tb.putcolslice('WEIGHT_SPECTRUM', wgt, [0, c], [npol-1, c], [1, 1])
763 tb.done()
765 @staticmethod
766 def locate_von(grid, radius=1, scale=0.3):
768 flagpoints = []
769 gridpoints = []
770 npoints = len(grid)
771 gradients = np.gradient(grid)
772 du = gradients[0]
773 dv = gradients[1]
774 th = np.sqrt(
775 np.ma.max(du) * np.ma.max(du) + np.ma.max(dv) * np.ma.max(dv)
776 )
777 print("Max grad", th, np.ma.max(du), np.ma.max(dv))
778 scale = th * scale
780 for i in range(int(radius), int(npoints - radius)):
781 for j in range(int(radius), int(npoints - radius)):
782 if grid.mask[i, j] == False:
783 du_up = du[i + radius][j]
784 du_down = du[i - radius][j]
785 dv_up = dv[i][j + radius]
786 dv_down = dv[i][j - radius]
788 if (
789 (np.abs(du_up) > scale)
790 or (np.abs(du_down) > scale)
791 or (np.abs(dv_up) > scale)
792 or (np.abs(dv_down) > scale)
793 ):
794 if (np.sign(du_up) == -1 * np.sign(du_down)) or (
795 np.sign(dv_up) == -1 * np.sign(dv_down)
796 ):
797 flagpoints.append(grid[i][j])
798 gridpoints.append((i, j))
800 return flagpoints, gridpoints
802 @staticmethod
803 def locateViaKrige(grid: np.array, flag: np.array, factor=5):
804 from pykrige.ok import OrdinaryKriging
806 b = np.abs(grid[flag == False])
807 ou = np.where(flag == False)
808 print(f"number of uv points {len(ou[1])}")
809 OK = OrdinaryKriging(
810 ou[1], ou[0], b, variogram_model="linear", exact_values=False
811 )
812 gridpoints = np.array(np.arange(0, len(grid)), dtype=np.float64)
813 z, ss = OK.execute("grid", gridpoints, gridpoints)
814 diffmap = np.ma.array(np.abs(np.abs(grid) - z), mask=flag)
815 med = np.ma.median(diffmap)
816 for k in range(len(ou[1])):
817 # difference with average around point
818 # diff=(9*z[ou[0][k], ou[1][k]]-(np.sum(z[ou[0][k]-1:ou[0][k]+2, ou[1][k]-1:ou[1][k]+2])))/8.0
819 diff = diffmap[ou[0][k], ou[1][k]]
820 # frac=np.abs(diff)/grid[ou[0][k], ou[1][k]]
821 if diff > factor * med:
822 flag[ou[0][k], ou[1][k]] = True
823 grid[ou[0][k], ou[1][k]] = 0.0