Coverage for /wheeldirectory/casa-6.7.0-12-py3.10.el8/lib/py/lib/python3.10/site-packages/casatasks/private/task_sdatmcor.py: 15%
274 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-10-31 19:53 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-10-31 19:53 +0000
1import collections
2import contextlib
3import itertools
4import os
5import shutil
7import numpy as np
9from casatasks import casalog
10from casatasks.private import sdutil, simutil
11from casatools import ms as mstool
12from casatools import msmetadata, quanta, singledishms
14ut = simutil.simutil()
15qa = quanta()
16sdms = singledishms()
19class ATMParameterConfigurator(collections.abc.Iterator):
20 def __init__(self, key, value, do_config=True):
21 data = [(key, value)] if do_config else []
22 self._iter = iter(data)
24 def __next__(self):
25 return next(self._iter)
28@contextlib.contextmanager
29def open_msmd(path):
30 msmd = msmetadata()
31 msmd.open(path)
32 try:
33 yield msmd
34 finally:
35 msmd.close()
38def _ms_remove(path):
39 if (os.path.exists(path)):
40 if (os.path.isdir(path)):
41 shutil.rmtree(path)
42 else:
43 os.remove(path)
46def get_default_params():
47 # Default constant: taken from atmcor_20200807.py (CSV-3320)
48 atmtype = 2 # atmType parameter for at (1: tropical, 2: mid lat summer, 3: mid lat winter, etc)
49 maxalt = 120 # maxAltitude parameter for at (km)
50 lapserate = -5.6 # dTem_dh parameter for at (lapse rate; K/km)
51 scaleht = 2.0 # h0 parameter for at (water scale height; km)
53 # convolve dTa* spectra with [0.25, 0.5, 0.25] to mimic Hanning spectral response;
54 dosmooth = False
55 # set to True if spectral averaging was not employed for the spw
56 dp = 10.0 # initATMProfile DEFAULT ###
57 dpm = 1.2 # initATMProfile DEFAULT ###
58 return locals()
61def parse_gainfactor(gainfactor):
62 """Parse gainfactor parameter.
64 Parse gainfactor parameter.
66 Args:
67 gainfactor (float, dict, str): gain factor.
68 if float value is given, it applies to all spws.
69 if dict is given, spw id and corresponding factor
70 should be provided as key-value pair.
71 if str is given, it should be the name of caltable.
72 factors are derived as inverse-square of values
73 stored in the caltable.
75 Raises:
76 FileNotFoundError: specified caltable does not exist.
78 Returns:
79 dictionary whose keys are spw id in string while values
80 are the factors to be applied to each spw.
81 dictionary is defined as collections.defaultdict that
82 returns 1.0 as a default value.
83 """
84 gaindict = collections.defaultdict(lambda: 1.0)
85 if isinstance(gainfactor, dict):
86 # make sure keys are str
87 d = dict((str(k), v) for k, v in gainfactor.items())
88 gaindict.update(d)
89 elif isinstance(gainfactor, str):
90 # should be the name of caltable
91 if not os.path.exists(gainfactor):
92 raise FileNotFoundError('"{}" should exist.'.format(gainfactor))
93 with sdutil.table_manager(gainfactor) as tb:
94 if 'FPARAM' in tb.colnames():
95 col = 'FPARAM'
96 elif 'CPARAM' in tb.colnames():
97 col = 'CPARAM'
98 else:
99 raise RuntimeError('{} is not a caltable'.format(gainfactor))
100 spw_list = set(tb.getcol('SPECTRAL_WINDOW_ID'))
101 for spw in spw_list:
102 tsel = tb.query('SPECTRAL_WINDOW_ID=={}'.format(spw))
103 try:
104 v = tsel.getcol(col).real
105 finally:
106 tsel.close()
107 factor = np.mean(1 / np.square(v))
108 gaindict[str(spw)] = factor
109 else:
110 # should be float
111 v = float(gainfactor)
112 gaindict = collections.defaultdict(lambda: v)
113 return gaindict
116def gaindict2list(msname, gaindict):
117 with sdutil.table_manager(os.path.join(msname, 'SPECTRAL_WINDOW')) as tb:
118 nspw = tb.nrows()
120 gainlist = np.ones(nspw, dtype=float)
121 if isinstance(gaindict, collections.defaultdict) and len(gaindict.keys()) == 0:
122 gainlist[:] = gaindict[0]
123 else:
124 for k, v in gaindict.items():
125 spw = int(k)
126 if 0 <= spw and spw < nspw:
127 gainlist[spw] = v
129 return gainlist
132def get_all_spws_from_main(msname):
133 """Extract all spw ids from MAIN table.
135 Extract all spectral window ids that have any
136 associated data in MS MAIN table.
138 Args:
139 msname (str): name of MS
141 Returns:
142 list: list of available spectral window ids
143 """
144 with sdutil.table_manager(msname) as tb:
145 ddids = np.unique(tb.getcol('DATA_DESC_ID'))
146 with open_msmd(msname) as msmd:
147 spws_all = [msmd.spwfordatadesc(ddid) for ddid in ddids]
148 return spws_all
151def get_selected_spws(msname, spw):
152 """
153 Get selected spectral window ids.
155 Args:
156 msname (str): name of MS
157 spw (str): spectral window selection
159 Raises:
160 TypeError: spw is not string
162 Returns:
163 list: list of selected spectral window ids
164 """
165 if not isinstance(spw, str):
166 raise TypeError('spw selection must be string')
167 elif len(spw) == 0:
168 # '' indicates all spws, which is equivalent to '*'
169 spwsel = '*'
170 else:
171 spwsel = spw
172 ms = mstool()
173 sel = ms.msseltoindex(msname, spw=spwsel)
174 return sel['spw']
177def parse_spw(msname, spw=''):
178 """Parse spw selection into list of spw ids.
180 Parse spw selection into list of spw ids that have
181 associated data in the MAIN table of given MS.
183 Args:
184 msname (str): name of MS
185 spw (str): spw selection
187 Raises:
188 TypeError: spw selection is not str
189 RuntimeError: spw selection cause empty result
191 Returns:
192 list: list of selected spw ids
193 """
194 spws_all = get_all_spws_from_main(msname)
195 spws_sel = get_selected_spws(msname, spw)
196 spws = set(spws_all).intersection(set(spws_sel))
197 return list(spws)
200def get_mount_off_source_commands(msname):
201 """Return list of flag commands whose reason is "Mount_is_off_source".
203 Args:
204 msname (str): name of MS
206 Returns:
207 np.ndarray: list of flag commands
208 """
209 with sdutil.table_manager(os.path.join(msname, 'FLAG_CMD')) as tb:
210 if tb.nrows() > 0:
211 tsel = tb.query('REASON=="Mount_is_off_source"')
212 try:
213 commands = tsel.getcol('COMMAND')
214 finally:
215 tsel.close()
216 else:
217 commands = []
218 return commands
221def get_antenna_name(antenna_selection):
222 """Extract antenna name from the antenna selection string.
224 Here, antenna_selection is assumed to be a string
225 in the form '<ANTENNA_NAME>&&*'.
227 Args:
228 antenna_selection (str): antenna selection string
230 Returns:
231 str: antenna name
232 """
233 return antenna_selection.split('=')[1].strip("'&*")
236def get_time_delta(time_range):
237 """Convert time range string into time duration in sec.
239 Here, time_range is assumed to be a string in the form
240 'YYYY/MM/DD/hh:mm:ss~YYYY/MM/DD/hh:mm:ss'
242 Args:
243 time_range (str): time range string
245 Returns:
246 float: time duration in sec
247 """
248 timestrs = time_range.split('=')[1].strip("'").split('~')
249 timequanta = [qa.quantity(t) for t in timestrs]
250 timedelta = qa.convert(qa.sub(timequanta[1], timequanta[0]), 's')['value']
251 return abs(timedelta)
254def cmd_to_ant_and_time(cmd):
255 """Extract antenna name and time duration from the flag command.
257 Args:
258 cmd (str): flag command
260 Returns:
261 tuple: antenna name and time duration
262 """
263 sels = cmd.split()
264 asel = list(filter(lambda x: x.startswith('antenna'), sels))[0]
265 tsel = list(filter(lambda x: x.startswith('time'), sels))[0]
267 antenna_name = get_antenna_name(asel)
268 time_delta = get_time_delta(tsel)
270 return antenna_name, time_delta
273def inspect_flag_cmd(msname):
274 """Inspect FLAG_CMD table.
276 Search flag commands whose reason is Mount_is_off_source and
277 extract antenna name and time duration from the commands.
279 Args:
280 msname (str): name of MS
282 Returns:
283 tuple: two dictionaries containing the inspection result.
284 The first one is number of command counts per antenna
285 while the second one is total duration flagged by the
286 commands per antenna.
287 """
288 commands = get_mount_off_source_commands(msname)
290 cmd_counts = collections.defaultdict(lambda: 0)
291 time_counts = collections.defaultdict(lambda: 0)
293 for cmd in commands:
294 ant, dt = cmd_to_ant_and_time(cmd)
295 cmd_counts[ant] += 1
296 time_counts[ant] += dt
298 return cmd_counts, time_counts
301# Argument parameter handling
302def parse_atm_params(user_param, user_default, task_default, default_unit=''):
303 """Parse ATM parameters.
305 Args:
306 user_param (str,int,float): User input.
307 user_default (str,int,float): User default.
308 task_default (str,int,float): Task default.
309 default_unit (str): Default unit.
311 Raises:
312 ValueError: user_param is invalid.
314 Returns:
315 Tuple: Two-tuple, resulting value as quantity and boolean
316 value indicating if the value is equal to user_default.
317 """
318 is_customized = user_param != user_default and user_param is not None
320 try:
321 if qa.isquantity(task_default):
322 task_default_quanta = qa.quantity(task_default)
323 else:
324 task_default_quanta = qa.quantity(task_default, default_unit)
325 except Exception as e:
326 casalog.post('INTERNAL ERROR: {}'.format(e), priority='SEVERE')
327 raise
329 if not is_customized:
330 param = task_default_quanta['value']
331 else:
332 user_param_quanta = qa.quantity(user_param)
333 if user_param_quanta['unit'] == '':
334 user_param_quanta = qa.quantity(
335 user_param_quanta['value'],
336 default_unit
337 )
338 else:
339 user_param_quanta = qa.convert(
340 user_param_quanta,
341 default_unit
342 )
343 is_compatible = qa.compare(user_param_quanta, task_default_quanta)
344 if is_compatible:
345 param = user_param_quanta['value']
346 else:
347 raise ValueError('User input "{}" should have the unit compatible with "{}"'.format(
348 user_param,
349 default_unit
350 ))
352 return param, is_customized
355def parse_atm_list_params(user_param, user_default='', task_default=[], default_unit=''):
356 """Parse ATM parameters.
358 Args:
359 user_param (str,list): User input.
360 user_default (str): User default.
361 task_default (list): Task default.
362 default_unit (str): Unit for output values.
364 Raises:
365 ValueError: user_param is invalid.
367 Returns:
368 Tuple: Two-tuple, resulting value as quantity and boolean
369 value indicating if the value is equal to user_default.
370 """
371 is_customized = user_param != user_default and user_param is not None
373 if not is_customized:
374 return task_default, is_customized
376 if isinstance(user_param, (list, np.ndarray)):
377 try:
378 param = [parse_atm_params(p, user_default, 0, default_unit=default_unit)[0]
379 for p in user_param]
380 param = [qa.convert(p, default_unit)['value'] for p in param]
381 except Exception as e:
382 casalog.post('ERROR during handling list input: {}'.format(e))
383 raise ValueError('list input "{}" is invalid.'.format(user_param))
384 return param, is_customized
385 elif isinstance(user_param, str):
386 try:
387 split_param = user_param.split(',')
388 param, _ = parse_atm_list_params(split_param, user_default, task_default, default_unit)
389 except Exception as e:
390 casalog.post('ERROR during handling comma-separated str input: {}'.format(e))
391 raise ValueError('str input "{}" is invalid.'.format(user_param))
392 return param, is_customized
393 else:
394 raise ValueError('user_param for parse_atm_list_params should be either list or str.')
397def get_default_antenna(msname):
398 """Determine default antenna id based on the FLAG_CMD table.
400 Procedure is as follows.
402 (1) extract flag commands whose reason is "Mount_is_off_source".
403 (2) compile the commands into a number of commands and flagged
404 time durations for each antenna.
405 (3) select antenna with the shortest flagged duration.
406 (4) if multiple antennas match in (3), select antenna with
407 the least number of commands among them.
408 (5) if multiple antennas match in (4), select the first
409 antenna among them.
411 Args:
412 msname (str): name of MS
414 Raises:
415 Exception: no antenna was found in the MAIN table
417 Returns:
418 int: default antenna id
419 """
420 # get list of antenna Ids from MAIN table
421 with sdutil.table_manager(msname) as tb:
422 ant_list = np.unique(tb.getcol('ANTENNA1'))
424 # No Available antenna
425 if len(ant_list) == 0:
426 raise Exception("No Antenna was found.")
428 # get antenna names list by antenna Id
429 with open_msmd(msname) as msmd:
430 ant_name = [msmd.antennanames(i)[0] for i in ant_list]
432 # dictionary to map antenna name to antenna Id
433 ant_dict = dict((k, v) for k, v in zip(ant_name, ant_list))
435 # determine default antenna id
436 cmd_counts, flagged_durations = inspect_flag_cmd(msname)
438 if len(cmd_counts) == 0:
439 # No flag command exists. All the antennas should be healthy
440 # so just pick up the first antenna.
441 default_id = ant_list[0]
442 default_name = ant_name[0]
443 else:
444 flagged_durations_filtered = dict((k, flagged_durations[k]) for k in ant_dict.keys())
445 min_duration = min(flagged_durations_filtered.values())
446 candidate_antennas = [k for k, v in flagged_durations_filtered.items() if v == min_duration]
448 if len(candidate_antennas) == 1:
449 default_name = candidate_antennas[0]
450 default_id = ant_dict[default_name]
451 else:
452 _counts = [cmd_counts[a] for a in candidate_antennas]
453 min_count = min(_counts)
454 candidate_antennas2 = [a for i, a in enumerate(candidate_antennas)
455 if _counts[i] == min_count]
456 default_name = candidate_antennas2[0]
457 default_id = ant_dict[default_name]
458 casalog.post('Select {} (ID {}) as a default antenna'.format(default_name, default_id))
459 return default_id
462def get_default_altitude(msname, antid):
463 """Get default altitude of the antenna.
465 decide default value of 'Altitude' for Atm Correction.
466 This requires to calculate Elevation from Antenna Position Information.
467 """
468 with sdutil.table_manager(os.path.join(msname, 'ANTENNA')) as tb:
469 # obtain the antenna Position (Earth Center) specified by antid
470 X, Y, Z = (float(i) for i in tb.getcell('POSITION', antid))
472 # xyz2long() -- https://casa.nrao.edu/casadocs/casa-5.6.0/simulation/simutil
473 #
474 # When given ITRF Earth-centered (X, Y, Z, using the parameters x, y, and z)
475 # coordinates [m] for a point,
476 # this method returns geodetic latitude and longitude [radians] and elevation [m].
477 # Elevation is measured relative to the closest point to the (latitude, longitude)
478 # on the WGS84 (World Geodetic System 1984) reference ellipsoid.
480 # [0]:longitude, [1]:latitude, [2]:elevation (geodetic elevation)
481 P = ut.xyz2long(X, Y, Z, 'WGS84')
482 geodetic_elevation = P[2]
484 ref = tb.getcolkeyword('POSITION', 'MEASINFO')['Ref']
486 casalog.post("Default Altitude")
487 casalog.post(" - Antenna ID: %d. " % antid)
488 casalog.post(" - Ref = %s. " % ref)
489 casalog.post(" - Position: (%s, %s, %s)." % (X, Y, Z))
490 casalog.post(" Altitude (geodetic elevation): %f" % geodetic_elevation)
492 return geodetic_elevation
495class ATMScalarParameterConfigurator(ATMParameterConfigurator):
496 def __init__(self, key, user_input, impl_default, default_unit,
497 api_default='', is_mandatory=True, is_effective=True):
498 value, is_customized = parse_atm_params(
499 user_param=user_input, user_default=api_default,
500 task_default=impl_default, default_unit=default_unit)
501 do_config = is_mandatory or (is_effective and is_customized)
502 super().__init__(key=key, value=value, do_config=do_config)
505class ATMListParameterConfigurator(ATMParameterConfigurator):
506 def __init__(self, key, user_input, impl_default, default_unit,
507 api_default='', is_mandatory=True, is_effective=True):
508 value, is_customized = parse_atm_list_params(
509 user_param=user_input, user_default=api_default,
510 task_default=impl_default, default_unit=default_unit)
511 do_config = is_mandatory or (is_effective and is_customized)
512 super().__init__(key=key, value=value, do_config=do_config)
515def get_configuration_for_atmcor(infile, spw, outputspw, gainfactor, user_inputs):
516 # requested list of output spws and processing spws
517 # processing spws are the intersection of these
518 outputspws_param = parse_spw(infile, outputspw)
519 spws_param = parse_spw(infile, spw)
520 all_processing_spws = np.asarray(list(set(spws_param).intersection(set(outputspws_param))))
522 # generate gain factor dictionary
523 gaindict = parse_gainfactor(gainfactor)
524 gainlist = gaindict2list(infile, gaindict)
526 # default parameter values (from Tsuyoshi's original script)
527 default_params = get_default_params()
529 # reference antenna_id to calculate Azimuth/Elevation
530 reference_antenna = int(get_default_antenna(infile))
532 # altitude of reference antenna
533 default_altitude = get_default_altitude(infile, reference_antenna)
534 user_altitude = user_inputs['altitude'] if user_inputs['atmdetail'] else ''
536 parameters = [
537 ATMParameterConfigurator(key='processspw', value=all_processing_spws),
538 ATMParameterConfigurator(key='gainfactor', value=gainlist),
539 ATMParameterConfigurator(key='refant', value=reference_antenna),
540 ATMParameterConfigurator(key='atmType', value=user_inputs['atmtype']),
541 ATMParameterConfigurator(key='maxAltitude', value=float(default_params['maxalt'])),
542 ATMScalarParameterConfigurator(
543 key='lapseRate', user_input=user_inputs['dtem_dh'],
544 impl_default=default_params['lapserate'], default_unit='K/km',
545 ),
546 ATMScalarParameterConfigurator(
547 key='scaleHeight', user_input=user_inputs['h0'],
548 impl_default=default_params['scaleht'], default_unit='km',
549 ),
550 ATMScalarParameterConfigurator(
551 key='pressureStep', user_input=user_inputs['dp'],
552 impl_default=default_params['dp'], default_unit='mbar'
553 ),
554 ATMScalarParameterConfigurator(
555 key='pressureStepFactor', user_input=user_inputs['dpm'],
556 impl_default=default_params['dpm'], default_unit='', api_default=-1
557 ),
558 ATMScalarParameterConfigurator(
559 key='siteAltitude', user_input=user_altitude,
560 impl_default=default_altitude, default_unit='m',
561 api_default='',
562 is_mandatory=True, is_effective=user_inputs['atmdetail']
563 ),
564 ATMScalarParameterConfigurator(
565 key='pressure', user_input=user_inputs['pressure'],
566 impl_default=0, default_unit='mbar',
567 is_mandatory=False, is_effective=user_inputs['atmdetail']
568 ),
569 ATMScalarParameterConfigurator(
570 key='temperature', user_input=user_inputs['temperature'],
571 impl_default=0, default_unit='K',
572 is_mandatory=False, is_effective=user_inputs['atmdetail']
573 ),
574 ATMScalarParameterConfigurator(
575 key='humidity', user_input=user_inputs['humidity'],
576 impl_default=0, default_unit='%',
577 api_default=-1,
578 is_mandatory=False, is_effective=user_inputs['atmdetail']
579 ),
580 ATMScalarParameterConfigurator(
581 key='pwv', user_input=user_inputs['pwv'],
582 impl_default=0, default_unit='mm',
583 is_mandatory=False, is_effective=user_inputs['atmdetail']
584 ),
585 ATMListParameterConfigurator(
586 key='layerBoundaries', user_input=user_inputs['layerboundaries'],
587 impl_default=[], default_unit='m',
588 is_mandatory=False, is_effective=user_inputs['atmdetail']
589 ),
590 ATMListParameterConfigurator(
591 key='layerTemperatures', user_input=user_inputs['layertemperature'],
592 impl_default=[], default_unit='K',
593 is_mandatory=False, is_effective=user_inputs['atmdetail']
594 )
595 ]
597 config = dict(itertools.chain(*parameters))
599 # number of threads for OpenMP
600 # if config['nthreads'] is set to -1, task will decide number of threads automatically
601 config['nthreads'] = int(os.getenv('OMP_NUM_THREADS', -1))
603 return config
606@sdutil.sdtask_decorator
607def sdatmcor(
608 infile=None, datacolumn=None, outfile=None, overwrite=None,
609 field=None, spw=None, scan=None, antenna=None,
610 correlation=None, timerange=None, intent=None,
611 observation=None, feed=None, msselect=None,
612 outputspw=None,
613 gainfactor=None,
614 dtem_dh=None, h0=None, atmtype=None,
615 atmdetail=None,
616 altitude=None, temperature=None, pressure=None, humidity=None, pwv=None,
617 dp=None, dpm=None,
618 layerboundaries=None, layertemperature=None):
620 try:
621 # Input/Output error check and internal set up.
622 if infile == '':
623 errmsg = "infile MUST BE specified."
624 raise Exception(errmsg)
626 if outfile == '':
627 errmsg = "outfile MUST BE specified."
628 raise Exception(errmsg)
630 # Protection, in case infile == outfile
631 if infile == outfile:
632 errmsg = "You are attempting to write the output on your input file."
633 raise Exception(errmsg)
635 # File Info
636 casalog.post("INPUT/OUTPUT")
637 casalog.post(" Input MS file = %s " % infile)
638 casalog.post(" Output MS file = %s " % outfile)
640 # infile Inaccessible
641 if not os.path.exists(infile):
642 errmsg = "Specified infile does not exist."
643 raise Exception(errmsg)
645 # outfile Protected
646 if os.path.exists(outfile):
647 if overwrite:
648 casalog.post("Overwrite: Overwrite specified. Delete the existing output file.")
649 _ms_remove(outfile)
650 else:
651 errmsg = "Specified outfile already exist."
652 raise Exception(errmsg)
654 # Inspect atmtype
655 atmtype_int = int(atmtype)
656 if atmtype_int not in (1, 2, 3, 4, 5):
657 errmsg = "atmtype (=%s) should be any one of (1, 2, 3, 4, 5)." % atmtype
658 raise Exception(errmsg)
660 # Inspect humidity (float). The range must be 0.0 gt. Humidity gt. 100.0 [%]
661 humidity_float = float(humidity)
662 if humidity_float != -1.0 and not (0.0 <= humidity_float and humidity_float <= 100.0):
663 errmsg = "humidity (=%s) should be in range 0~100" % humidity
664 raise Exception(errmsg)
666 # datacolumn check (by XML definition)
667 datacolumn_upper = datacolumn.upper()
668 if datacolumn_upper not in ['DATA', 'CORRECTED', 'FLOAT_DATA']:
669 errmsg = "Specified column name (%s) Unacceptable." % datacolumn
670 raise Exception(errmsg)
672 # tweak antenna selection string to include autocorr data
673 antenna_autocorr = sdutil.get_antenna_selection_include_autocorr(infile, antenna)
675 # C++ re-implementation
676 sdms.open(infile)
677 sdms.set_selection(spw=outputspw, field=field,
678 antenna=antenna_autocorr,
679 timerange=timerange, scan=scan,
680 polarization=correlation, intent=intent,
681 observation=observation, feed=feed,
682 taql=msselect,
683 reindex=False)
685 config = get_configuration_for_atmcor(
686 infile=infile,
687 spw=spw,
688 outputspw=outputspw,
689 gainfactor=gainfactor,
690 user_inputs=locals()
691 )
693 sdms.atmcor(config=config, datacolumn=datacolumn, outfile=outfile)
695 except Exception as err:
696 casalog.post('%s' % err, priority='SEVERE')
697 raise