Coverage for  / home / casatest / venv / lib / python3.12 / site-packages / casatasks / private / imagerhelpers / msuvbinflag_algorithms.py: 9%

387 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-12 07:14 +0000

1import os 

2import numpy as np 

3 

4# import numba as nb 

5import casatools 

6import time 

7from scipy.optimize import curve_fit 

8from typing import Tuple, List, Union, Optional 

9 

10from casatasks import casalog 

11 

12import matplotlib.pyplot as pl 

13 

14ms = casatools.ms() 

15tb = casatools.table() 

16me = casatools.measures() 

17qa = casatools.quanta() 

18ia = casatools.image() 

19im = casatools.imager() 

20 

21 

22class UVGridFlag: 

23 def __init__(self, binnedvis: str, doplot: bool = False) -> None: 

24 self.binnedvis = binnedvis 

25 self.doplot = doplot 

26 ## parameter when debugging to test algorithm but not change grid. 

27 self.dryrun = False 

28 if self.doplot: 

29 pl.ion() 

30 

31 # @nb.njit(cache=True) 

32 def populate_grid( 

33 self, 

34 uvw: np.array, 

35 stokesI: np.array, 

36 uvgrid: np.array, 

37 uvgrid_npt: np.array, 

38 deltau: float, 

39 deltav: float, 

40 npix: int, 

41 ): 

42 for ii in range(len(uvw[0])): 

43 uidx = int(np.round(uvw[0][ii] // deltau + npix // 2)) 

44 vidx = int(np.round(uvw[1][ii] // deltav + npix // 2)) 

45 

46 uvgrid[uidx, vidx] += stokesI[ii] 

47 uvgrid_npt[uidx, vidx] += 1 

48 

49 return uvgrid, uvgrid_npt 

50 

51 def mad(self, inpdat: np.array) -> float: 

52 """ 

53 Calculate the STD via MAD for the input data 

54 

55 Inputs: 

56 inpdat Input numpy array 

57 

58 Returns: 

59 std Calculate the std via mad 

60 """ 

61 

62 med = np.median(inpdat) 

63 mad = np.median(np.abs(inpdat - med)) 

64 

65 # 1.4826 is the scaling factor for a normal distribution 

66 # to convert MAD to STD 

67 return 1.4826 * mad 

68 

69 def msuvbin_to_uvgrid( 

70 self, ms: str, npix: int, deltau: float, deltav: float 

71 ) -> Tuple[np.array, np.array]: 

72 tb.open(ms) 

73 uvw = tb.getcol("UVW") 

74 data = tb.getcol("DATA") 

75 tb.close() 

76 

77 umin, umax = np.min(uvw[0]), np.max(uvw[0]) 

78 vmin, vmax = np.min(uvw[1]), np.max(uvw[1]) 

79 wmin, wmax = np.min(uvw[2]), np.max(uvw[2]) 

80 

81 uvals = np.linspace(umin, umax, npix) 

82 vvals = np.linspace(vmin, vmax, npix) 

83 

84 uvgrid = np.zeros((npix, npix), dtype=np.complex128) 

85 uvgrid_npt = np.zeros((npix, npix), dtype=int) 

86 

87 stokesI = 0.5 * (data[0] + data[1]) 

88 stokesI = np.squeeze(stokesI) 

89 

90 uvgrid, uvgrid_npt = self.populate_grid( 

91 uvw, stokesI, uvgrid, uvgrid_npt, deltau, deltav, npix 

92 ) 

93 # Average per uv cell 

94 idx = np.where(uvgrid_npt != 0) 

95 uvgrid[idx] = uvgrid[idx] / uvgrid_npt[idx] 

96 

97 return uvgrid, uvgrid_npt 

98 

99 def resid_cube( 

100 self, x: float, a: float, b: float, c: float, d: float 

101 ) -> float: 

102 return a * x**3 + b * x**2 + c * x + d 

103 

104 def resid_cinco( 

105 self, 

106 x: np.ndarray, 

107 a: float, 

108 b: float, 

109 c: float, 

110 d: float, 

111 e: float, 

112 f: float, 

113 ) -> np.ndarray: 

114 return a * x**5 + b * x**4 + c * x**3 + d * x**2 + e * x + f 

115 

116 def fit_radial_profile( 

117 self, 

118 xvals: np.ndarray, 

119 yvals: np.ndarray, 

120 ystd: np.ndarray, 

121 deg: int = 3, 

122 clip_sigma: float = 3, 

123 ) -> np.ndarray: 

124 """ 

125 Fit the radial profile with a polynomial 

126 """ 

127 

128 # print(f"deg {deg}, clip_sigma {clip_sigma}") 

129 # print(f"xvals shape {xvals.shape}, yvals shape {yvals.shape}") 

130 # print(f"non-zero x {np.count_nonzero(xvals)}, non-zero y {np.count_nonzero(yvals)}") 

131 

132 idx = np.where(yvals != 0) 

133 xvals = xvals[idx] 

134 yvals = yvals[idx] 

135 ystd = ystd[idx] 

136 

137 # print(f"xvals {xvals}, yvals {yvals}") 

138 

139 # Fit the radial profile with a polynomial 

140 # pfit, pcov = curve_fit(resid_function, xvals, yvals, sigma=ystd, p0=[1, 1, 1, 1]) 

141 

142 pfit = np.polyfit(xvals, yvals, deg=2) 

143 yfit = np.polyval(pfit, np.linspace(xvals.min(), xvals.max(), 100)) 

144 

145 has_converged = False 

146 while not has_converged: 

147 # Outlier rejection while fitting 

148 resid = yvals - np.polyval(pfit, xvals) 

149 resid_mad = self.mad(resid) 

150 resid_med = np.median(resid) 

151 idx = np.where( 

152 (resid > resid_med - clip_sigma * resid_mad) 

153 & (resid < resid_med + clip_sigma * resid_mad) 

154 ) 

155 

156 if len(idx[0]) == len(xvals): 

157 has_converged = True 

158 continue 

159 

160 xvals = xvals[idx] 

161 yvals = yvals[idx] 

162 ystd = ystd[idx] 

163 

164 # pfit = np.polyfit(xvals, yvals, deg=deg) 

165 

166 pfit, pcov = curve_fit( 

167 self.resid_cube, xvals, yvals, sigma=ystd, p0=[1, 1, 1, 1] 

168 ) 

169 yfit = np.polyval(pfit, np.linspace(xvals.min(), xvals.max(), 100)) 

170 

171 # import matplotlib.pyplot as plt 

172 

173 # fig, ax = plt.subplots() 

174 # ax.plot(xvals, yvals, '-o', label='Data') 

175 # ax.plot(np.linspace(xvals.min(), xvals.max(), 100), yfit, '--', label='Fit') 

176 # ax.set_xlabel('Radius') 

177 # ax.set_ylabel('Intensity') 

178 # ax.legend() 

179 # plt.tight_layout() 

180 # plt.show() 

181 

182 return pfit 

183 

184 def calc_radial_profile_ann( 

185 self, uvgrid: np.ndarray, uvlen_m: np.ndarray 

186 ) -> Tuple[np.array, np.array, np.array]: 

187 """ 

188 Calculate the annular average of the uvgrid for every radius, 

189 and fit the 1D radial profile with a polynomial. 

190 """ 

191 

192 nbin = 30 

193 

194 npixx, npixy = uvgrid.shape[0], uvgrid.shape[1] 

195 cx, cy = npixx // 2, npixy // 2 

196 

197 uvlen_m_grid = uvlen_m.reshape([npixx, npixy]) 

198 

199 # Generate grid of radii 

200 x = np.arange(npixx) - cx 

201 y = np.arange(npixy) - cy 

202 # rad = np.sqrt(x**2 + y**2).astype(int) 

203 

204 yy, xx = np.meshgrid(x, y) 

205 # radgrid = np.sqrt(xx**2 + yy**2).astype(int) 

206 

207 # Create log-spaced annuli to account for reducing UV coverage with radius 

208 # Minimum annulus is 5px 

209 annuli = np.logspace(0, np.log10(uvlen_m_grid.max()), nbin) 

210 annuli = np.round(annuli).astype(int) 

211 

212 radial_mean = np.zeros(nbin) 

213 radial_mad = np.zeros(nbin) 

214 

215 ann_min = 0 

216 for idx, ann in enumerate(annuli): 

217 ridx = np.where((uvlen_m_grid >= ann_min) & (uvlen_m_grid < ann)) 

218 uvgrid_sel = uvgrid[ridx] 

219 uvgrid_sel = uvgrid_sel[np.abs(uvgrid_sel) != 0] 

220 

221 if len(uvgrid_sel) == 0: 

222 radial_mean[idx] = 0 

223 radial_mad[idx] = 0.0 

224 else: 

225 radial_mean[idx] = np.mean(np.abs(uvgrid_sel)) 

226 radial_mad[idx] = self.mad(np.abs(uvgrid_sel)) 

227 

228 ann_min = ann 

229 

230 return radial_mean, radial_mad, annuli 

231 

232 def calc_radial_profile_pix( 

233 self, uvgrid: np.ndarray, deltau: float, deltav: float 

234 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 

235 """ 

236 Calculate the radial per-pixel average of the uvgrid for every radius, 

237 and fit the 1D radial profile with a polynomial. 

238 """ 

239 

240 npixx, npixy = uvgrid.shape[0], uvgrid.shape[1] 

241 cx, cy = npixx // 2, npixy // 2 

242 

243 # Generate radial values from 0 to max 

244 x = np.arange(cx) 

245 y = np.arange(cy) 

246 rad = np.sqrt(x**2 + y**2).astype(int) 

247 

248 # Generate grid of radii 

249 x = np.arange(npixx) - cx 

250 y = np.arange(npixy) - cy 

251 yy, xx = np.meshgrid(x, y) 

252 radgrid = np.sqrt(xx**2 + yy**2).astype(int) 

253 

254 radial_mean = np.zeros(np.max([cx, cy])) 

255 radial_mad = np.zeros(np.max([cx, cy])) 

256 

257 for idx, rr in enumerate(rad): 

258 if idx == len(rad) - 1: 

259 ridx = np.where((radgrid >= rad[idx])) 

260 else: 

261 ridx = np.where( 

262 (radgrid > rad[idx]) & (radgrid <= rad[idx + 1]) 

263 ) 

264 

265 uvgrid_sel = uvgrid[ridx] 

266 uvgrid_sel = uvgrid_sel[np.abs(uvgrid_sel) != 0] 

267 

268 if len(uvgrid_sel) == 0: 

269 radial_mean[idx] = 0 

270 radial_mad[idx] = 0.0 

271 else: 

272 radial_mean[idx] = np.mean(uvgrid_sel) 

273 radial_mad[idx] = self.mad(uvgrid_sel) 

274 

275 return uvgrid, radial_mean, radial_mad, rad * deltau 

276 

277 #############################################################3 

278 def calc_radial_profile_and_fit( 

279 self, 

280 uvgrid: np.ndarray, 

281 wgtgrid: np.ndarray, 

282 flggrid: np.ndarray, 

283 nsigma: float, 

284 ) -> None: 

285 """ 

286 Does a weighted radial mean profile and fit it and determines flag for point that is nsigma 

287 above fitted radial profile. 

288 flggrid gets modifies 

289 Right now the wgtgrid and uvgrid are zeroed at the flagged cells...but this is not necessary 

290 """ 

291 npixx, npixy = uvgrid.shape[0], uvgrid.shape[1] 

292 cx, cy = npixx // 2, npixy // 2 

293 # print(f"npixx {npixx}, npixy {npixy}, centerx {cx} centery {cy}") 

294 

295 # Generate radial values from 0 to max 

296 x = np.arange(cx) 

297 y = np.arange(cy) 

298 rad = np.sqrt(x**2 + y**2).astype(int) 

299 npoints = int(np.max(rad)) + 1 

300 radamp = np.zeros(npoints) 

301 radamp2 = np.zeros(npoints) 

302 radwght = np.zeros(npoints) 

303 xval = np.array(range(npoints), dtype="float") 

304 for j in range(1, npixy): 

305 yval2 = (j - cy) * (j - cy) 

306 for k in range(1, npixx): 

307 rval = int(np.sqrt((k - cx) * (k - cx) + yval2)) 

308 if wgtgrid[k, j] > 0.0: 

309 absval = np.abs(uvgrid[k, j]) 

310 radamp[rval] += absval * wgtgrid[k, j] 

311 radamp2[rval] += absval * absval * wgtgrid[k, j] 

312 radwght[rval] += wgtgrid[k, j] 

313 if np.max(radwght) == 0.0: 

314 # empty channel 

315 return 

316 radamp[radwght != 0] = radamp[radwght != 0] / radwght[radwght != 0] 

317 radamp2[radwght != 0] = radamp2[radwght != 0] / radwght[radwght != 0] 

318 maxsenspos = np.argmax(radwght) 

319 # normalize radweight 

320 # normrdwght=radwght/np.max(radwght) 

321 sig = np.sqrt(np.abs(radamp2 - radamp * radamp)) 

322 # nescale relative sigmas by number of weights att the point 

323 # medsig=np.median(sig[sig !=0]) 

324 sigtouse = sig[maxsenspos] 

325 # sig[normrdwght!=0]=sig[normrdwght!=0]/normrdwght[normrdwght!=0] 

326 xvalnz = xval[(sig != 0.0) & (radamp != 0)] 

327 radampnz = radamp[(sig != 0) & (radamp != 0)] 

328 try: 

329 fitnz = curve_fit(self.resid_cinco, xvalnz, radampnz) 

330 except: 

331 # print("failed to curve_fit") 

332 return 

333 ### 

334 # print('min max of sig and max sens one', np.min(sig), np.max(sig), sigtouse) 

335 signz = sig[sig != 0.0] 

336 sig = np.interp(xval, xvalnz, signz) 

337 # print( 'corvar ', fitnz[1]) 

338 radfit = self.resid_cinco(xval, *fitnz[0]) 

339 # radamp=np.ma.array(radamp, mask=(radwght == 0)) 

340 # radfit=np.ma.array(radfit, mask=(radwght == 0)) 

341 max_rad_idx = np.where(xval == np.max(xvalnz))[0][0] 

342 if self.doplot: 

343 # pl.figure() 

344 ax1 = pl.subplot(211) 

345 

346 # pl.plot(xval, radfit+sig,'+') 

347 ax1.errorbar( 

348 xval[0:max_rad_idx], 

349 radfit[0:max_rad_idx], 

350 yerr=sig[0:max_rad_idx], 

351 ecolor="lime", 

352 fmt="none", 

353 label="sigma", 

354 ) 

355 ax1.plot( 

356 xval[0:max_rad_idx], 

357 radamp[0:max_rad_idx], 

358 "o", 

359 color="magenta", 

360 label="mean radial value", 

361 ) 

362 ax1.plot( 

363 xval[0:max_rad_idx], 

364 radfit[0:max_rad_idx], 

365 "k", 

366 label="fitted radial value", 

367 ) 

368 ax1.set_ylabel("Amplitude") 

369 ax1.set_xlabel("uvdist in pix") 

370 ax1.legend() 

371 if self.doplot: 

372 ax2 = pl.subplot(212) 

373 for j in range(npixy): 

374 for k in range(npixx): 

375 # sweep over all points 

376 # if points are not already flagged 

377 yval2 = (j - cy) * (j - cy) 

378 if not flggrid[k, j]: 

379 r = int(np.sqrt(yval2 + (k - cx) * (k - cx))) 

380 # if(r < npoints and np.abs(uvgrid[k,j]) > (radfit[r]+nsigma*max(medsig, sig[r]))): 

381 if r < npoints and ( 

382 np.abs(uvgrid[k, j]) > (radfit[r] + nsigma * sigtouse) 

383 ): 

384 if self.doplot: 

385 ax2.plot(r, np.abs(uvgrid[k, j]), "go") 

386 uvgrid[k, j] = 0 

387 wgtgrid[k, j] = 0 

388 flggrid[k, j] = True 

389 else: 

390 if self.doplot: 

391 ax2.plot(r, np.abs(uvgrid[k, j]), "b+") 

392 

393 ###################################################################################################### 

394 # Bonus joke : The Dalai Lama walks into a pizza shop and says "Can you make me one with everything?" 

395 # The cashier hands him the pizza and says "That'll be $12.50." The Dalai Lama hands him a 

396 # $20 bill and waits. After a few moments, he asks "Where's my change?" The cashier replies 

397 # "Change comes from within." 

398 

399 # This needs to be a static method, otherwise numba cannot compile it because it does not understand self 

400 @staticmethod 

401 # @nb.njit(cache=True) 

402 def apply_flags( 

403 dat: np.ndarray, 

404 flg: np.ndarray, 

405 uvlen: np.ndarray, 

406 radial_mean: np.ndarray, 

407 radial_mad: np.ndarray, 

408 annuli: np.ndarray, 

409 nsigma: float = 5.0, 

410 ) -> np.ndarray: 

411 """ 

412 Apply flags based on the radial profile to the input data column 

413 """ 

414 

415 nrow = uvlen.shape[0] 

416 

417 for rr in range(nrow): 

418 

419 # feature of searchsorted will return N if above max(annuli) 

420 annidx = annuli.size - 1 

421 if uvlen[rr] < annuli[-1]: 

422 annidx = np.searchsorted(annuli, uvlen[rr]) 

423 if ( 

424 np.abs(dat[..., rr]) 

425 > radial_mean[annidx] + nsigma * radial_mad[annidx] 

426 ): 

427 flg[..., rr] = True 

428 

429 return flg 

430 

431 ########################################################################################### 

432 def accumulate_continuum_grid( 

433 self, 

434 tb: casatools.table, 

435 npol: int, 

436 nchan: int, 

437 npoints: int, 

438 deltau: float, 

439 deltav: float, 

440 ) -> Tuple[np.ndarray, np.ndarray]: 

441 """ 

442 If the input msuvbin has multiple channels, loop over them to 

443 accumulate on a single grid. This allows for a better estimate of the 

444 radial profile from a "fuller" UV grid before flagging outliers 

445 per-plane. 

446 

447 Inputs: 

448 tb Table object - must be open 

449 npol Number of polarizations' 

450 nchan Number of channels 

451 npoints Number of points in the grid 

452 deltau U spacing in lambda 

453 deltav V spacing in lambda 

454 

455 Returns: 

456 uvgrid Accumulated UV grid 

457 uvnpt Number of points per UV cell 

458 """ 

459 

460 uvgrid_cont = np.zeros((npoints, npoints), dtype=np.complex128) 

461 wgtgrid_cont = np.zeros((npoints, npoints), dtype=np.float64) 

462 

463 for pol in range(npol): 

464 for chan in range(nchan): 

465 dat = tb.getcolslice("DATA", [pol, chan], [pol, chan], [1, 1]) 

466 flg = tb.getcolslice("FLAG", [pol, chan], [pol, chan], [1, 1]) 

467 wgt = tb.getcolslice( 

468 "WEIGHT_SPECTRUM", [pol, chan], [pol, chan], [1, 1] 

469 ) 

470 

471 if dat.size == 0 or flg.size == 0 or wgt.size == 0: 

472 casalog.post( 

473 "Zero size array read. Skipping.", 

474 "WARN", 

475 "task_msuvbinflag", 

476 ) 

477 continue 

478 

479 dat_grid = dat[0, 0, :].reshape([npoints, npoints]) 

480 flg_grid = flg[0, 0, :].reshape([npoints, npoints]) 

481 wgt_grid = wgt[0, 0, :].reshape([npoints, npoints]) 

482 

483 # Flag the data as necessary 

484 dat_grid = dat_grid * ~flg_grid 

485 

486 uvgrid_cont += dat_grid # should not this respect the flagged data i.e not add the data which are flagged ? 

487 wgtgrid_cont += wgt_grid 

488 

489 return uvgrid_cont, wgtgrid_cont 

490 

491 ########################################################## 

492 def flagViaBin_radial(self, sigma: float = 5): 

493 

494 tb.open(self.binnedvis, nomodify=False) 

495 msuvbinkey = tb.getkeyword("MSUVBIN") 

496 

497 # msuvbinkey.keys() 

498 # Out[5]: dict_keys(['csys', 'nchan', 'npol', 'numvis', 'nx', 'ny', 'sumweight']) 

499 

500 # in radian 

501 dra = msuvbinkey["csys"]["direction0"]["cdelt"][0] 

502 ddec = msuvbinkey["csys"]["direction0"]["cdelt"][1] 

503 

504 nx = msuvbinkey["nx"] 

505 ny = msuvbinkey["ny"] 

506 

507 # in radian 

508 ra_extent = dra * nx 

509 dec_extent = ddec * ny 

510 

511 # in Lambda 

512 deltau = 1.0 / ra_extent 

513 deltav = 1.0 / dec_extent 

514 

515 npol = msuvbinkey["npol"] 

516 nchan = msuvbinkey["nchan"] 

517 

518 npoints = min(nx, ny) 

519 

520 uvw = tb.getcol("UVW") 

521 uvlen_m = np.sqrt(uvw[0] ** 2 + uvw[1] ** 2) 

522 

523 # Accumulate all channels in a single grid 

524 uvgrid_cont, wgtgrid_cont = self.accumulate_continuum_grid( 

525 tb, npol, nchan, npoints, deltau, deltav 

526 ) 

527 # Calculate the radial profile 

528 radial_mean, radial_mad, annuli = self.calc_radial_profile_ann( 

529 uvgrid_cont, uvlen_m 

530 ) 

531 # radial_fit = fit_radial_profile(annuli, np.abs(radial_mean), radial_mad, deg=2) 

532 

533 if self.doplot: 

534 import matplotlib.pyplot as plt 

535 from matplotlib.colors import LogNorm 

536 

537 fig, ax = plt.subplots() 

538 ax.plot(annuli[2:], np.abs(radial_mean)[2:], "-o", label="data") 

539 # ax.plot(np.linspace(annuli.min(), annuli.max(), 100), np.polyval(radial_fit, np.linspace(annuli.min(), annuli.max(), 100)), label='fit') 

540 ax.fill_between( 

541 annuli, 

542 np.abs(radial_mean) - radial_mad, 

543 np.abs(radial_mean) + radial_mad, 

544 alpha=0.5, 

545 ) 

546 ax.set_xlabel("Radius") 

547 ax.set_ylabel("Intensity (Jy)") 

548 ax.set_title("Radial Mean") 

549 # ax.set_yscale('symlog', linthresh=1e-9) 

550 ax.legend() 

551 plt.tight_layout() 

552 

553 plt.savefig("radprof.jpg", bbox_inches="tight") 

554 

555 fig, ax = plt.subplots(1, 1) 

556 uvgrid_shape = uvgrid_cont.shape 

557 ax.imshow( 

558 np.abs(uvgrid_cont), 

559 origin="lower", 

560 norm=LogNorm(vmin=1e-12, vmax=1), 

561 extent=[ 

562 -uvgrid_shape[0] // 2, 

563 uvgrid_shape[0] // 2, 

564 -uvgrid_shape[1] // 2, 

565 uvgrid_shape[1] // 2, 

566 ], 

567 ) 

568 ax.set_title("UV grid") 

569 plt.tight_layout() 

570 # plt.savefig('uvgrid.jpg', bbox_inches='tight') 

571 plt.show() 

572 

573 for pol in range(npol): 

574 for chan in range(nchan): 

575 dat = np.asarray( 

576 tb.getcolslice("DATA", [pol, chan], [pol, chan], [1, 1]) 

577 ) 

578 flg = np.asarray( 

579 tb.getcolslice("FLAG", [pol, chan], [pol, chan], [1, 1]) 

580 ) 

581 wgt = np.asarray( 

582 tb.getcolslice( 

583 "WEIGHT_SPECTRUM", [pol, chan], [pol, chan], [1, 1] 

584 ) 

585 ) 

586 

587 if dat.size == 0 or flg.size == 0 or wgt.size == 0: 

588 casalog.post( 

589 "Zero size array read. Skipping.", 

590 "WARN", 

591 "task_msuvbinflag", 

592 ) 

593 continue 

594 

595 # Do the flagging and write back 

596 flg_new = self.apply_flags( 

597 dat, 

598 flg, 

599 uvlen_m, 

600 radial_mean, 

601 radial_mad, 

602 annuli, 

603 nsigma=sigma, 

604 ) 

605 

606 tb.putcolslice("DATA", dat, [pol, chan], [pol, chan]) 

607 tb.putcolslice("FLAG", flg_new, [pol, chan], [pol, chan]) 

608 tb.putcolslice( 

609 "WEIGHT_SPECTRUM", wgt, [pol, chan], [pol, chan] 

610 ) 

611 

612 tb.clearlocks() 

613 tb.close() 

614 tb.done() 

615 

616 ################################################################### 

617 def flag_radial_per_plane(self, sigma=5) -> None: 

618 tb.open(self.binnedvis, nomodify=False) 

619 msuvbinkey = tb.getkeyword("MSUVBIN") 

620 nx = msuvbinkey["nx"] 

621 ny = msuvbinkey["ny"] 

622 if nx != ny: 

623 raise Exception("Do not deal with non square gridded vis") 

624 npol = msuvbinkey["npol"] 

625 nchan = msuvbinkey["nchan"] 

626 

627 for c in range(nchan): 

628 dat = tb.getcolslice("DATA", [0, c], [npol - 1, c], [1, 1]) 

629 flg = tb.getcolslice("FLAG", [0, c], [npol - 1, c], [1, 1]) 

630 wgt = tb.getcolslice( 

631 "WEIGHT_SPECTRUM", [0, c], [npol - 1, c], [1, 1] 

632 ) 

633 ######### 

634 of = np.sum(flg[:, 0, :]) 

635 casalog.post( 

636 "BEFORE chan %d number of unflagged points: %d max: %f" 

637 % ( 

638 c, 

639 nx * nx * npol - of, 

640 np.max(np.abs(dat[:, 0, :])), 

641 ), 

642 "DEBUG", 

643 "task_msuvbinflag", 

644 ) 

645 # print (f'BEFORE chan {c} number of unflagged points: {nx*nx*npol-of} max: {np.max(np.abs(dat[:,0,:]))}') 

646 ######## 

647 for k in range(npol): 

648 if self.doplot: 

649 pl.clf() 

650 ax1 = pl.subplot(211) 

651 ax1.set_title( 

652 f"radial mean Amp and fit for chan {c} and pol {k} " 

653 ) 

654 a = dat[k, 0, :].reshape([nx, nx]) 

655 f = flg[k, 0, :].reshape([nx, nx]) 

656 w = wgt[k, 0, :].reshape([nx, nx]) 

657 self.calc_radial_profile_and_fit(a, w, f, sigma) 

658 if self.doplot: 

659 # pl.show() 

660 pl.savefig(f"rad_{self.binnedvis}_c{c}_p{k}.jpg") 

661 # input("Press Enter to continue...") 

662 ######### 

663 of = np.sum(flg[:, 0, :]) 

664 casalog.post( 

665 "AFTER chan %d number of unflagged points: %d max: %f" 

666 % ( 

667 c, 

668 nx * nx * npol - of, 

669 np.max(np.abs(dat[:, 0, :])), 

670 ), 

671 "DEBUG", 

672 "task_msuvbinflag", 

673 ) 

674 # print (f'AFTER chan {c} number of unflagged points: {nx*nx*npol-of} max: {np.max(np.abs(dat[:,0,:]))}') 

675 # print ('=======================================================================') 

676 ######## 

677 if not self.dryrun: 

678 tb.putcolslice("FLAG", flg, [0, c], [npol - 1, c], [1, 1]) 

679 tb.putcolslice("DATA", dat, [0, c], [npol - 1, c], [1, 1]) 

680 tb.putcolslice( 

681 "WEIGHT_SPECTRUM", wgt, [0, c], [npol - 1, c], [1, 1] 

682 ) 

683 

684 tb.done() 

685 

686 ######################################################################################### 

687 def flag_gradient(self) -> None: 

688 # temporary till pykrige is installed by default 

689 ############################################### 

690 import pdb 

691 import pip 

692 

693 try: 

694 from pykrige.ok import OrdinaryKriging 

695 except: 

696 pip.main(["install", "pykrige"]) 

697 from pykrige.ok import OrdinaryKriging 

698 ############################################### 

699 factor = 5.0 

700 tb.open(self.binnedvis, nomodify=False) 

701 msuvbinkey = tb.getkeyword("MSUVBIN") 

702 nx = msuvbinkey["nx"] 

703 ny = msuvbinkey["ny"] 

704 if nx != ny: 

705 raise Exception("Do not deal with non square gridded vis") 

706 npol = msuvbinkey["npol"] 

707 nchan = msuvbinkey["nchan"] 

708 # pdb.set_trace() 

709 for c in range(nchan): 

710 dat = tb.getcolslice("DATA", [0, c], [npol - 1, c], [1, 1]) 

711 flg = tb.getcolslice("FLAG", [0, c], [npol - 1, c], [1, 1]) 

712 wgt = tb.getcolslice( 

713 "WEIGHT_SPECTRUM", [0, c], [npol - 1, c], [1, 1] 

714 ) 

715 ######### 

716 of = np.sum(flg[:, 0, :]) 

717 print( 

718 f"BEFORE chan {c} number of unflagged points: {nx*nx*npol-of} max: {np.max(np.abs(dat[:,0,:]))}" 

719 ) 

720 ######## 

721 for k in range(npol): 

722 a = dat[k, 0, :].reshape([nx, nx]) 

723 f = flg[k, 0, :].reshape([nx, nx]) 

724 w = wgt[k, 0, :].reshape([nx, nx]) 

725 f[w == 0.0] = True 

726 

727 if self.doplot: 

728 pl.clf() 

729 pl.ion() 

730 af = np.ma.array(np.abs(a), mask=f) 

731 med = np.ma.median(af) 

732 rms = np.ma.std(af) 

733 ax1 = pl.subplot(121) 

734 pl.imshow(np.abs(a), vmin=med - rms, vmax=med + 4 * rms) 

735 pl.title(f"BEFORE chan {c} and pol{k}") 

736 # ax1.set_title(f'gradient stuff for chan {c} and pol {k} ') 

737 self.locateViaKrige(a, f, factor) 

738 if self.doplot: 

739 

740 print(f"chan{c}, pol{k}") 

741 

742 pl.subplot(122) 

743 pl.imshow(np.abs(a), vmin=med - rms, vmax=med + 4 * rms) 

744 pl.title("AFTER") 

745 pl.show() 

746 pl.savefig(f"rad_{self.binnedvis}_c{c}_p{k}.jpg") 

747 # input("Press Enter to continue...") 

748 time.sleep(10) 

749 ######### 

750 of = np.sum(flg[:, 0, :]) 

751 print( 

752 f"AFTER chan {c} number of unflagged points: {nx*nx*npol-of} max: {np.max(np.abs(dat[:,0,:]))}" 

753 ) 

754 print( 

755 "=======================================================================" 

756 ) 

757 ######## 

758 if not self.dryrun: 

759 tb.putcolslice("FLAG", flg, [0, c], [npol - 1, c], [1, 1]) 

760 # tb.putcolslice('DATA', dat, [0,c], [npol-1,c], [1,1]) 

761 # tb.putcolslice('WEIGHT_SPECTRUM', wgt, [0, c], [npol-1, c], [1, 1]) 

762 

763 tb.done() 

764 

765 @staticmethod 

766 def locate_von(grid, radius=1, scale=0.3): 

767 

768 flagpoints = [] 

769 gridpoints = [] 

770 npoints = len(grid) 

771 gradients = np.gradient(grid) 

772 du = gradients[0] 

773 dv = gradients[1] 

774 th = np.sqrt( 

775 np.ma.max(du) * np.ma.max(du) + np.ma.max(dv) * np.ma.max(dv) 

776 ) 

777 print("Max grad", th, np.ma.max(du), np.ma.max(dv)) 

778 scale = th * scale 

779 

780 for i in range(int(radius), int(npoints - radius)): 

781 for j in range(int(radius), int(npoints - radius)): 

782 if grid.mask[i, j] == False: 

783 du_up = du[i + radius][j] 

784 du_down = du[i - radius][j] 

785 dv_up = dv[i][j + radius] 

786 dv_down = dv[i][j - radius] 

787 

788 if ( 

789 (np.abs(du_up) > scale) 

790 or (np.abs(du_down) > scale) 

791 or (np.abs(dv_up) > scale) 

792 or (np.abs(dv_down) > scale) 

793 ): 

794 if (np.sign(du_up) == -1 * np.sign(du_down)) or ( 

795 np.sign(dv_up) == -1 * np.sign(dv_down) 

796 ): 

797 flagpoints.append(grid[i][j]) 

798 gridpoints.append((i, j)) 

799 

800 return flagpoints, gridpoints 

801 

802 @staticmethod 

803 def locateViaKrige(grid: np.array, flag: np.array, factor=5): 

804 from pykrige.ok import OrdinaryKriging 

805 

806 b = np.abs(grid[flag == False]) 

807 ou = np.where(flag == False) 

808 print(f"number of uv points {len(ou[1])}") 

809 OK = OrdinaryKriging( 

810 ou[1], ou[0], b, variogram_model="linear", exact_values=False 

811 ) 

812 gridpoints = np.array(np.arange(0, len(grid)), dtype=np.float64) 

813 z, ss = OK.execute("grid", gridpoints, gridpoints) 

814 diffmap = np.ma.array(np.abs(np.abs(grid) - z), mask=flag) 

815 med = np.ma.median(diffmap) 

816 for k in range(len(ou[1])): 

817 # difference with average around point 

818 # diff=(9*z[ou[0][k], ou[1][k]]-(np.sum(z[ou[0][k]-1:ou[0][k]+2, ou[1][k]-1:ou[1][k]+2])))/8.0 

819 diff = diffmap[ou[0][k], ou[1][k]] 

820 # frac=np.abs(diff)/grid[ou[0][k], ou[1][k]] 

821 if diff > factor * med: 

822 flag[ou[0][k], ou[1][k]] = True 

823 grid[ou[0][k], ou[1][k]] = 0.0