Coverage for /wheeldirectory/casa-6.7.0-12-py3.10.el8/lib/py/lib/python3.10/site-packages/casatasks/private/task_sdbaseline.py: 11%
317 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-10-31 19:53 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-10-31 19:53 +0000
1from collections import Counter
2import datetime
3import os
4import shutil
6from casatasks import casalog
7from casatools import ms as mstool
8from casatools import singledishms
10from . import sdutil
11from .mstools import write_history
13ms = mstool()
16@sdutil.callable_sdtask_decorator
17def sdbaseline(infile=None, datacolumn=None, antenna=None, field=None,
18 spw=None, timerange=None, scan=None, pol=None, intent=None,
19 reindex=None, maskmode=None, thresh=None, avg_limit=None,
20 minwidth=None, edge=None, blmode=None, dosubtract=None,
21 blformat=None, bloutput=None, bltable=None, blfunc=None,
22 order=None, npiece=None, applyfft=None, fftmethod=None,
23 fftthresh=None, addwn=None, rejwn=None, clipthresh=None,
24 clipniter=None, blparam=None, verbose=None,
25 updateweight=None, sigmavalue=None,
26 showprogress=None, minnrow=None,
27 outfile=None, overwrite=None):
29 temp_outfile = ''
31 try:
32 # CAS-12985 requests the following params be given case insensitively,
33 # so they need to be converted to lowercase here (2021/1/28 WK)
34 blfunc = blfunc.lower()
35 blmode = blmode.lower()
36 fftmethod = fftmethod.lower()
37 if isinstance(fftthresh, str):
38 fftthresh = fftthresh.lower()
40 if (spw == ''):
41 spw = '*'
43 if not os.path.exists(infile):
44 raise ValueError("infile='" + str(infile) + "' does not exist.")
45 if (outfile == '') or not isinstance(outfile, str):
46 outfile = infile.rstrip('/') + '_bs'
47 casalog.post("outfile is empty or non-string. set to '" + outfile + "'")
48 if (not overwrite) and dosubtract and os.path.exists(outfile):
49 raise ValueError("outfile='%s' exists, and cannot overwrite it." % (outfile))
50 if (blfunc == 'variable') and not os.path.exists(blparam):
51 raise ValueError("input file '%s' does not exists" % blparam)
53 if (blmode == 'fit'):
54 temp_outfile = _do_fit(infile, datacolumn, antenna, field, spw, timerange, scan,
55 pol, intent, reindex, maskmode, thresh, avg_limit, minwidth,
56 edge, dosubtract, blformat, bloutput, blfunc, order, npiece,
57 applyfft, fftmethod, fftthresh, addwn, rejwn, clipthresh,
58 clipniter, blparam, verbose, updateweight, sigmavalue,
59 outfile, overwrite)
60 elif (blmode == 'apply'):
61 _do_apply(infile, datacolumn, antenna, field, spw, timerange, scan, pol, intent,
62 reindex, bltable, updateweight, sigmavalue, outfile, overwrite)
64 # Remove {WEIGHT|SIGMA}_SPECTRUM columns if updateweight=True (CAS-13161)
65 if updateweight:
66 with sdutil.table_manager(outfile, nomodify=False) as mytb:
67 cols_spectrum = ['WEIGHT_SPECTRUM', 'SIGMA_SPECTRUM']
68 cols_remove = [col for col in cols_spectrum if col in mytb.colnames()]
69 if len(cols_remove) > 0:
70 mytb.removecols(' '.join(cols_remove))
72 # Write history to outfile
73 if dosubtract:
74 param_names = sdbaseline.__code__.co_varnames[:sdbaseline.__code__.co_argcount]
75 var_local = locals()
76 param_vals = [var_local[p] for p in param_names]
77 write_history(ms, outfile, 'sdbaseline', param_names,
78 param_vals, casalog)
80 finally:
81 if (not dosubtract):
82 # Remove (skeleton) outfile
83 if temp_outfile != '':
84 outfile = temp_outfile
85 remove_data(outfile)
88blformat_item = ['csv', 'text', 'table']
89blformat_ext = ['csv', 'txt', 'bltable']
91mesg_invalid_wavenumber = 'wrong value given for addwn/rejwn'
94def remove_data(filename):
95 if not os.path.exists(filename):
96 return
98 if os.path.isdir(filename):
99 shutil.rmtree(filename)
100 elif os.path.isfile(filename):
101 os.remove(filename)
102 else:
103 # could be a symlink
104 os.remove(filename)
107def is_empty(blformat):
108 """Check if blformat is empty.
110 returns True if blformat is None, '', [] and
111 a string list containing only '' (i.e., ['', '', ..., ''])
112 """
113 if isinstance(blformat, list):
114 return all(map(is_empty, blformat))
116 return not blformat
119def prepare_for_blformat_bloutput(infile, blformat, bloutput, overwrite):
120 # force to string list
121 blformat = force_to_string_list(blformat, 'blformat')
122 bloutput = force_to_string_list(bloutput, 'bloutput')
124 # the default bloutput value '' is expanded to a list
125 # with length of blformat, and with '' throughout.
126 if (bloutput == ['']):
127 bloutput *= len(blformat)
129 # check length
130 if (len(blformat) != len(bloutput)):
131 raise ValueError('blformat and bloutput must have the same length.')
133 # check duplication
134 if has_duplicate_nonnull_element(blformat):
135 raise ValueError('duplicate elements in blformat.')
136 if has_duplicate_nonnull_element_ex(bloutput, blformat):
137 raise ValueError('duplicate elements in bloutput.')
139 # fill bloutput items to be output, then rearrange them
140 # in the order of blformat_item.
141 bloutput = normalise_bloutput(infile, blformat, bloutput, overwrite)
143 return blformat, bloutput
146def force_to_string_list(s, name):
147 mesg = '%s must be string or list of string.' % name
148 if isinstance(s, str):
149 s = [s]
150 elif isinstance(s, list):
151 for i in range(len(s)):
152 if not isinstance(s[i], str):
153 raise ValueError(mesg)
154 else:
155 raise ValueError(mesg)
156 return s
159def has_duplicate_nonnull_element(in_list):
160 # return True if in_list has duplicated elements other than ''
161 duplicates = [key for key, val in Counter(in_list).items() if val > 1]
162 len_duplicates = len(duplicates)
164 if (len_duplicates >= 2):
165 return True
166 elif (len_duplicates == 1):
167 return (duplicates[0] != '')
168 else: # len_duplicates == 0
169 return False
172def has_duplicate_nonnull_element_ex(lst, base):
173 # lst and base must have the same length.
174 #
175 # (1) extract elements from lst and make a new list
176 # if the element of base with the same index
177 # is not ''.
178 # (2) check if the list made in (1) has duplicated
179 # elements other than ''.
181 return has_duplicate_nonnull_element(
182 [lst[i] for i in range(len(lst)) if base[i] != ''])
185def normalise_bloutput(infile, blformat, bloutput, overwrite):
186 return [get_normalised_name(infile, blformat, bloutput, item[0], item[1], overwrite)
187 for item in zip(blformat_item, blformat_ext)]
190def get_normalised_name(infile, blformat, bloutput, name, ext, overwrite):
191 fname = ''
192 blformat_lower = [s.lower() for s in blformat]
193 if (name in blformat_lower):
194 fname = bloutput[blformat_lower.index(name)]
195 if (fname == ''):
196 fname = infile + '_blparam.' + ext
197 if os.path.exists(fname):
198 if overwrite:
199 remove_data(fname)
200 else:
201 raise ValueError(fname + ' exists.')
202 return fname
205def output_bloutput_text_header(blformat, bloutput, blfunc, maskmode, infile, outfile):
206 fname = bloutput[blformat_item.index('text')]
207 if (fname == ''):
208 return
210 with open(fname, 'w') as f:
211 info = [['Source Table', infile],
212 ['Output File', outfile if (outfile != '') else infile],
213 ['Mask mode', maskmode]]
215 separator = '#' * 60 + '\n'
217 f.write(separator)
218 for i in range(len(info)):
219 f.write('%12s: %s\n' % tuple(info[i]))
220 f.write(separator)
221 f.write('\n')
224def get_temporary_file_name(basename):
225 name = basename + '_sdbaseline_pid' + str(os.getpid()) + '_' \
226 + datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')
227 return name
230def parse_wavenumber_param(wn):
231 if isinstance(wn, bool):
232 raise ValueError(mesg_invalid_wavenumber)
233 elif isinstance(wn, list):
234 __check_positive_or_zero(wn)
235 wn_uniq = list(set(wn))
236 wn_uniq.sort()
237 return ','.join(__get_strlist(wn_uniq))
238 elif isinstance(wn, tuple):
239 __check_positive_or_zero(wn)
240 wn_uniq = list(set(wn))
241 wn_uniq.sort()
242 return ','.join(__get_strlist(wn_uniq))
243 elif isinstance(wn, int):
244 __check_positive_or_zero(wn)
245 return str(wn)
246 elif isinstance(wn, str):
247 if '.' in wn:
248 # case of float value as string
249 raise ValueError(mesg_invalid_wavenumber)
250 elif ',' in wn:
251 # cases 'a,b,c,...'
252 val0 = wn.split(',')
253 __check_positive_or_zero(val0)
254 val = []
255 for v in val0:
256 val.append(int(v))
257 res = list(set(val)) # uniq
258 res.sort()
259 elif '-' in wn:
260 # case 'a-b' : return [a,a+1,...,b-1,b]
261 val = wn.split('-')
262 __check_positive_or_zero(val)
263 val = [int(val[0]), int(val[1])]
264 val.sort()
265 res = [i for i in range(val[0], val[1] + 1)]
266 elif '~' in wn:
267 # case 'a~b' : return [a,a+1,...,b-1,b]
268 val = wn.split('~')
269 __check_positive_or_zero(val)
270 val = [int(val[0]), int(val[1])]
271 val.sort()
272 res = [i for i in range(val[0], val[1] + 1)]
273 elif wn[:2] == '<=' or wn[:2] == '=<':
274 # cases '<=a','=<a' : return [0,1,...,a-1,a]
275 val = wn[2:]
276 __check_positive_or_zero(val)
277 res = [i for i in range(int(val) + 1)]
278 elif wn[-2:] == '>=' or wn[-2:] == '=>':
279 # cases 'a>=','a=>' : return [0,1,...,a-1,a]
280 val = wn[:-2]
281 __check_positive_or_zero(val)
282 res = [i for i in range(int(val) + 1)]
283 elif wn[0] == '<':
284 # case '<a' : return [0,1,...,a-2,a-1]
285 val = wn[1:]
286 __check_positive_or_zero(val, False)
287 res = [i for i in range(int(val))]
288 elif wn[-1] == '>':
289 # case 'a>' : return [0,1,...,a-2,a-1]
290 val = wn[:-1]
291 __check_positive_or_zero(val, False)
292 res = [i for i in range(int(val))]
293 elif wn[:2] == '>=' or wn[:2] == '=>':
294 # cases '>=a','=>a' : return [a,-999], which is
295 # then interpreted in C++
296 # side as [a,a+1,...,a_nyq]
297 # (CAS-3759)
298 val = wn[2:]
299 __check_positive_or_zero(val)
300 res = [int(val), -999]
301 elif wn[-2:] == '<=' or wn[-2:] == '=<':
302 # cases 'a<=','a=<' : return [a,-999], which is
303 # then interpreted in C++
304 # side as [a,a+1,...,a_nyq]
305 # (CAS-3759)
306 val = wn[:-2]
307 __check_positive_or_zero(val)
308 res = [int(val), -999]
309 elif wn[0] == '>':
310 # case '>a' : return [a+1,-999], which is
311 # then interpreted in C++
312 # side as [a+1,a+2,...,a_nyq]
313 # (CAS-3759)
314 val0 = wn[1:]
315 val = int(val0) + 1
316 __check_positive_or_zero(val)
317 res = [val, -999]
318 elif wn[-1] == '<':
319 # case 'a<' : return [a+1,-999], which is
320 # then interpreted in C++
321 # side as [a+1,a+2,...,a_nyq]
322 # (CAS-3759)
323 val0 = wn[:-1]
324 val = int(val0) + 1
325 __check_positive_or_zero(val)
326 res = [val, -999]
327 else:
328 # case 'a'
329 __check_positive_or_zero(wn)
330 res = [int(wn)]
332 # return res
333 return ','.join(__get_strlist(res))
334 else:
335 raise ValueError(mesg_invalid_wavenumber)
338def __get_strlist(param):
339 return [str(p) for p in param]
342def check_fftthresh(fftthresh):
343 """Validate fftthresh value.
345 The fftthresh must be one of the following:
346 (1) positive value (float, integer or string)
347 (2) 'top' + positive integer value
348 (3) positive float value + 'sigma'
349 """
350 has_invalid_type = False
351 val_not_positive = False
353 if isinstance(fftthresh, bool):
354 # Checking for bool must precede checking for integer
355 has_invalid_type = True
356 elif isinstance(fftthresh, int) or isinstance(fftthresh, float):
357 if (fftthresh <= 0.0):
358 val_not_positive = True
359 elif isinstance(fftthresh, str):
360 try:
361 if (3 < len(fftthresh)) and (fftthresh[:3] == 'top'):
362 if (int(fftthresh[3:]) <= 0):
363 val_not_positive = True
364 elif (5 < len(fftthresh)) and (fftthresh[-5:] == 'sigma'):
365 if (float(fftthresh[:-5]) <= 0.0):
366 val_not_positive = True
367 else:
368 if (float(fftthresh) <= 0.0):
369 val_not_positive = True
370 except Exception:
371 raise ValueError('fftthresh has a wrong format.')
372 else:
373 has_invalid_type = True
375 if has_invalid_type:
376 raise ValueError('fftthresh must be float or integer or string.')
377 if val_not_positive:
378 raise ValueError('threshold given to fftthresh must be positive.')
381def __check_positive_or_zero(param, allowzero=True):
382 if isinstance(param, list) or isinstance(param, tuple):
383 for i in range(len(param)):
384 __do_check_positive_or_zero(int(param[i]), allowzero)
385 elif isinstance(param, int):
386 __do_check_positive_or_zero(param, allowzero)
387 elif isinstance(param, str):
388 __do_check_positive_or_zero(int(param), allowzero)
389 else:
390 raise ValueError(mesg_invalid_wavenumber)
393def __do_check_positive_or_zero(param, allowzero):
394 if (param < 0) or ((param == 0) and not allowzero):
395 raise ValueError(mesg_invalid_wavenumber)
398def prepare_for_baselining(**keywords):
399 params = {}
400 funcname = 'subtract_baseline'
402 blfunc = keywords['blfunc']
403 keys = ['datacolumn', 'outfile', 'bloutput', 'dosubtract', 'spw',
404 'updateweight', 'sigmavalue']
405 if blfunc in ['poly', 'chebyshev']:
406 keys += ['blfunc', 'order']
407 elif blfunc == 'cspline':
408 keys += ['npiece']
409 funcname += ('_' + blfunc)
410 elif blfunc == 'sinusoid':
411 keys += ['applyfft', 'fftmethod', 'fftthresh', 'addwn', 'rejwn']
412 funcname += ('_' + blfunc)
413 elif blfunc == 'variable':
414 keys += ['blparam', 'verbose']
415 funcname += ('_' + blfunc)
416 else:
417 raise ValueError("Unsupported blfunc = %s" % blfunc)
418 if blfunc != 'variable':
419 keys += ['clip_threshold_sigma', 'num_fitting_max']
420 keys += ['linefinding', 'threshold', 'avg_limit', 'minwidth', 'edge']
421 for key in keys:
422 params[key] = keywords[key]
424 baseline_func = getattr(keywords['sdms'], funcname)
426 return params, baseline_func
429def remove_sorted_table_keyword(infile):
430 res = {'is_sorttab': False, 'sorttab_keywd': '', 'sorttab_name': ''}
432 with sdutil.table_manager(infile, nomodify=False) as tb:
433 sorttab_keywd = 'SORTED_TABLE'
434 if sorttab_keywd in tb.keywordnames():
435 res['is_sorttab'] = True
436 res['sorttab_keywd'] = sorttab_keywd
437 res['sorttab_name'] = tb.getkeyword(sorttab_keywd)
438 tb.removekeyword(sorttab_keywd)
440 return res
443def restore_sorted_table_keyword(infile, sorttab_info):
444 if sorttab_info['is_sorttab'] and (sorttab_info['sorttab_name'] != ''):
445 with sdutil.table_manager(infile, nomodify=False) as tb:
446 tb.putkeyword(sorttab_info['sorttab_keywd'], sorttab_info['sorttab_name'])
449def _do_apply(infile, datacolumn, antenna, field, spw, timerange, scan, pol, intent,
450 reindex, bltable, updateweight, sigmavalue, outfile, overwrite):
451 if not os.path.exists(bltable):
452 raise ValueError("file specified in bltable '%s' does not exist." % bltable)
454 # Note: the condition "infile != outfile" in the following line is for safety
455 # to prevent from accidentally removing infile by setting outfile=infile.
456 # Don't remove it.
457 if overwrite and (infile != outfile) and os.path.exists(outfile):
458 remove_data(outfile)
460 sorttab_info = remove_sorted_table_keyword(infile)
462 with sdutil.tool_manager(infile, singledishms) as mysdms:
463 selection = ms.msseltoindex(vis=infile, spw=spw, field=field,
464 baseline=antenna, time=timerange,
465 scan=scan)
466 mysdms.set_selection(spw=sdutil.get_spwids(selection), field=field,
467 antenna=antenna, timerange=timerange,
468 scan=scan, polarization=pol, intent=intent,
469 reindex=reindex)
470 mysdms.apply_baseline_table(bltable=bltable,
471 datacolumn=datacolumn,
472 spw=spw,
473 updateweight=updateweight,
474 sigmavalue=sigmavalue,
475 outfile=outfile)
477 restore_sorted_table_keyword(infile, sorttab_info)
480def _do_fit(infile, datacolumn, antenna, field, spw, timerange, scan, pol, intent,
481 reindex, maskmode, thresh, avg_limit, minwidth, edge, dosubtract, blformat,
482 bloutput, blfunc, order, npiece, applyfft, fftmethod, fftthresh, addwn,
483 rejwn, clipthresh, clipniter, blparam, verbose, updateweight, sigmavalue,
484 outfile, overwrite):
486 temp_outfile = ''
488 if (not dosubtract) and is_empty(blformat):
489 raise ValueError("blformat must be specified when dosubtract is False")
491 blformat, bloutput = prepare_for_blformat_bloutput(infile, blformat, bloutput, overwrite)
493 output_bloutput_text_header(blformat, bloutput, blfunc, maskmode, infile, outfile)
495 # Set temporary name for output MS if dosubtract is False and outfile exists
496 # for not removing/overwriting outfile that already exists
497 if os.path.exists(outfile):
498 # Note: the condition "infile != outfile" in the following line is for safety
499 # to prevent from accidentally removing infile by setting outfile=infile
500 # Don't remove it.
501 if dosubtract and overwrite and (infile != outfile):
502 remove_data(outfile)
503 elif (not dosubtract):
504 outfile = get_temporary_file_name(infile)
505 temp_outfile = outfile
507 if (blfunc == 'variable'):
508 sorttab_info = remove_sorted_table_keyword(infile)
509 elif (blfunc == 'sinusoid'):
510 addwn = parse_wavenumber_param(addwn)
511 rejwn = parse_wavenumber_param(rejwn)
512 check_fftthresh(fftthresh)
514 with sdutil.tool_manager(infile, singledishms) as mysdms:
515 selection = ms.msseltoindex(vis=infile, spw=spw, field=field, baseline=antenna,
516 time=timerange, scan=scan)
517 mysdms.set_selection(spw=sdutil.get_spwids(selection), field=field, antenna=antenna,
518 timerange=timerange, scan=scan, polarization=pol, intent=intent,
519 reindex=reindex)
520 params, func = prepare_for_baselining(sdms=mysdms,
521 blfunc=blfunc,
522 datacolumn=datacolumn,
523 outfile=outfile,
524 bloutput=','.join(bloutput),
525 dosubtract=dosubtract,
526 spw=spw,
527 pol=pol,
528 linefinding=(maskmode == 'auto'),
529 threshold=thresh,
530 avg_limit=avg_limit,
531 minwidth=minwidth,
532 edge=edge,
533 order=order,
534 npiece=npiece,
535 applyfft=applyfft,
536 fftmethod=fftmethod,
537 fftthresh=fftthresh,
538 addwn=addwn,
539 rejwn=rejwn,
540 clip_threshold_sigma=clipthresh,
541 num_fitting_max=clipniter + 1,
542 blparam=blparam,
543 verbose=verbose,
544 updateweight=updateweight,
545 sigmavalue=sigmavalue)
546 func(**params)
548 if (blfunc == 'variable'):
549 restore_sorted_table_keyword(infile, sorttab_info)
551 return temp_outfile