Coverage for /wheeldirectory/casa-6.7.0-12-py3.10.el8/lib/py/lib/python3.10/site-packages/casatasks/private/partitionhelper.py: 61%
626 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-01 07:19 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-01 07:19 +0000
1import os
2import sys
3import shutil
4import pprint as pp
5import traceback
6import time
7import numpy as np
8from matplotlib import pyplot as plt
10from casatasks import casalog
11from casatools import table, ms, msmetadata
13import subprocess
15mst_local = ms()
16tbt_local = table()
17msmdt_local = msmetadata()
19class convertToMMS():
20 def __init__(self,\
21 inpdir=None, \
22 mmsdir=None, \
23 axis='auto', \
24 numsubms=4,
25# createmslink=False, \
26 cleanup=False):
28 '''Run the partition task to create MMSs from a directory with MSs'''
29 casalog.origin('convertToMMS')
31 self.inpdir = inpdir
32 self.outdir = mmsdir
33 self.axis = axis
34 self.numsubms = numsubms
35# self.createmslink = createmslink
36 self.mmsdir = '/tmp/mmsdir'
37 self.cleanup = cleanup
39 # Input directory is mandatory
40 if self.inpdir is None:
41 casalog.post('You must give an input directory to this script')
42 self.usage()
43 return
45 if not os.path.exists(self.inpdir):
46 casalog.post('Input directory inpdir does not exist -> '+self.inpdir,'ERROR')
47 self.usage()
48 return
50 if not os.path.isdir(self.inpdir):
51 casalog.post('Value of inpdir is not a directory -> '+self.inpdir,'ERROR')
52 self.usage()
53 return
56 # Only work with absolute paths
57 self.inpdir = os.path.abspath(self.inpdir)
58 casalog.post('Will read input MS from '+self.inpdir)
60 # Verify output directory
61 if self.outdir is None:
62 self.mmsdir = os.path.join(os.getcwd(),'mmsdir')
63 elif self.outdir == '/':
64 casalog.post('inpdir is set to root!', 'WARN')
65 self.mmsdir = os.path.join(os.getcwd(),'mmsdir')
66 else:
67 self.outdir = os.path.abspath(self.outdir)
68 self.mmsdir = self.outdir
70 if self.mmsdir == self.inpdir:
71 casalog.post('Output directory cannot be same of input directory','ERROR')
72 return
74 # Cleanup output directory
75 if self.cleanup:
76 casalog.post('Cleaning up output directory '+self.mmsdir)
77 if os.path.isdir(self.mmsdir):
78 shutil.rmtree(self.mmsdir)
80 if not os.path.exists(self.mmsdir):
81 os.makedirs(self.mmsdir)
84 casalog.post('Will save output MMS to '+self.mmsdir)
86 # Walk through input directory
87 files = os.walk(self.inpdir,followlinks=True).next()
89 # Get MS list
90 mslist = []
91 mslist = self.getMSlist(files)
93 casalog.post('List of MSs in input directory')
94 casalog.post(pp.pformat(mslist))
96 # Get non-MS directories and other files
97 nonmslist = []
98 nonmslist = self.getFileslist(files)
100 casalog.post('List of other files in input directory')
101 casalog.post(pp.pformat(nonmslist))
104 # Create an MMS for each MS in list
105 for ms in mslist:
106 casalog.post('Will create an MMS for '+ms)
107 ret = self.runPartition(ms, self.mmsdir, self.axis, self.numsubms)
108 if not ret:
109 sys.exit(2)
111 # Verify later if this is still needed
112 time.sleep(10)
114 casalog.origin('convertToMMS')
115 casalog.post('--------------- Successfully created MMS -----------------')
118 # Copy non-MS files to MMS directory
119 for nfile in nonmslist:
120 bfile = os.path.basename(nfile)
121 lfile = os.path.join(self.mmsdir, bfile)
122 casalog.post('Copying non-MS file '+bfile)
123# os.symlink(file, lfile)
124# shutil.copytree(nfile, lfile, symlinks=False)
125 os.system("cp -RL {0} {1}".format(nfile, lfile))
128 def getMSlist(self, files):
129 '''Get a list of MSs from a directory.
130 files -> a tuple that is returned by the following call:
131 files = os.walk(self.inpdir,followlinks=True).next()
133 It will test if a directory is an MS and will only return
134 true MSs, that have Type:Measurement Set in table.info. It will skip
135 directories that start with . and those that do not end with
136 extension .ms.
137 '''
139 topdir = files[0]
140 mslist = []
142 # Loop through list of directories
143 for d in files[1]:
144 # Skip . entries
145 if d.startswith('.'):
146 continue
148# if not d.endswith('.ms'):
149# continue
151 # Full path for directory
152 mydir = os.path.join(topdir,d)
154 # It is probably an MS
155 if self.isItMS(mydir) == 1:
156 mslist.append(mydir)
158 return mslist
160 def isItMS(self, mydir):
161 '''Check the type of a directory.
162 mydir --> full path of a directory.
163 Returns 1 for an MS, 2 for a cal table and 3 for a MMS.
164 If 0 is returned, it means any other type or an error.'''
166 ret = 0
168 # Listing of this directory
169 ldir = os.listdir(mydir)
171 if not ldir.__contains__('table.info'):
172 return ret
174 cmd1 = 'grep Type '+mydir+'/table.info'
175 cmd2 = 'grep SubType '+mydir+'/table.info'
176 mytype = subprocess.getoutput(cmd1).rstrip("\n")
177 stype = subprocess.getoutput(cmd2).rstrip("\n")
179 # It is a cal table
180 if mytype.__contains__('Calibration'):
181 ret = 2
183 elif mytype.__contains__('Measurement'):
184 # It is a Multi-MS
185 if stype.__contains__('CONCATENATED'):
186 # Further check
187 if ldir.__contains__('SUBMSS'):
188 ret = 3
189 # It is an MS
190 else:
191 ret = 1
193 return ret
196 def getFileslist(self, files):
197 '''Get a list of non-MS files from a directory.
198 files -> a tuple that is returned by the following call:
199 files = os.walk(self.inpdir,followlinks=True).next()
201 It will return files and directories that are not MSs. It will skip
202 files that start with .
203 '''
205 topdir = files[0]
206 fileslist = []
208 # Get other directories that are not MSs
209 for d in files[1]:
211 # Skip . entries
212 if d.startswith('.'):
213 continue
215 # Skip MS directories
216 if d.endswith('.ms'):
217 continue
219 # Full path for directory
220 mydir = os.path.join(topdir,d)
222 # It is not an MS
223 if self.isItMS(mydir) != 1:
224 fileslist.append(mydir)
227 # Get non-directory files
228 for f in files[2]:
229 # Skip . entries
230 if f.startswith('.'):
231 continue
233 # Full path for file
234 myfile = os.path.join(topdir, f)
235 fileslist.append(myfile)
237 return fileslist
240 def runPartition(self, ms, mmsdir, axis, subms):
241 '''Run partition with default values to create an MMS.
242 ms --> full pathname of the MS
243 mmsdir --> directory to save the MMS to
244 axis --> separationaxis to use (spw, scan, auto)
245 subms --> number of subMss to create
247 '''
248 try:
249 # CASA 6
250 from casatasks import partition
251 except ImportError:
252 # CASA 5
253 from tasks import partition
255 if not os.path.lexists(ms):
256 return False
258 # Create MMS name
259# bname = os.path.basename(ms)
260# if bname.endswith('.ms'):
261# mmsname = bname.replace('.ms','.mms')
262# else:
263# mmsname = bname+'.mms'
265 # Create MMS with the same name of the MS, but in a different location
266 MSBaseName = os.path.basename(ms)
267 MMSFullName = os.path.join(self.mmsdir, MSBaseName)
268 if os.path.lexists(MMSFullName):
269 casalog.post('Output MMS already exist -->'+MMSFullName,'ERROR')
270 return False
272 casalog.post('Output MMS will be: '+MMSFullName)
274# mms = os.path.join(self.mmsdir, mmsname)
275# if os.path.lexists(mms):
276# casalog.post('Output MMS already exist -->'+mms,'ERROR')
277# return False
279 # Check for remainings of corrupted mms
280# corrupted = mms.replace('.mms','.data')
281 corrupted = MMSFullName + '.data'
282 if os.path.exists(corrupted):
283 casalog.post('Cleaning up left overs','WARN')
284 shutil.rmtree(corrupted)
286 # Run partition
287 partition(vis=ms, outputvis=MMSFullName, createmms=True, datacolumn='all', flagbackup=False,
288 separationaxis=axis, numsubms=subms)
289 casalog.origin('convertToMMS')
291 # Check if MMS was created
292 if not os.path.exists(MMSFullName):
293 casalog.post('Cannot create MMS ->'+MMSFullName, 'ERROR')
294 return False
296 # If requested, create a link to this MMS with the original MS name
297# if createlink:
298# here = os.getcwd()
299# os.chdir(mmsdir)
300# mmsname = os.path.basename(mms)
301## lms = mmsname.replace('.mms', '.ms')
302# casalog.post('Creating symbolic link to MMS')
303## os.symlink(mmsname, lms)
304# os.symlink(mmsname, bname)
305# os.chdir(here)
307 return True
309 def usage(self):
310 casalog.post('=========================================================================')
311 casalog.post(' convertToMMS will create a directory with multi-MSs.')
312 casalog.post('Usage:\n')
313 casalog.post(' import partitionhelper as ph')
314 casalog.post(' ph.convertToMMS(inpdir=\'dir\') \n')
315 casalog.post('Options:')
316 casalog.post(' inpdir <dir> directory with input MS.')
317 casalog.post(' mmsdir <dir> directory to save output MMS. If not given, it will save ')
318 casalog.post(' the MMS in a directory called mmsdir in the current directory.')
319 casalog.post(" axis='auto' separationaxis parameter of partition (spw,scan,auto).")
320 casalog.post(" numsubms=4 number of subMSs to create in output MMS")
321 casalog.post(' cleanup=False if True it will remove the output directory before starting.\n')
323 casalog.post(' NOTE: this script will run using the default values of partition. It will try to ')
324 casalog.post(' create an MMS for every MS in the input directory. It will skip non-MS directories ')
325 casalog.post(' such as cal tables. If partition succeeds, the script will create a link to every ')
326 casalog.post(' other directory or file in the output directory. ')
327 casalog.post(' The script will not walk through sub-directories of inpdir. It will also skip ')
328 casalog.post(' files or directories that start with a .')
329 casalog.post('==========================================================================')
330 return
332#
333# -------------- HELPER functions for dealing with an MMS --------------
334#
335# getMMSScans 'Get the list of scans of an MMS dictionary'
336# getScanList 'Get the list of scans of an MS or MMS'
337# getScanNrows 'Get the number of rows of a scan in a MS. It will add the
338# nrows of all sub-scans.'
339# getMMSScanNrows 'Get the number of rows of a scan in an MMS dictionary.'
340# getSpwIds 'Get the Spw IDs of a scan.'
341# getDiskUsage 'eturn the size in bytes of an MS in disk.'
342#
343# ----------------------------------------------------------------------
345# def getNumberOf(msfile, item='row'):
346# '''Using the msmd tool, it gets the number of
347# scan, spw, antenna, baseline, field, state,
348# channel, row in a MS or MMS'''
349#
350# md = msmdtool() # or msmd() in CASA 6
351# try:
352# md.open(msfile)
353# except:
354# casalog.post('Cannot open the msfile')
355# return 0
356#
357# if item == 'row':
358# numof = md.nrows()
359# elif item == 'scan':
360# numof = md.nscans()
361# elif item == 'spw':
362# numof = md.nspw()
363# elif item == 'antenna':
364# numof = md.nantennas()
365# elif item == 'baseline':
366# numof = md.nbaselines()
367# elif item == 'channel':
368# numof = md.nchan()
369# elif item == 'field':
370# numof = md.nfields()
371# elif item == 'state':
372# numof = md.nstates()
373# else:
374# numof = 0
375#
376# md.close()
377# return numof
380# NOTE
381# There is a bug in ms.getscansummary() that does not give the scans for all
382# observation Ids, but only for the last one. See CAS-4409
383def getMMSScans(mmsdict):
384 '''Get the list of scans of an MMS dictionary.
385 mmsdict --> output dictionary from listpartition(MMS,createdict=true)
386 Return a list of the scans in this MMS. '''
388 if not isinstance(mmsdict, dict):
389 casalog.post('ERROR: Input is not a dictionary', 'ERROR')
390 return []
392 tkeys = mmsdict.keys()
393 scanlist = []
394 slist = set(scanlist)
395 for k in tkeys:
396 skeys = mmsdict[k]['scanId'].keys()
397 for j in skeys:
398 slist.add(j)
400 return list(slist)
402def getScanList(msfile, selection={}):
403 '''Get the list of scans of an MS or MMS.
404 msfile --> name of MS or MMS
405 selection --> dictionary with data selection
407 Return a list of the scans in this MS/MMS. '''
409 mst_local.open(msfile)
410 if isinstance(selection, dict) and selection != {}:
411 mst_local.msselect(items=selection)
413 scand = mst_local.getscansummary()
414 mst_local.close()
416 scanlist = scand.keys()
418 return scanlist
421def getScanNrows(msfile, myscan, selection={}):
422 '''Get the number of rows of a scan in a MS. It will add the nrows of all sub-scans.
423 This will not take into account any selection done on the MS.
424 msfile --> name of the MS or MMS
425 myscan --> scan ID (int)
426 selection --> dictionary with data selection
428 Return the number of rows in the scan.
430 To compare with the dictionary returned by listpartition, do the following:
432 resdict = listpartition('file.mms', createdict=True)
433 slist = ph.getMMSScans(thisdict)
434 for s in slist:
435 mmsN = ph.getMMSScanNrows(thisdict, s)
436 msN = ph.getScanNrows('referenceMS', s)
437 assert (mmsN == msN)
438 '''
439 mst_local.open(msfile)
440 if isinstance(selection, dict) and selection != {}:
441 mst_local.msselect(items=selection)
443 scand = mst_local.getscansummary()
444 mst_local.close()
446 Nrows = 0
447 if not str(myscan) in scand:
448 return Nrows
450 subscans = scand[str(myscan)]
451 for ii in subscans.keys():
452 Nrows += scand[str(myscan)][ii]['nRow']
454 return Nrows
457def getMMSScanNrows(thisdict, myscan):
458 '''Get the number of rows of a scan in an MMS dictionary.
459 thisdict --> output dictionary from listpartition(MMS,createdict=true)
460 myscan --> scan ID (int)
461 Return the number of rows in the given scan. '''
463 if not isinstance(thisdict, dict):
464 casalog.post('ERROR: Input is not a dictionary', 'ERROR')
465 return -1
467 tkeys = thisdict.keys()
468 scanrows = 0
469 for k in tkeys:
470 if myscan in thisdict[k]['scanId']:
471 scanrows += thisdict[k]['scanId'][myscan]['nrows']
473 return scanrows
476def getSpwIds(msfile, myscan, selection={}):
477 '''Get the Spw IDs of a scan.
478 msfile --> name of the MS or MMS
479 myscan --> scan Id (int)
480 selection --> dictionary with data selection
482 Return a list with the Spw IDs. Note that the returned spw IDs are sorted.
484 '''
485 import numpy as np
487 mst_local.open(msfile)
488 if isinstance(selection, dict) and selection != {}:
489 mst_local.msselect(items=selection)
491 scand = mst_local.getscansummary()
492 mst_local.close()
494 spwlist = []
496 if not str(myscan) in scand:
497 return spwlist
499 subscans = scand[str(myscan)]
500 aspws = np.array([],dtype=int)
502 for ii in subscans.keys():
503 sscanid = ii
504 spwids = scand[str(myscan)][sscanid]['SpwIds']
505 aspws = np.append(aspws,spwids)
507 # Sort spws and remove duplicates
508 aspws.sort()
509 uniquespws = np.unique(aspws)
511 # Try to return a list
512 spwlist = uniquespws.ravel().tolist()
513 return spwlist
516def getScanSpwSummary(mslist=[]):
517 """ Get a consolidated dictionary with scan, spw, channel information
518 of a list of MSs. It adds the nrows of all sub-scans of a scan.
520 Keyword arguments:
521 mslist --> list with names of MSs
523 Returns a dictionary such as:
524 mylist=['subms1.ms','subms2.ms']
525 outdict = getScanSpwSummary(mylist)
526 outdict = {0: {'MS': 'subms1.ms',
527 'scanId': {30: {'nchans': array([64, 64]),
528 'nrows': 544,
529 'spwIds': array([ 0, 1])}},
530 'size': '214M'},
531 1: {'MS': 'ngc5921.ms',
532 'scanId': {1: {'nchans': array([63]),
533 'nrows': 4509,
534 'spwIds': array([0])},
535 2: {'nchans': array([63]),
536 'nrows': 1890,
537 'spwIds': array([0])}},
538 'size': '72M'}}
539 """
541 if mslist == []:
542 return {}
544 # Create lists for scan and spw dictionaries of each MS
545 msscanlist = []
546 msspwlist = []
548 # List with sizes in bytes per sub-MS
549 sizelist = []
551 # Loop through all MSs
552 for subms in mslist:
553 try:
554 mst_local.open(subms)
555 scans = mst_local.getscansummary()
556 msscanlist.append(scans)
557 spws = mst_local.getspectralwindowinfo()
558 msspwlist.append(spws)
559 except Exception as exc:
560 raise Exception('Cannot get scan/spw information from subMS: {0}'.format(exc))
561 finally:
562 mst_local.close()
564 # Get the data volume in bytes per sub-MS
565 sizelist.append(getDiskUsage(subms))
567 # Get the information to list in output
568 # Dictionary to return
569 outdict = {}
571 for ims in range(mslist.__len__()):
572 # Create temp dictionary for each sub-MS
573 tempdict = {}
574 msname = os.path.basename(mslist[ims])
575 tempdict['MS'] = msname
576 tempdict['size'] = sizelist[ims]
578 # Get scan dictionary for this sub-MS
579 scandict = msscanlist[ims]
581 # Get spw dictionary for this sub-MS
582 # NOTE: the keys of spwdict.keys() are NOT the spw Ids
583 spwdict = msspwlist[ims]
585 # The keys are the scan numbers
586 scanlist = scandict.keys()
588 # Get information per scan
589 tempdict['scanId'] = {}
590 for scan in scanlist:
591 newscandict = {}
592 subscanlist = scandict[scan].keys()
594 # Get spws and nrows per sub-scan
595 nrows = 0
596 aspws = np.array([],dtype='int32')
597 for subscan in subscanlist:
598 nrows += scandict[scan][subscan]['nRow']
600 # Get the spws for each sub-scan
601 spwids = scandict[scan][subscan]['SpwIds']
602 aspws = np.append(aspws,spwids)
604 newscandict['nrows'] = nrows
606 # Sort spws and remove duplicates
607 aspws.sort()
608 uniquespws = np.unique(aspws)
609 newscandict['spwIds'] = uniquespws
611 # Array to hold channels
612 charray = np.empty_like(uniquespws)
613 spwsize = np.size(uniquespws)
615 # Now get the number of channels per spw
616 for ind in range(spwsize):
617 spwid = uniquespws[ind]
618 for sid in spwdict.keys():
619 if spwdict[sid]['SpectralWindowId'] == spwid:
620 nchans = spwdict[sid]['NumChan']
621 charray[ind] = nchans
622 continue
624 newscandict['nchans'] = charray
625 tempdict['scanId'][int(scan)] = newscandict
628 outdict[ims] = tempdict
629 #casalog.post(pp.format(outdict))
631 return outdict
634def getMMSSpwIds(thisdict):
635 '''Get the list of spws from an MMS dictionary.
636 thisdict --> output dictionary from listpartition(MMS,createdict=true)
637 Return a list of the spw Ids in the dictionary. '''
639 import numpy as np
641 if not isinstance(thisdict, dict):
642 casalog.post('ERROR: Input is not a dictionary', 'ERROR')
643 return []
645 tkeys = thisdict.keys()
647 aspws = np.array([],dtype='int32')
648 for k in tkeys:
649 scanlist = thisdict[k]['scanId'].keys()
650 for s in scanlist:
651 spwids = thisdict[k]['scanId'][s]['spwIds']
652 aspws = np.append(aspws, spwids)
654 # Sort spws and remove duplicates
655 aspws.sort()
656 uniquespws = np.unique(aspws)
658 # Try to return a list
659 spwlist = uniquespws.ravel().tolist()
661 return spwlist
663def getSubMSSpwIds(subms, thisdict):
665 import numpy as np
666 tkeys = thisdict.keys()
667 aspws = np.array([],dtype='int32')
668 mysubms = os.path.basename(subms)
669 for k in tkeys:
670 if thisdict[k]['MS'] == mysubms:
671 # get the spwIds of this subMS
672 scanlist = thisdict[k]['scanId'].keys()
673 for s in scanlist:
674 spwids = thisdict[k]['scanId'][s]['spwIds']
675 aspws = np.append(aspws, spwids)
676 break
678 # Sort spws and remove duplicates
679 aspws.sort()
680 uniquespws = np.unique(aspws)
682 # Try to return a list
683 spwlist = uniquespws.ravel().tolist()
684 return spwlist
686def getDiskUsage(msfile):
687 """Return the size in bytes of an MS or MMS in disk.
689 Keyword arguments:
690 msfile --> name of the MS
691 This function will return a value given by the command du -hs
692 """
694 from subprocess import Popen, PIPE, STDOUT
696 # Command line to run
697 ducmd = 'du -hs {0}'.format(msfile)
699 p = Popen(ducmd, shell=True, stdin=None, stdout=PIPE, stderr=STDOUT, close_fds=True)
700 o, e = p.communicate() ### previously 'sizeline = p.stdout.read()' here
701 ### left process running...
702 sizeline = o.decode( ).split( )[0]
704 # Create a list of the output string, which looks like this:
705 # ' 75M\tuidScan23.data/uidScan23.0000.ms\n'
706 # This will create a list with [size,sub-ms]
707 mssize = sizeline.split()
709 return mssize[0]
712def getSubtables(vis):
713 theSubTables = []
714 tbt_local.open(vis)
715 myKeyw = tbt_local.getkeywords()
716 tbt_local.close()
717 for k in myKeyw.keys():
718 theKeyw = myKeyw[k]
719 if (type(theKeyw)==str and theKeyw.split(' ')[0]=='Table:'
720 and not k=='SORTED_TABLE'):
721 theSubTables.append(os.path.basename(theKeyw.split(' ')[1]))
723 return theSubTables
726def makeMMS(outputvis, submslist, copysubtables=False, omitsubtables=[], parallelaxis=''):
727 """Create a Multi-MS from a list of MSs
729 Keyword arguments:
730 outputvis -- name of the output MMS
731 submslist -- list of input subMSs to create the output from
732 copysubtables -- True will copy the sub-tables from the first subMS to the others in the
733 output MMS. Default to False.
734 omitsubtables -- List of sub-tables to omit when copying to output MMS. They will be linked instead
735 parallelasxis -- Optionally, set the value to be written to AxisType in table.info of the output MMS
736 Usually this value comes from the separationaxis keyword of partition or mstransform.
738 Be AWARE that this function will remove the tables listed in submslist.
739 """
741 if os.path.exists(outputvis):
742 raise ValueError('Output MS already exists')
744 if len(submslist)==0:
745 raise ValueError('No SubMSs given')
747 ## make an MMS with all sub-MSs contained in a SUBMSS subdirectory
748 origpath = os.getcwd()
750 try:
751 try:
752 mst_local.createmultims(outputvis,
753 submslist,
754 [],
755 True, # nomodify
756 False, # lock
757 copysubtables,
758 omitsubtables
759 ) # when copying the subtables, omit these
761 except Exception:
762 raise
763 finally:
764 mst_local.close()
766 # remove the SORTED_TABLE keywords because the sorting is not reliable after partitioning
767 try:
768 tbt_local.open(outputvis, nomodify=False)
769 if 'SORTED_TABLE' in tbt_local.keywordnames():
770 tbt_local.removekeyword('SORTED_TABLE')
771 tbt_local.close()
773 for thesubms in submslist:
774 tbt_local.open(outputvis+'/SUBMSS/'+os.path.basename(thesubms), nomodify=False)
775 if 'SORTED_TABLE' in tbt_local.keywordnames():
776 tobedel = tbt_local.getkeyword('SORTED_TABLE').split(' ')[1]
777 tbt_local.removekeyword('SORTED_TABLE')
778 os.system('rm -rf '+tobedel)
779 tbt_local.close()
780 except Exception:
781 tbt_local.close()
782 raise
784 # Create symbolic links to the subtables of the first SubMS in the reference MS (top one)
785 os.chdir(outputvis)
786 mastersubms = os.path.basename(submslist[0].rstrip('/'))
787 thesubtables = getSubtables('SUBMSS/'+mastersubms)
789 for s in thesubtables:
790 os.symlink('SUBMSS/'+mastersubms+'/'+s, s)
792 os.chdir('SUBMSS/'+mastersubms)
794 # Remove the SOURCE and HISTORY tables, which should not be linked
795 thesubtables.remove('SOURCE')
796 thesubtables.remove('HISTORY')
798 # Create sym links to all sub-tables in all subMSs
799 for i in range(1,len(submslist)):
800 thesubms = os.path.basename(submslist[i].rstrip('/'))
801 os.chdir('../'+thesubms)
803 for s in thesubtables:
804 os.system('rm -rf '+s)
805 os.symlink('../'+mastersubms+'/'+s, s)
807 # Write the AxisType info in the MMS
808 if parallelaxis != '':
809 setAxisType(outputvis, parallelaxis)
811 except Exception as exc:
812 os.chdir(origpath)
813 raise ValueError('Problem in MMS creation: {0}'.format(exc))
815 os.chdir(origpath)
817 return True
819def axisType(mmsname):
820 """Get the axisType information from a Multi-MS. The AxisType information
821 is usually added for Multi-MS with the axis which data is parallelized across.
823 Keyword arguments:
824 mmsname -- name of the Multi-MS
826 It returns the value of AxisType or an empty string if it doesn't exist.
827 """
829 axis = ''
831 try:
832 tbt_local.open(mmsname, nomodify=True)
833 tbinfo = tbt_local.info()
834 except Exception as exc:
835 raise ValueError('Unable to open table {0}. Exception: {1}'.format(mmsname, exc))
836 finally:
837 tbt_local.close()
839 if 'readme' in tbinfo:
840 readme = tbinfo['readme']
841 readlist = readme.splitlines()
842 for val in readlist:
843 if val.__contains__('AxisType'):
844 a,b,axis = val.partition('=')
846 return axis.strip()
848def setAxisType(mmsname, axis=''):
849 """Set the AxisType keyword in a Multi-MS info. If AxisType already
850 exists, it will be overwritten.
852 Keyword arguments:
853 mmsname -- name of the Multi-MS
854 axis -- parallel axis of the Multi-MS. Options: scan; spw or scan,spw
856 Return True on success, False otherwise.
857 """
859 import copy
861 if axis == '':
862 raise ValueError('Axis value cannot be empty')
864 try:
865 tbt_local.open(mmsname)
866 tbinfo = tbt_local.info()
867 except Exception as exc:
868 raise ValueError('Unable to open table {0}. Exception: {1}'.format(mmsname, exc))
869 finally:
870 tbt_local.close()
872 readme = ''
873 # Save original readme
874 if 'readme' in tbinfo:
875 readme = tbinfo['readme']
877 # Check if AxisType already exist and remove it
878 if axisType(mmsname) != '':
879 casalog.post('WARN: Will overwrite the existing AxisType value', 'WARN')
880 readlist = readme.splitlines()
881 newlist = copy.deepcopy(readlist)
882 for val in newlist:
883 if val.__contains__('AxisType'):
884 readlist.remove(val)
886 # Recreate the string
887 nr = ''
888 for val in readlist:
889 nr = nr + val + '\n'
891 readme = nr.rstrip()
894 # Preset for axis info
895 axisInfo = "AxisType = "
896 axis.rstrip()
897 axisInfo = axisInfo + axis + '\n'
899 # New readme
900 newReadme = axisInfo + readme
902 # Create readme record
903 readmerec = {'readme':newReadme}
905 try:
906 tbt_local.open(mmsname, nomodify=False)
907 tbt_local.putinfo(readmerec)
908 except Exception as exc:
909 raise ValueError('Unable to put readme info into table {0}. Exception: {1}'.
910 format(mmsname, exc))
911 finally:
912 tbt_local.close()
914 # Check if the axis was correctly added
915 check_axis = axisType(mmsname)
917 if check_axis != axis:
918 return False
920 return True
922def buildScanDDIMap(scanSummary, ddIspectralWindowInfo):
923 """
924 Builds a scan->DDI map and 3 list of # visibilities per DDI, scan, field
926 :param scanSummary: scan summary dictionary as produced by the mstool (getscansummary)
927 :param ddiSpectralWindowInfo: SPW info dictionary as produced by the mstool
928 (getspectralwindowinfo())
929 :returns: a dict with a scan->ddi map, and three dict with # of visibilities per
930 ddi, scan, and field.
931 """
932 # Make an array for total number of visibilites per ddi and scan separatelly
933 nVisPerDDI = {}
934 nVisPerScan = {}
935 nVisPerField = {}
937 # Iterate over scan list
938 scanDdiMap = {}
939 for scan in sorted(scanSummary):
940 # Initialize scan sub-map
941 scanDdiMap[scan] = {}
942 # Iterate over timestamps for this scan
943 for timestamp in scanSummary[scan]:
944 # Get list of ddis for this timestamp
945 DDIds = scanSummary[scan][timestamp]['DDIds']
946 fieldId = str(scanSummary[scan][timestamp]['FieldId'])
947 # Get number of rows per ddi (assume all DDIs have the same number of rows)
948 # In ALMA data WVR DDI has only one row per antenna but it is separated from the other DDIs
949 nrowsPerDDI = scanSummary[scan][timestamp]['nRow'] / len(DDIds)
950 # Iterate over DDIs for this timestamp
951 for ddi in DDIds:
952 # Convert to string to be used as a map key
953 ddi = str(ddi)
954 # Check if DDI entry is already present for this scan, otherwise initialize it
955 if ddi not in scanDdiMap[scan]:
956 scanDdiMap[scan][ddi] = {}
957 scanDdiMap[scan][ddi]['nVis'] = 0
958 scanDdiMap[scan][ddi]['fieldId'] = fieldId
959 scanDdiMap[scan][ddi]['isWVR'] = ddIspectralWindowInfo[ddi]['isWVR']
960 # Calculate number of visibilities
961 nvis = nrowsPerDDI*ddIspectralWindowInfo[ddi]['NumChan']*ddIspectralWindowInfo[ddi]['NumCorr']
962 # Add number of rows and vis from this timestamp
963 scanDdiMap[scan][ddi]['nVis'] = scanDdiMap[scan][ddi]['nVis'] + nvis
964 # Update ddi nvis
965 if ddi not in nVisPerDDI:
966 nVisPerDDI[ddi] = nvis
967 else:
968 nVisPerDDI[ddi] = nVisPerDDI[ddi] + nvis
969 # Update scan nvis
970 if scan not in nVisPerScan:
971 nVisPerScan[scan] = nvis
972 else:
973 nVisPerScan[scan] = nVisPerScan[scan] + nvis
974 # Update field nvis
975 if fieldId not in nVisPerField:
976 nVisPerField[fieldId] = nvis
977 else:
978 nVisPerField[fieldId] = nVisPerField[fieldId] + nvis
980 return scanDdiMap, nVisPerDDI, nVisPerScan, nVisPerField
982def getPartitionMap(msfilename, nsubms, selection={}, axis=['field','spw','scan'],plotMode=0):
983 """Generates a partition scan/spw map to obtain optimal load balancing with the following criteria:
985 1st - Maximize the scan/spw/field distribution across sub-MSs
986 2nd - Generate sub-MSs with similar size
988 In order to balance better the size of the subMSs the allocation process
989 iterates over the scan,spw pairs in descending number of visibilities.
991 That is larger chunks are allocated first, and smaller chunks at the final
992 stages so that they can be used to balance the load in a stable way
994 Keyword arguments:
995 msname -- Input MS filename
996 nsubms -- Number of subMSs
997 selection -- Data selection dictionary
998 axis -- Vector of strings containing the axis for load distribution (scan,spw,field)
999 plotMode -- Integer in the range 0-3 to determine the plot generation mode
1000 0 - Don't generate any plots
1001 1 - Show plots but don't save them
1002 2 - Save plots but don't show them
1003 3 - Show and save plots
1005 Returns a map of the sub-MSs with the corresponding scan/spw selections and the number of visibilities
1006 """
1008 # Open ms tool
1009 mst_local.open(msfilename)
1011 # Apply data selection
1012 if isinstance(selection, dict) and selection != {}:
1013 mst_local.msselect(items=selection)
1015 # Get list of DDIs and timestamps per scan
1016 scanSummary = mst_local.getscansummary()
1017 ddIspectralWindowInfo = mst_local.getspectralwindowinfo()
1019 # Close ms tool
1020 mst_local.close()
1022 # Get list of WVR SPWs using the ms metadata tool
1023 msmdt_local.open(msfilename)
1024 wvrspws = msmdt_local.wvrspws()
1025 msmdt_local.close()
1027 # Mark WVR DDIs as identified by the ms metadata tool
1028 for ddi in ddIspectralWindowInfo:
1029 if ddIspectralWindowInfo[ddi] in wvrspws:
1030 ddIspectralWindowInfo[ddi]['isWVR'] = True
1031 else:
1032 ddIspectralWindowInfo[ddi]['isWVR'] = False
1034 scanDdiMap, nVisPerDDI, nVisPerScan, nVisPerField = buildScanDDIMap(scanSummary,
1035 ddIspectralWindowInfo)
1037 # Sort the scan/ddi pairs depending on the number of visibilities
1038 ddiList = list()
1039 scanList = list()
1040 fieldList = list()
1041 nVisList = list()
1042 nScanDDIPairs = 0
1043 for scan in scanDdiMap:
1044 for ddi in scanDdiMap[scan]:
1045 ddiList.append(ddi)
1046 scanList.append(scan)
1047 fieldList.append(scanDdiMap[scan][ddi]['fieldId'])
1048 nVisList.append(scanDdiMap[scan][ddi]['nVis'])
1049 nScanDDIPairs += 1
1052 # Check that the number of available scan/ddi pairs is not greater than the number of subMSs
1053 if nsubms > nScanDDIPairs:
1054 casalog.post("Number of subMSs (%i) is greater than available scan,ddi pairs (%i), setting nsubms to %i"
1055 % (nsubms,nScanDDIPairs,nScanDDIPairs),"WARN","getPartitionMap")
1056 nsubms = nScanDDIPairs
1058 ddiArray = np.array(ddiList)
1059 scanArray = np.array(scanList)
1060 nVisArray = np.array(nVisList)
1062 nVisSortIndex = np.lexsort((ddiArray, scanArray, nVisArray))
1063 # argsort/lexsort return indices by increasing value. This reverses the indices by
1064 # decreasing value
1065 nVisSortIndex[:] = nVisSortIndex[::-1]
1067 ddiArray = ddiArray[nVisSortIndex]
1068 scanArray = scanArray[nVisSortIndex]
1069 nVisArray = nVisArray[nVisSortIndex]
1071 # Make a map for the contribution of each subMS to each scan
1072 scanNvisDistributionPerSubMs = {}
1073 for scan in scanSummary:
1074 scanNvisDistributionPerSubMs[scan] = np.zeros(nsubms)
1077 # Make a map for the contribution of each subMS to each ddi
1078 ddiNvisDistributionPerSubMs = {}
1079 for ddi in ddIspectralWindowInfo:
1080 ddiNvisDistributionPerSubMs[ddi] = np.zeros(nsubms)
1083 # Make a map for the contribution of each subMS to each field
1084 fieldList = np.unique(fieldList)
1085 fieldNvisDistributionPerSubMs = {}
1086 for field in fieldList:
1087 fieldNvisDistributionPerSubMs[field] = np.zeros(nsubms)
1090 # Make an array for total number of visibilites per subms
1091 nvisPerSubMs = np.zeros(nsubms)
1094 # Initialize final map of scans/pw pairs per subms
1095 submScanDdiMap = {}
1096 for subms in range (0,nsubms):
1097 submScanDdiMap[subms] = {}
1098 submScanDdiMap[subms]['scanList'] = list()
1099 submScanDdiMap[subms]['ddiList'] = list()
1100 submScanDdiMap[subms]['fieldList'] = list()
1101 submScanDdiMap[subms]['nVisList'] = list()
1102 submScanDdiMap[subms]['nVisTotal'] = 0
1105 # Iterate over the scan/ddi map and assign each pair to a subMS
1106 for pair in range(len(ddiArray)):
1108 ddi = ddiArray[pair]
1109 scan = scanArray[pair]
1110 field = scanDdiMap[scan][ddi]['fieldId']
1112 # Select the subMS that with bigger (scan/ddi/field gap)
1113 # We use the average as a refLevel to include global structure information
1114 # But we also take into account the actual max value in case we are distributing large uneven chunks
1115 jointNvisGap = np.zeros(nsubms)
1116 if 'scan' in axis:
1117 refLevel = max(nVisPerScan[scan] //
1118 nsubms,scanNvisDistributionPerSubMs[scan].max())
1119 jointNvisGap = jointNvisGap + refLevel - scanNvisDistributionPerSubMs[scan]
1120 if 'spw' in axis:
1121 refLevel = max(nVisPerDDI[ddi] //
1122 nsubms,ddiNvisDistributionPerSubMs[ddi].max())
1123 jointNvisGap = jointNvisGap + refLevel - ddiNvisDistributionPerSubMs[ddi]
1124 if 'field' in axis:
1125 refLevel = max(nVisPerField[field] //
1126 nsubms,fieldNvisDistributionPerSubMs[field].max())
1127 jointNvisGap = jointNvisGap + refLevel - fieldNvisDistributionPerSubMs[field]
1129 optimalSubMs = np.where(jointNvisGap == jointNvisGap.max())
1130 optimalSubMs = optimalSubMs[0] # np.where returns a tuple
1132 # In case of multiple candidates select the subms with minum number of total visibilities
1133 if len(optimalSubMs) > 1:
1134 subIdx = np.argmin(nvisPerSubMs[optimalSubMs])
1135 optimalSubMs = optimalSubMs[subIdx]
1136 else:
1137 optimalSubMs = optimalSubMs[0]
1139 # Store the scan/ddi pair info in the selected optimal subms
1140 nVis = scanDdiMap[scan][ddi]['nVis']
1141 nvisPerSubMs[optimalSubMs] = nvisPerSubMs[optimalSubMs] + nVis
1142 submScanDdiMap[optimalSubMs]['scanList'].append(int(scan))
1143 submScanDdiMap[optimalSubMs]['ddiList'].append(int(ddi))
1144 submScanDdiMap[optimalSubMs]['fieldList'].append(field)
1145 submScanDdiMap[optimalSubMs]['nVisList'].append(nVis)
1146 submScanDdiMap[optimalSubMs]['nVisTotal'] = submScanDdiMap[optimalSubMs]['nVisTotal'] + nVis
1148 # Also update the counters for the subms-scan and subms-ddi maps
1149 scanNvisDistributionPerSubMs[scan][optimalSubMs] = scanNvisDistributionPerSubMs[scan][optimalSubMs] + nVis
1150 ddiNvisDistributionPerSubMs[ddi][optimalSubMs] = ddiNvisDistributionPerSubMs[ddi][optimalSubMs] + nVis
1151 fieldNvisDistributionPerSubMs[field][optimalSubMs] = fieldNvisDistributionPerSubMs[field][optimalSubMs] + nVis
1154 # Generate plots
1155 if plotMode > 0:
1156 plt.close()
1157 plotname_prefix = os.path.basename(msfilename) + ' axis ' + string.join(axis)
1158 plotVisDistribution(nVisPerScan,scanNvisDistributionPerSubMs,plotname_prefix,'scan',plotMode=plotMode)
1159 plotVisDistribution(nVisPerDDI,ddiNvisDistributionPerSubMs,plotname_prefix,'ddi',plotMode=plotMode)
1160 plotVisDistribution(nVisPerField,fieldNvisDistributionPerSubMs,plotname_prefix,'field',plotMode=plotMode)
1163 # Generate list of taql commands
1164 for subms in submScanDdiMap:
1165 # Initialize taql command
1166 from collections import defaultdict
1167 dmytaql = defaultdict(list)
1169 for pair in range(len(submScanDdiMap[subms]['scanList'])):
1170 # Get scan/ddi for this pair
1171 ddi = submScanDdiMap[subms]['ddiList'][pair]
1172 scan = submScanDdiMap[subms]['scanList'][pair]
1173 dmytaql[ddi].append(scan)
1175 mytaql = []
1176 for ddi, scans in dmytaql.items():
1177 scansel = '[' + ', '.join([str(x) for x in scans]) + ']'
1178 mytaql.append(('(DATA_DESC_ID==%i && (SCAN_NUMBER IN %s))') % (ddi, scansel))
1180 mytaql = ' OR '.join(mytaql)
1182 # Store taql
1183 submScanDdiMap[subms]['taql'] = mytaql
1186 # Return map of scan/ddi pairs per subMs
1187 return submScanDdiMap
1190def plotVisDistribution(nvisMap,idNvisDistributionPerSubMs,filename,idLabel,plotMode=1):
1191 """Generates a plot to show the distribution of scans/wp across subMs.
1192 The plot style is a stacked bar char, where the spw/scans with higher number of visibilities are shown at the bottom
1194 Keyword arguments:
1195 nvisMap -- Map of total numbe of visibilities per Id
1196 idNvisDistributionPerSubMs -- Map of visibilities per subMS for each Id
1197 filename -- Name of MS to be shown in the title and plot filename
1198 idLabel -- idLabel to indicate the id (spw, scan) to be used for the figure title
1199 plotMode -- Integer in the range 0-3 to determine the plot generation mode
1200 0 - Don't generate any plots
1201 1 - Show plots but don't save them
1202 2 - Save plots but don't show them
1203 2 - Show and save plots
1204 """
1206 # Create a new figure
1207 plt.ioff()
1210 # If plot is not to be shown then use pre-define sized figure to 1585x1170 pizels with 75 DPI
1211 # (we cannot maximize the window to the screen size)
1212 if plotMode==2:
1213 plt.figure(figsize=(21.13,15.6),dpi=75) # Size is given in inches
1214 else:
1215 plt.figure()
1218 # Sort the id according to the total number of visibilities to that we can
1219 # represent bigger the groups at the bottom and the smaller ones at the top
1220 idx = 0
1221 idArray = np.zeros(len(nvisMap))
1222 idNvisArray = np.zeros(len(nvisMap))
1223 for id in nvisMap:
1224 idArray[idx] = int(id)
1225 idNvisArray[idx] = nvisMap[id]
1226 idx = idx + 1
1228 idArraySortIndex = np.argsort(idNvisArray)
1229 idArraySortIndex[:] = idArraySortIndex[::-1]
1230 idArraySorted = idArray[idArraySortIndex]
1233 # Initialize color vector to alternate cold/warm colors
1234 nid = len(nvisMap)
1235 colorVector = list()
1236 colorRange = range(nid)
1237 colorVectorEven = colorRange[::2]
1238 colorVectorOdd = colorRange[1::2]
1239 colorVectorOdd.reverse()
1240 while len(colorVectorOdd) > 0 or len(colorVectorEven) > 0:
1241 if len(colorVectorOdd) > 0: colorVector.append(colorVectorOdd.pop())
1242 if len(colorVectorEven) > 0: colorVector.append(colorVectorEven.pop())
1245 # Generate stacked bar plot
1246 coloridx = 0 # color index
1247 width = 0.35 # bar width
1248 nsubms = len(idNvisDistributionPerSubMs[idNvisDistributionPerSubMs.keys()[0]])
1249 idx = np.arange(nsubms) # location of the bar centers in the horizontal axis
1250 bottomLevel = np.zeros(nsubms) # Reference level for the bars to be stacked after the previous ones
1251 legendidLabels = list() # List of legend idLabels
1252 plotHandles=list() # List of plot handles for the legend
1253 for id in idArraySorted:
1255 id = str(int(id))
1257 idplot = plt.bar(idx, idNvisDistributionPerSubMs[id], width, bottom=bottomLevel, color=plt.cm.Paired(1.*colorVector[coloridx]/nid))
1259 # Update color index
1260 coloridx = coloridx + 1
1262 # Update legend lists
1263 plotHandles.append(idplot)
1264 legendidLabels.append(idLabel + ' ' + id)
1266 # Update reference level
1267 bottomLevel = bottomLevel + idNvisDistributionPerSubMs[id]
1270 # Add legend
1271 plt.legend( plotHandles, legendidLabels, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
1274 # AQdd lable for y axis
1275 plt.ylabel('nVis')
1278 # Add x-ticks
1279 xticks = list()
1280 for subms in range(0,nsubms):
1281 xticks.append('subMS-' + str(subms))
1282 plt.xticks(idx+width/2., xticks )
1285 # Add title
1286 title = filename + ' distribution of ' + idLabel + ' visibilities across sub-MSs'
1287 plt.title(title)
1290 # Resize to full screen
1291 if plotMode==1 or plotMode==3:
1292 mng = plt.get_current_fig_manager()
1293 mng.resize(*mng.window.maxsize())
1296 # Show figure
1297 if plotMode==1 or plotMode==3:
1298 plt.ion()
1299 plt.show()
1302 # Save plot
1303 if plotMode>1:
1304 title = title.replace(' ','-') + '.png'
1305 plt.savefig(title)
1308 # If plot is not to be shown then close it
1309 if plotMode==2:
1310 plt.close()