Coverage for /wheeldirectory/casa-6.7.0-12-py3.10.el8/lib/py/lib/python3.10/site-packages/casatasks/private/task_sdbaseline.py: 93%

317 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-01 07:19 +0000

1from collections import Counter 

2import datetime 

3import os 

4import shutil 

5 

6from casatasks import casalog 

7from casatools import ms as mstool 

8from casatools import singledishms 

9 

10from . import sdutil 

11from .mstools import write_history 

12 

13ms = mstool() 

14 

15 

16@sdutil.callable_sdtask_decorator 

17def sdbaseline(infile=None, datacolumn=None, antenna=None, field=None, 

18 spw=None, timerange=None, scan=None, pol=None, intent=None, 

19 reindex=None, maskmode=None, thresh=None, avg_limit=None, 

20 minwidth=None, edge=None, blmode=None, dosubtract=None, 

21 blformat=None, bloutput=None, bltable=None, blfunc=None, 

22 order=None, npiece=None, applyfft=None, fftmethod=None, 

23 fftthresh=None, addwn=None, rejwn=None, clipthresh=None, 

24 clipniter=None, blparam=None, verbose=None, 

25 updateweight=None, sigmavalue=None, 

26 showprogress=None, minnrow=None, 

27 outfile=None, overwrite=None): 

28 

29 temp_outfile = '' 

30 

31 try: 

32 # CAS-12985 requests the following params be given case insensitively, 

33 # so they need to be converted to lowercase here (2021/1/28 WK) 

34 blfunc = blfunc.lower() 

35 blmode = blmode.lower() 

36 fftmethod = fftmethod.lower() 

37 if isinstance(fftthresh, str): 

38 fftthresh = fftthresh.lower() 

39 

40 if (spw == ''): 

41 spw = '*' 

42 

43 if not os.path.exists(infile): 

44 raise ValueError("infile='" + str(infile) + "' does not exist.") 

45 if (outfile == '') or not isinstance(outfile, str): 

46 outfile = infile.rstrip('/') + '_bs' 

47 casalog.post("outfile is empty or non-string. set to '" + outfile + "'") 

48 if (not overwrite) and dosubtract and os.path.exists(outfile): 

49 raise ValueError("outfile='%s' exists, and cannot overwrite it." % (outfile)) 

50 if (blfunc == 'variable') and not os.path.exists(blparam): 

51 raise ValueError("input file '%s' does not exists" % blparam) 

52 

53 if (blmode == 'fit'): 

54 temp_outfile = _do_fit(infile, datacolumn, antenna, field, spw, timerange, scan, 

55 pol, intent, reindex, maskmode, thresh, avg_limit, minwidth, 

56 edge, dosubtract, blformat, bloutput, blfunc, order, npiece, 

57 applyfft, fftmethod, fftthresh, addwn, rejwn, clipthresh, 

58 clipniter, blparam, verbose, updateweight, sigmavalue, 

59 outfile, overwrite) 

60 elif (blmode == 'apply'): 

61 _do_apply(infile, datacolumn, antenna, field, spw, timerange, scan, pol, intent, 

62 reindex, bltable, updateweight, sigmavalue, outfile, overwrite) 

63 

64 # Remove {WEIGHT|SIGMA}_SPECTRUM columns if updateweight=True (CAS-13161) 

65 if updateweight: 

66 with sdutil.table_manager(outfile, nomodify=False) as mytb: 

67 cols_spectrum = ['WEIGHT_SPECTRUM', 'SIGMA_SPECTRUM'] 

68 cols_remove = [col for col in cols_spectrum if col in mytb.colnames()] 

69 if len(cols_remove) > 0: 

70 mytb.removecols(' '.join(cols_remove)) 

71 

72 # Write history to outfile 

73 if dosubtract: 

74 param_names = sdbaseline.__code__.co_varnames[:sdbaseline.__code__.co_argcount] 

75 var_local = locals() 

76 param_vals = [var_local[p] for p in param_names] 

77 write_history(ms, outfile, 'sdbaseline', param_names, 

78 param_vals, casalog) 

79 

80 finally: 

81 if (not dosubtract): 

82 # Remove (skeleton) outfile 

83 if temp_outfile != '': 

84 outfile = temp_outfile 

85 remove_data(outfile) 

86 

87 

88blformat_item = ['csv', 'text', 'table'] 

89blformat_ext = ['csv', 'txt', 'bltable'] 

90 

91mesg_invalid_wavenumber = 'wrong value given for addwn/rejwn' 

92 

93 

94def remove_data(filename): 

95 if not os.path.exists(filename): 

96 return 

97 

98 if os.path.isdir(filename): 

99 shutil.rmtree(filename) 

100 elif os.path.isfile(filename): 

101 os.remove(filename) 

102 else: 

103 # could be a symlink 

104 os.remove(filename) 

105 

106 

107def is_empty(blformat): 

108 """Check if blformat is empty. 

109 

110 returns True if blformat is None, '', [] and 

111 a string list containing only '' (i.e., ['', '', ..., '']) 

112 """ 

113 if isinstance(blformat, list): 

114 return all(map(is_empty, blformat)) 

115 

116 return not blformat 

117 

118 

119def prepare_for_blformat_bloutput(infile, blformat, bloutput, overwrite): 

120 # force to string list 

121 blformat = force_to_string_list(blformat, 'blformat') 

122 bloutput = force_to_string_list(bloutput, 'bloutput') 

123 

124 # the default bloutput value '' is expanded to a list 

125 # with length of blformat, and with '' throughout. 

126 if (bloutput == ['']): 

127 bloutput *= len(blformat) 

128 

129 # check length 

130 if (len(blformat) != len(bloutput)): 

131 raise ValueError('blformat and bloutput must have the same length.') 

132 

133 # check duplication 

134 if has_duplicate_nonnull_element(blformat): 

135 raise ValueError('duplicate elements in blformat.') 

136 if has_duplicate_nonnull_element_ex(bloutput, blformat): 

137 raise ValueError('duplicate elements in bloutput.') 

138 

139 # fill bloutput items to be output, then rearrange them 

140 # in the order of blformat_item. 

141 bloutput = normalise_bloutput(infile, blformat, bloutput, overwrite) 

142 

143 return blformat, bloutput 

144 

145 

146def force_to_string_list(s, name): 

147 mesg = '%s must be string or list of string.' % name 

148 if isinstance(s, str): 

149 s = [s] 

150 elif isinstance(s, list): 

151 for i in range(len(s)): 

152 if not isinstance(s[i], str): 

153 raise ValueError(mesg) 

154 else: 

155 raise ValueError(mesg) 

156 return s 

157 

158 

159def has_duplicate_nonnull_element(in_list): 

160 # return True if in_list has duplicated elements other than '' 

161 duplicates = [key for key, val in Counter(in_list).items() if val > 1] 

162 len_duplicates = len(duplicates) 

163 

164 if (len_duplicates >= 2): 

165 return True 

166 elif (len_duplicates == 1): 

167 return (duplicates[0] != '') 

168 else: # len_duplicates == 0 

169 return False 

170 

171 

172def has_duplicate_nonnull_element_ex(lst, base): 

173 # lst and base must have the same length. 

174 # 

175 # (1) extract elements from lst and make a new list 

176 # if the element of base with the same index 

177 # is not ''. 

178 # (2) check if the list made in (1) has duplicated 

179 # elements other than ''. 

180 

181 return has_duplicate_nonnull_element( 

182 [lst[i] for i in range(len(lst)) if base[i] != '']) 

183 

184 

185def normalise_bloutput(infile, blformat, bloutput, overwrite): 

186 return [get_normalised_name(infile, blformat, bloutput, item[0], item[1], overwrite) 

187 for item in zip(blformat_item, blformat_ext)] 

188 

189 

190def get_normalised_name(infile, blformat, bloutput, name, ext, overwrite): 

191 fname = '' 

192 blformat_lower = [s.lower() for s in blformat] 

193 if (name in blformat_lower): 

194 fname = bloutput[blformat_lower.index(name)] 

195 if (fname == ''): 

196 fname = infile + '_blparam.' + ext 

197 if os.path.exists(fname): 

198 if overwrite: 

199 remove_data(fname) 

200 else: 

201 raise ValueError(fname + ' exists.') 

202 return fname 

203 

204 

205def output_bloutput_text_header(blformat, bloutput, blfunc, maskmode, infile, outfile): 

206 fname = bloutput[blformat_item.index('text')] 

207 if (fname == ''): 

208 return 

209 

210 with open(fname, 'w') as f: 

211 info = [['Source Table', infile], 

212 ['Output File', outfile if (outfile != '') else infile], 

213 ['Mask mode', maskmode]] 

214 

215 separator = '#' * 60 + '\n' 

216 

217 f.write(separator) 

218 for i in range(len(info)): 

219 f.write('%12s: %s\n' % tuple(info[i])) 

220 f.write(separator) 

221 f.write('\n') 

222 

223 

224def get_temporary_file_name(basename): 

225 name = basename + '_sdbaseline_pid' + str(os.getpid()) + '_' \ 

226 + datetime.datetime.now().strftime('%Y%m%d%H%M%S%f') 

227 return name 

228 

229 

230def parse_wavenumber_param(wn): 

231 if isinstance(wn, bool): 

232 raise ValueError(mesg_invalid_wavenumber) 

233 elif isinstance(wn, list): 

234 __check_positive_or_zero(wn) 

235 wn_uniq = list(set(wn)) 

236 wn_uniq.sort() 

237 return ','.join(__get_strlist(wn_uniq)) 

238 elif isinstance(wn, tuple): 

239 __check_positive_or_zero(wn) 

240 wn_uniq = list(set(wn)) 

241 wn_uniq.sort() 

242 return ','.join(__get_strlist(wn_uniq)) 

243 elif isinstance(wn, int): 

244 __check_positive_or_zero(wn) 

245 return str(wn) 

246 elif isinstance(wn, str): 

247 if '.' in wn: 

248 # case of float value as string 

249 raise ValueError(mesg_invalid_wavenumber) 

250 elif ',' in wn: 

251 # cases 'a,b,c,...' 

252 val0 = wn.split(',') 

253 __check_positive_or_zero(val0) 

254 val = [] 

255 for v in val0: 

256 val.append(int(v)) 

257 res = list(set(val)) # uniq 

258 res.sort() 

259 elif '-' in wn: 

260 # case 'a-b' : return [a,a+1,...,b-1,b] 

261 val = wn.split('-') 

262 __check_positive_or_zero(val) 

263 val = [int(val[0]), int(val[1])] 

264 val.sort() 

265 res = [i for i in range(val[0], val[1] + 1)] 

266 elif '~' in wn: 

267 # case 'a~b' : return [a,a+1,...,b-1,b] 

268 val = wn.split('~') 

269 __check_positive_or_zero(val) 

270 val = [int(val[0]), int(val[1])] 

271 val.sort() 

272 res = [i for i in range(val[0], val[1] + 1)] 

273 elif wn[:2] == '<=' or wn[:2] == '=<': 

274 # cases '<=a','=<a' : return [0,1,...,a-1,a] 

275 val = wn[2:] 

276 __check_positive_or_zero(val) 

277 res = [i for i in range(int(val) + 1)] 

278 elif wn[-2:] == '>=' or wn[-2:] == '=>': 

279 # cases 'a>=','a=>' : return [0,1,...,a-1,a] 

280 val = wn[:-2] 

281 __check_positive_or_zero(val) 

282 res = [i for i in range(int(val) + 1)] 

283 elif wn[0] == '<': 

284 # case '<a' : return [0,1,...,a-2,a-1] 

285 val = wn[1:] 

286 __check_positive_or_zero(val, False) 

287 res = [i for i in range(int(val))] 

288 elif wn[-1] == '>': 

289 # case 'a>' : return [0,1,...,a-2,a-1] 

290 val = wn[:-1] 

291 __check_positive_or_zero(val, False) 

292 res = [i for i in range(int(val))] 

293 elif wn[:2] == '>=' or wn[:2] == '=>': 

294 # cases '>=a','=>a' : return [a,-999], which is 

295 # then interpreted in C++ 

296 # side as [a,a+1,...,a_nyq] 

297 # (CAS-3759) 

298 val = wn[2:] 

299 __check_positive_or_zero(val) 

300 res = [int(val), -999] 

301 elif wn[-2:] == '<=' or wn[-2:] == '=<': 

302 # cases 'a<=','a=<' : return [a,-999], which is 

303 # then interpreted in C++ 

304 # side as [a,a+1,...,a_nyq] 

305 # (CAS-3759) 

306 val = wn[:-2] 

307 __check_positive_or_zero(val) 

308 res = [int(val), -999] 

309 elif wn[0] == '>': 

310 # case '>a' : return [a+1,-999], which is 

311 # then interpreted in C++ 

312 # side as [a+1,a+2,...,a_nyq] 

313 # (CAS-3759) 

314 val0 = wn[1:] 

315 val = int(val0) + 1 

316 __check_positive_or_zero(val) 

317 res = [val, -999] 

318 elif wn[-1] == '<': 

319 # case 'a<' : return [a+1,-999], which is 

320 # then interpreted in C++ 

321 # side as [a+1,a+2,...,a_nyq] 

322 # (CAS-3759) 

323 val0 = wn[:-1] 

324 val = int(val0) + 1 

325 __check_positive_or_zero(val) 

326 res = [val, -999] 

327 else: 

328 # case 'a' 

329 __check_positive_or_zero(wn) 

330 res = [int(wn)] 

331 

332 # return res 

333 return ','.join(__get_strlist(res)) 

334 else: 

335 raise ValueError(mesg_invalid_wavenumber) 

336 

337 

338def __get_strlist(param): 

339 return [str(p) for p in param] 

340 

341 

342def check_fftthresh(fftthresh): 

343 """Validate fftthresh value. 

344 

345 The fftthresh must be one of the following: 

346 (1) positive value (float, integer or string) 

347 (2) 'top' + positive integer value 

348 (3) positive float value + 'sigma' 

349 """ 

350 has_invalid_type = False 

351 val_not_positive = False 

352 

353 if isinstance(fftthresh, bool): 

354 # Checking for bool must precede checking for integer 

355 has_invalid_type = True 

356 elif isinstance(fftthresh, int) or isinstance(fftthresh, float): 

357 if (fftthresh <= 0.0): 

358 val_not_positive = True 

359 elif isinstance(fftthresh, str): 

360 try: 

361 if (3 < len(fftthresh)) and (fftthresh[:3] == 'top'): 

362 if (int(fftthresh[3:]) <= 0): 

363 val_not_positive = True 

364 elif (5 < len(fftthresh)) and (fftthresh[-5:] == 'sigma'): 

365 if (float(fftthresh[:-5]) <= 0.0): 

366 val_not_positive = True 

367 else: 

368 if (float(fftthresh) <= 0.0): 

369 val_not_positive = True 

370 except Exception: 

371 raise ValueError('fftthresh has a wrong format.') 

372 else: 

373 has_invalid_type = True 

374 

375 if has_invalid_type: 

376 raise ValueError('fftthresh must be float or integer or string.') 

377 if val_not_positive: 

378 raise ValueError('threshold given to fftthresh must be positive.') 

379 

380 

381def __check_positive_or_zero(param, allowzero=True): 

382 if isinstance(param, list) or isinstance(param, tuple): 

383 for i in range(len(param)): 

384 __do_check_positive_or_zero(int(param[i]), allowzero) 

385 elif isinstance(param, int): 

386 __do_check_positive_or_zero(param, allowzero) 

387 elif isinstance(param, str): 

388 __do_check_positive_or_zero(int(param), allowzero) 

389 else: 

390 raise ValueError(mesg_invalid_wavenumber) 

391 

392 

393def __do_check_positive_or_zero(param, allowzero): 

394 if (param < 0) or ((param == 0) and not allowzero): 

395 raise ValueError(mesg_invalid_wavenumber) 

396 

397 

398def prepare_for_baselining(**keywords): 

399 params = {} 

400 funcname = 'subtract_baseline' 

401 

402 blfunc = keywords['blfunc'] 

403 keys = ['datacolumn', 'outfile', 'bloutput', 'dosubtract', 'spw', 

404 'updateweight', 'sigmavalue'] 

405 if blfunc in ['poly', 'chebyshev']: 

406 keys += ['blfunc', 'order'] 

407 elif blfunc == 'cspline': 

408 keys += ['npiece'] 

409 funcname += ('_' + blfunc) 

410 elif blfunc == 'sinusoid': 

411 keys += ['applyfft', 'fftmethod', 'fftthresh', 'addwn', 'rejwn'] 

412 funcname += ('_' + blfunc) 

413 elif blfunc == 'variable': 

414 keys += ['blparam', 'verbose'] 

415 funcname += ('_' + blfunc) 

416 else: 

417 raise ValueError("Unsupported blfunc = %s" % blfunc) 

418 if blfunc != 'variable': 

419 keys += ['clip_threshold_sigma', 'num_fitting_max'] 

420 keys += ['linefinding', 'threshold', 'avg_limit', 'minwidth', 'edge'] 

421 for key in keys: 

422 params[key] = keywords[key] 

423 

424 baseline_func = getattr(keywords['sdms'], funcname) 

425 

426 return params, baseline_func 

427 

428 

429def remove_sorted_table_keyword(infile): 

430 res = {'is_sorttab': False, 'sorttab_keywd': '', 'sorttab_name': ''} 

431 

432 with sdutil.table_manager(infile, nomodify=False) as tb: 

433 sorttab_keywd = 'SORTED_TABLE' 

434 if sorttab_keywd in tb.keywordnames(): 

435 res['is_sorttab'] = True 

436 res['sorttab_keywd'] = sorttab_keywd 

437 res['sorttab_name'] = tb.getkeyword(sorttab_keywd) 

438 tb.removekeyword(sorttab_keywd) 

439 

440 return res 

441 

442 

443def restore_sorted_table_keyword(infile, sorttab_info): 

444 if sorttab_info['is_sorttab'] and (sorttab_info['sorttab_name'] != ''): 

445 with sdutil.table_manager(infile, nomodify=False) as tb: 

446 tb.putkeyword(sorttab_info['sorttab_keywd'], sorttab_info['sorttab_name']) 

447 

448 

449def _do_apply(infile, datacolumn, antenna, field, spw, timerange, scan, pol, intent, 

450 reindex, bltable, updateweight, sigmavalue, outfile, overwrite): 

451 if not os.path.exists(bltable): 

452 raise ValueError("file specified in bltable '%s' does not exist." % bltable) 

453 

454 # Note: the condition "infile != outfile" in the following line is for safety 

455 # to prevent from accidentally removing infile by setting outfile=infile. 

456 # Don't remove it. 

457 if overwrite and (infile != outfile) and os.path.exists(outfile): 

458 remove_data(outfile) 

459 

460 sorttab_info = remove_sorted_table_keyword(infile) 

461 

462 with sdutil.tool_manager(infile, singledishms) as mysdms: 

463 selection = ms.msseltoindex(vis=infile, spw=spw, field=field, 

464 baseline=antenna, time=timerange, 

465 scan=scan) 

466 mysdms.set_selection(spw=sdutil.get_spwids(selection), field=field, 

467 antenna=antenna, timerange=timerange, 

468 scan=scan, polarization=pol, intent=intent, 

469 reindex=reindex) 

470 mysdms.apply_baseline_table(bltable=bltable, 

471 datacolumn=datacolumn, 

472 spw=spw, 

473 updateweight=updateweight, 

474 sigmavalue=sigmavalue, 

475 outfile=outfile) 

476 

477 restore_sorted_table_keyword(infile, sorttab_info) 

478 

479 

480def _do_fit(infile, datacolumn, antenna, field, spw, timerange, scan, pol, intent, 

481 reindex, maskmode, thresh, avg_limit, minwidth, edge, dosubtract, blformat, 

482 bloutput, blfunc, order, npiece, applyfft, fftmethod, fftthresh, addwn, 

483 rejwn, clipthresh, clipniter, blparam, verbose, updateweight, sigmavalue, 

484 outfile, overwrite): 

485 

486 temp_outfile = '' 

487 

488 if (not dosubtract) and is_empty(blformat): 

489 raise ValueError("blformat must be specified when dosubtract is False") 

490 

491 blformat, bloutput = prepare_for_blformat_bloutput(infile, blformat, bloutput, overwrite) 

492 

493 output_bloutput_text_header(blformat, bloutput, blfunc, maskmode, infile, outfile) 

494 

495 # Set temporary name for output MS if dosubtract is False and outfile exists 

496 # for not removing/overwriting outfile that already exists 

497 if os.path.exists(outfile): 

498 # Note: the condition "infile != outfile" in the following line is for safety 

499 # to prevent from accidentally removing infile by setting outfile=infile 

500 # Don't remove it. 

501 if dosubtract and overwrite and (infile != outfile): 

502 remove_data(outfile) 

503 elif (not dosubtract): 

504 outfile = get_temporary_file_name(infile) 

505 temp_outfile = outfile 

506 

507 if (blfunc == 'variable'): 

508 sorttab_info = remove_sorted_table_keyword(infile) 

509 elif (blfunc == 'sinusoid'): 

510 addwn = parse_wavenumber_param(addwn) 

511 rejwn = parse_wavenumber_param(rejwn) 

512 check_fftthresh(fftthresh) 

513 

514 with sdutil.tool_manager(infile, singledishms) as mysdms: 

515 selection = ms.msseltoindex(vis=infile, spw=spw, field=field, baseline=antenna, 

516 time=timerange, scan=scan) 

517 mysdms.set_selection(spw=sdutil.get_spwids(selection), field=field, antenna=antenna, 

518 timerange=timerange, scan=scan, polarization=pol, intent=intent, 

519 reindex=reindex) 

520 params, func = prepare_for_baselining(sdms=mysdms, 

521 blfunc=blfunc, 

522 datacolumn=datacolumn, 

523 outfile=outfile, 

524 bloutput=','.join(bloutput), 

525 dosubtract=dosubtract, 

526 spw=spw, 

527 pol=pol, 

528 linefinding=(maskmode == 'auto'), 

529 threshold=thresh, 

530 avg_limit=avg_limit, 

531 minwidth=minwidth, 

532 edge=edge, 

533 order=order, 

534 npiece=npiece, 

535 applyfft=applyfft, 

536 fftmethod=fftmethod, 

537 fftthresh=fftthresh, 

538 addwn=addwn, 

539 rejwn=rejwn, 

540 clip_threshold_sigma=clipthresh, 

541 num_fitting_max=clipniter + 1, 

542 blparam=blparam, 

543 verbose=verbose, 

544 updateweight=updateweight, 

545 sigmavalue=sigmavalue) 

546 func(**params) 

547 

548 if (blfunc == 'variable'): 

549 restore_sorted_table_keyword(infile, sorttab_info) 

550 

551 return temp_outfile