Coverage for /wheeldirectory/casa-6.7.0-12-py3.10.el8/lib/py/lib/python3.10/site-packages/casatasks/private/parallel/parallel_task_helper.py: 24%
417 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-10-31 19:10 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-10-31 19:10 +0000
1#!/usr/bin/env python
2import os
3import sys
4import copy
5import shutil
6import inspect
8from .. import partitionhelper as ph
9from casatools import table as tbtool
10from casatools import ms as mstool
11from casatasks import casalog
12from casatasks.private.parallel.rflag_post_proc import combine_rflag_subreport, is_rflag_report
13from casatasks.private.parallel.rflag_post_proc import finalize_agg_rflag_thresholds
15def strfind(str_instance, a):
16 return str_instance.find(a)
18# common function to use to get a dictionary values iterator
19def locitervalues(adict):
20 return adict.values()
22# To handle thread-based Tier-2 parallelization
23import threading
25# jagonzal (CAS-4106): Properly report all the exceptions and errors in the cluster framework
26import traceback
28# jagonzal (Migration to MPI)
29try:
30 from casampi.MPIEnvironment import MPIEnvironment
31 from casampi.MPICommandClient import MPICommandClient
32 mpi_available = True
33except ImportError:
34 mpi_available = False
36class JobData:
37 """
38 This class incapsulates a single job. The commandName is the name
39 of the task to be executed. The jobInfo is a dictionary of all
40 parameters that need to be handled.
41 """
42 class CommandInfo:
44 def __init__(self, commandName, commandInfo, returnVariable):
45 self.commandName = commandName
46 self.commandInfo = commandInfo
47 self.returnVariable = returnVariable
49 def getReturnVariable(self):
50 return self.returnVariable
52 def getCommandLine(self):
53 firstArgument = True
54 output = "%s = %s(" % (self.returnVariable, self.commandName)
55 for (arg,value) in self.commandInfo.items():
56 if firstArgument:
57 firstArgument = False
58 else:
59 output += ', '
60 if isinstance(value, str):
61 output += ("%s = '%s'" % (arg, value))
62 else:
63 output += ("%s = " % arg) + str(value)
64 output += ')'
65 return output
68 def __init__(self, commandName, commandInfo = {}):
69 self._commandList = []
70 self.status = 'new'
71 self.addCommand(commandName, commandInfo)
72 self._returnValues = None
75 def addCommand(self, commandName, commandInfo):
76 """
77 Add an additional command to this Job to be exectued after
78 previous Jobs.
79 """
80 rtnVar = "returnVar%d" % len(self._commandList)
81 self._commandList.append(JobData.CommandInfo(commandName,
82 commandInfo,
83 rtnVar))
84 def getCommandLine(self):
85 """
86 This method will return the command line(s) to be executed on the
87 remote engine. It is usually only needed for debugging or for
88 the JobQueueManager.
89 """
90 output = ''
91 for idx in range(len(self._commandList)):
92 if idx > 0:
93 output += '; '
94 output += self._commandList[idx].getCommandLine()
95 return output
97 def getCommandNames(self):
98 """
99 This method will return a list of command names that are associated
100 with this job.
101 """
102 return [command.commandName for command in self._commandList]
105 def getCommandArguments(self, commandName = None):
106 """
107 This method will return the command arguments associated with a
108 particular job.
109 * If commandName is not none the arguments for the command with
110 that name are returned.
111 * Otherwise a dictionary (with keys being the commandName and
112 the value being the dictionary of arguments) is returned.
113 * If there is only a single command the arguments for that
114 command are returned as a dictionary.
115 """
116 returnValue = {}
117 for command in self._commandList:
118 if commandName is None or commandName == command.commandName:
119 returnValue[command.commandName] = command.commandInfo
121 if len(returnValue) == 1:
122 return list(returnValue.values())[0]
123 return returnValue
125 def getReturnVariableList(self):
126 return [ci.returnVariable for ci in self._commandList]
128 def setReturnValues(self, valueList):
129 self._returnValues = valueList
131 def getReturnValues(self):
132 if self._returnValues is not None:
133 if len(self._returnValues) == 1:
134 return self._returnValues[0]
135 return self._returnValues
137class ParallelTaskHelper:
138 """
139 This is the extension of the TaskHelper to allow for parallel
140 operation. For simple tasks all that should be required to make
141 a task parallel is to use this rather than the TaskHelper method
142 above
143 """
145 __bypass_parallel_processing = 0
146 __async_mode = False
147 __multithreading = False
149 def __init__(self, task_name, args = {}):
150 self._arg = dict(args)
151 self._arguser = {}
152 self._taskName = task_name
153 self._executionList = []
154 self._jobQueue = None
155 # Cache the initial inputs
156 self.__originalParams = args
157 # jagonzal: Add reference to cluster object
158 self._cluster = None
159 self._mpi_cluster = False
160 self._command_request_id_list = None
161 if not mpi_available or not MPIEnvironment.is_mpi_enabled:
162 self.__bypass_parallel_processing = 1
163 if (self.__bypass_parallel_processing == 0):
164 self._mpi_cluster = True
165 self._command_request_id_list = []
166 self._cluster = MPICommandClient()
167 # jagonzal: To inhibit return values consolidation
168 self._consolidateOutput = True
169 # jagonzal (CAS-4287): Add a cluster-less mode to by-pass parallel processing for MMSs as requested
170 # This is actually a dict, with key=vis and value= the 'success' field of the cmd.
171 # (exception: for tasks with parameter outputvis (like partition), key=outputvis)
172 self._sequential_return_list = {}
174 def override_arg(self,arg,value):
175 self._arguser[arg] = value
177 def initialize(self):
178 """
179 This is the setup portion.
180 Currently it:
181 * Finds the full path for the input vis.
182 * Initialize the MPICommandClient
183 """
184 self._arg['vis'] = os.path.abspath(self._arg['vis'])
186 # jagonzal (Migration to MPI)
187 if self._mpi_cluster:
188 self._cluster.start_services()
190 def getNumberOfServers(self):
191 """
192 Return the number of engines (iPython cluster) or the number of servers (MPI cluster)
193 """
194 if (mpi_available and self.__bypass_parallel_processing == 0):
195 return len(MPIEnvironment.mpi_server_rank_list())
196 else:
197 return None
199 def generateJobs(self):
200 """
201 This is the method which generates all of the actual jobs to be
202 done. The default is to assume the input vis is a reference ms and
203 build one job for each referenced ms.
204 """
206 casalog.origin("ParallelTaskHelper")
208 try:
209 msTool = mstool()
210 if not msTool.open(self._arg['vis']):
211 raise ValueError("Unable to open MS %s," % self._arg['vis'])
212 if not msTool.ismultims():
213 raise ValueError("MS is not a MultiMS, simple parallelization failed")
215 subMs_idx = 0
216 for subMS in msTool.getreferencedtables():
217 localArgs = copy.deepcopy(self._arg)
218 localArgs['vis'] = subMS
220 for key in self._arguser:
221 localArgs[key] = self._arguser[key][subMs_idx]
222 subMs_idx += 1
224 if self._mpi_cluster:
225 self._executionList.append([self._taskName + '()',localArgs])
226 else:
227 self._executionList.append(JobData(self._taskName,localArgs))
229 msTool.close()
230 return True
231 except Exception as instance:
232 casalog.post("Error handling MMS %s: %s" % (self._arg['vis'],instance),"WARN","generateJobs")
233 msTool.close()
234 return False
237 def executeJobs(self):
239 casalog.origin("ParallelTaskHelper")
241 # jagonzal (CAS-4287): Add a cluster-less mode to by-pass parallel processing for MMSs as requested
242 if (self.__bypass_parallel_processing == 1):
243 for job in self._executionList:
244 parameters = job.getCommandArguments()
245 try:
246 gvars = globals( )
247 try:
248 exec("from casatasks import *; " + job.getCommandLine(),gvars)
249 except Exception as exc:
250 casalog.post("exec in parallel_task_helper.executeJobs failed: {}'".format(exc))
251 raise
253 # jagonzal: Special case for partition
254 # The 'True' values emulate the command_response['successful'] that
255 # we'd get in parallel runs from other MPI processes.
256 if 'outputvis' in parameters:
257 self._sequential_return_list[parameters['outputvis']] = True
258 else:
259 self._sequential_return_list[parameters['vis']] = gvars['returnVar0'] or True
261 except Exception as instance:
262 str_instance = str(instance)
263 if (strfind(str_instance, "NullSelection") == 0):
264 casalog.post("Error running task sequentially %s: %s" % (job.getCommandLine(),str_instance),"WARN","executeJobs")
265 traceback.print_tb(sys.exc_info()[2])
266 else:
267 casalog.post("Ignoring NullSelection error from %s" % (parameters['vis']),"INFO","executeJobs")
268 self._executionList = []
269 else:
270 for job in self._executionList:
271 command_request_id = self._cluster.push_command_request(job[0],False,None,job[1])
272 self._command_request_id_list.append(command_request_id[0])
275 def postExecution(self):
277 casalog.origin("ParallelTaskHelper")
279 ret_list = {}
280 if (self.__bypass_parallel_processing==1):
281 ret_list = self._sequential_return_list
282 self._sequential_return_list = {}
283 elif (self._cluster != None):
284 # jagonzal (CAS-7631): Support for thread-based Tier-2 parallelization
285 if ParallelTaskHelper.getMultithreadingMode():
286 event = self._cluster.get_command_response_event(self._command_request_id_list)
287 ParallelTaskWorker.releaseTaskLock()
288 event.wait()
289 ParallelTaskWorker.acquireTaskLock()
290 # Get command response
291 command_response_list = self._cluster.get_command_response(self._command_request_id_list,True,True)
292 # Format list in the form of vis dict
293 ret_list = {}
294 for command_response in command_response_list:
295 vis = command_response['parameters']['vis']
296 if 'uvcontsub' in command_response['command']:
297 # One more particular case, similar as in 'executeJob' for partition.
298 # The design of these lists and how they are used in different ways in
299 # tasks uvcontsub, setjy, flagdata, etc. is evil
300 # uvcontsub expects a 'success' True/False value for every subMS rather
301 # than the return value of the subMS uvcontsub.
302 ret_list[vis] = command_response['successful']
303 else:
304 ret_list[vis] = command_response['ret']
305 else:
306 return None
308 ret = ret_list
309 if self._consolidateOutput:
310 ret = ParallelTaskHelper.consolidateResults(ret_list,self._taskName)
312 return ret
315 @staticmethod
316 def consolidateResults(ret_list,taskname):
317 if isinstance(list(ret_list.values())[0],bool):
318 retval = True
319 for subMs in ret_list:
320 if not ret_list[subMs]:
321 casalog.post("%s failed for sub-MS %s" % (taskname,subMs),"WARN","consolidateResults")
322 retval = False
323 return retval
324 elif any(isinstance(v,dict) for v in locitervalues(ret_list)):
325 ret_dict = {}
326 for _key, subMS_dict in ret_list.items():
327 casalog.post(" ***** consolidateResults, subMS: {0}".format(subMS_dict),
328 "WARN", "consolidateResults")
329 if isinstance(subMS_dict, dict):
330 try:
331 ret_dict = ParallelTaskHelper.combine_dictionaries(subMS_dict, ret_dict)
332 except Exception as instance:
333 casalog.post("Error post processing MMS results {0}: {1}".format(
334 subMS_dict, instance), 'WARN', 'consolidateResults')
335 raise
336 return ParallelTaskHelper.finalize_consolidate_results(ret_dict)
339 @staticmethod
340 def combine_dictionaries(dict_list,ret_dict):
341 """
342 Combines a flagging (sub-)report dictionary dict_list (from a subMS) into an overall
343 report dictionary (ret_dict).
344 """
345 for key, item in dict_list.items():
346 if isinstance(item, dict):
347 if key in ret_dict:
348 if is_rflag_report(item):
349 ret_dict[key] = combine_rflag_subreport(item, ret_dict[key])
350 else:
351 ret_dict[key] = ParallelTaskHelper.combine_dictionaries(item,ret_dict[key])
352 else:
353 ret_dict[key] = ParallelTaskHelper.combine_dictionaries(item,{})
354 else:
355 if key in ret_dict:
356 # the 'nreport' field should not be summed - it's an index
357 if not isinstance(ret_dict[key],str) and 'nreport' != key:
358 # This is a good default for all reports that have flag counters
359 ret_dict[key] += item
360 else:
361 ret_dict[key] = item
363 return ret_dict
366 @staticmethod
367 def finalize_consolidate_results(ret):
368 """ Applies final step to the items of the report dictionary.
369 For now only needs specific processing to finalize the aggregation of the RFlag
370 thresholds (freqdev/timedev) vectors. """
372 for key, item in ret.items():
373 if isinstance(item, dict) and is_rflag_report(item):
374 ret[key] = finalize_agg_rflag_thresholds(item)
376 return ret
379 @staticmethod
380 def getResult(command_request_id_list,taskname):
382 # Access MPICommandClietn singleton instance
383 client = MPICommandClient()
385 # Get response list
386 command_response_list = client.get_command_response(command_request_id_list,True,True)
388 # Format list in the form of vis dict
389 ret_list = {}
390 for command_response in command_response_list:
391 vis = command_response['parameters']['vis']
392 ret_list[vis] = command_response['ret']
394 # Consolidate results and return
395 ret = ParallelTaskHelper.consolidateResults(ret_list,taskname)
397 return ret
400 def go(self):
402 casalog.origin("ParallelTaskHelper")
404 self.initialize()
405 if (self.generateJobs()):
406 self.executeJobs()
408 if ParallelTaskHelper.__async_mode:
409 res_list = [] if self._command_request_id_list is None else list(self._command_request_id_list)
410 return res_list
411 else:
412 try:
413 retVar = self.postExecution()
414 except Exception as instance:
415 casalog.post("Error post processing MMS results %s: %s" % (self._arg['vis'],instance),"WARN","go")
416 traceback.print_tb(sys.exc_info()[2])
417 return False
418 else:
419 retVar = False
421 # Restore casalog origin
422 casalog.origin(self._taskName)
424 return retVar
426 @staticmethod
427 def getReferencedMSs(vis):
429 msTool = mstool()
430 if not msTool.open(vis):
431 raise ValueError("Unable to open MS %s." % vis)
433 if not msTool.ismultims():
434 raise ValueError("MS %s is not a reference MS." % vis)
436 rtnValue = msTool.getreferencedtables()
437 if not isinstance(rtnValue, list):
438 rtnValue = [rtnValue]
440 msTool.close()
441 return rtnValue
444 @staticmethod
445 def restoreSubtableAgreement(vis, mastersubms='', subtables=[]):
446 """
447 Tidy up the MMS vis by replacing the subtables of all SubMSs
448 by the subtables from the SubMS given by "mastersubms".
449 If specified, only the subtables in the list "subtables"
450 are replaced, otherwise all.
451 If "mastersubms" is not given, the first SubMS of the MMS
452 will be used as master.
453 """
455 msTool = mstool();
456 msTool.open(vis)
457 theSubMSs = msTool.getreferencedtables()
458 msTool.close()
460 tbTool = tbtool( );
462 if mastersubms=='':
463 tbTool.open(vis)
464 myKeyw = tbTool.getkeywords()
465 tbTool.close()
466 mastersubms=os.path.dirname(myKeyw['ANTENNA'].split(' ')[1]) #assume ANTENNA is present
468 mastersubms = os.path.abspath(mastersubms)
470 theSubTables = ph.getSubtables(mastersubms)
472 if subtables==[]:
473 subtables=theSubTables
474 else:
475 for s in subtables:
476 if not (s in theSubTables):
477 raise ValueError( s+' is not a subtable of '+ mastersubms )
479 origpath = os.getcwd()
480 masterbase = os.path.basename(mastersubms)
482 for r in theSubMSs:
483 rbase = os.path.basename(r)
484 if not rbase==masterbase:
485 for s in subtables:
486 theSubTab = r+'/'+s
487 if os.path.islink(theSubTab): # don't copy over links
488 if(os.path.basename(os.path.dirname(os.path.realpath(theSubTab)))!=masterbase):
489 # the mastersubms has changed: make new link
490 os.chdir(r)
491 shutil.rmtree(s, ignore_errors=True)
492 os.symlink('../'+masterbase+'/'+s, s)
493 os.chdir(origpath)
494 else:
495 shutil.rmtree(theSubTab, ignore_errors=True)
496 shutil.copytree(mastersubms+'/'+s, theSubTab)
498 return True
500 @staticmethod
501 def bypassParallelProcessing(switch=1):
502 """
503 # jagonzal (CAS-4287): Add a cluster-less mode to by-pass parallel processing for MMSs as requested
504 switch=1 => Process each sub-Ms sequentially
505 switch=2 => Process the MMS as a normal MS
506 """
507 ParallelTaskHelper.__bypass_parallel_processing = switch
509 @staticmethod
510 def getBypassParallelProcessing():
511 """
512 # jagonzal (CAS-4287): Add a cluster-less mode to by-pass parallel processing for MMSs as requested
513 switch=1 => Process each sub-Ms sequentially
514 switch=2 => Process the MMS as a normal MS
515 """
516 return ParallelTaskHelper.__bypass_parallel_processing
518 @staticmethod
519 def setAsyncMode(async_mode=False):
520 ParallelTaskHelper.__async_mode = async_mode
522 @staticmethod
523 def getAsyncMode():
524 return ParallelTaskHelper.__async_mode
526 @staticmethod
527 def setMultithreadingMode(multithreading=False):
528 ParallelTaskHelper.__multithreading = multithreading
530 @staticmethod
531 def getMultithreadingMode():
532 return ParallelTaskHelper.__multithreading
534 @staticmethod
535 def isParallelMS(vis):
536 """
537 This method will let us know if we can do the simple form
538 of parallelization by invoking on many referenced mss.
539 """
541 # jagonzal (CAS-4287): Add a cluster-less mode to by-pass parallel processing for MMSs as requested
542 if (ParallelTaskHelper.__bypass_parallel_processing == 2):
543 return False
545 msTool = mstool()
546 if not msTool.open(vis):
547 raise ValueError( "Unable to open MS %s," % vis)
548 rtnVal = msTool.ismultims() and \
549 isinstance(msTool.getreferencedtables(), list)
551 msTool.close()
552 return rtnVal
554 @staticmethod
555 def findAbsPath(input):
556 if isinstance(input,str):
557 return os.path.abspath(input)
559 if isinstance(input, list):
560 rtnValue = []
561 for file_i in input:
562 rtnValue.append(os.path.abspath(file_i))
563 return rtnValue
565 # Your on your own, don't know what to do
566 return input
568 @staticmethod
569 def isMPIEnabled():
570 return MPIEnvironment.is_mpi_enabled if mpi_available else False
572 @staticmethod
573 def isMPIClient():
574 return MPIEnvironment.is_mpi_client if mpi_available else False
576 @staticmethod
577 def listToCasaString(inputList):
578 """
579 This Method will take a list of integers and try to express them as a
580 compact set using the CASA notation.
581 """
582 if inputList is None or len(inputList) == 0:
583 return ''
585 def selectionString(rangeStart, rangeEnd):
586 if rangeStart == rangeEnd:
587 return str(rangeStart)
588 return "%d~%d" % (rangeStart, rangeEnd)
590 inputList.sort()
591 compactStrings = []
592 rangeStart = inputList[0]
593 lastValue = inputList[0]
594 for val in inputList[1:]:
595 if val > lastValue + 1:
596 compactStrings.append(selectionString(rangeStart,lastValue))
597 rangeStart = val
598 lastValue = val
599 compactStrings.append(selectionString(rangeStart,lastValue))
601 return ','.join([a for a in compactStrings])
604class ParallelTaskWorker:
606 # Initialize task lock
607 __task_lock = threading.Lock()
609 def __init__(self, cmd):
611 self.__cmd = compile(cmd,"ParallelTaskWorker", "eval")
612 self.__state = "initialized"
613 self.__res = None
614 self.__thread = None
615 self.__environment = self.getEnvironment()
616 self.__formatted_traceback = None
617 self.__completion_event = threading.Event()
619 def getEnvironment(self):
620 try:
621 # casampi should not depend on globals (casashell). And CASA6/casashell doesn't
622 # anyway have init_tasks:update_params. Keep going w/o globals
623 import casampi
624 return {}
625 except ImportError:
626 stack=inspect.stack()
627 for stack_level in range(len(stack)):
628 frame_globals=sys._getframe(stack_level).f_globals
629 if 'update_params' in frame_globals:
630 return dict(frame_globals)
632 raise Exception("CASA top level environment not found")
634 def start(self):
636 # Initialize completion event
637 self.__completion_event.clear()
639 # Spawn thread
640 self.__thread = threading.Thread(target=self.runCmd, args=(), kwargs=())
641 self.__thread.setDaemon(True)
642 self.__thread.start()
644 # Mark state as running
645 self.__state = "running"
647 def runCmd(self):
649 # Acquire lock
650 ParallelTaskWorker.acquireTaskLock()
652 # Update environment with globals from calling context
653 globals().update(self.__environment)
655 # Run compiled command
656 try:
657 self.__res = eval(self.__cmd)
658 # Mark state as successful
659 self.__state = "successful"
660 # Release task lock
661 ParallelTaskWorker.releaseTaskLock()
662 except Exception as instance:
663 # Mark state as failed
664 self.__state = "failed"
665 # Release task lock if necessary
666 if ParallelTaskWorker.checkTaskLock():ParallelTaskWorker.releaseTaskLock()
667 # Post error message
668 self.__formatted_traceback = traceback.format_exc()
669 casalog.post("Exception executing command '%s': %s"
670 % (self.__cmd,self.__formatted_traceback),
671 "SEVERE","ParallelTaskWorker::runCmd")
673 # Send completion event signal
674 self.__completion_event.set()
676 def getResult(self):
678 if self.__state == "running":
679 # Wait until completion event signal is received
680 self.__completion_event.wait()
683 if self.__state == "initialized":
684 casalog.post("Worker not started",
685 "WARN","ParallelTaskWorker::getResult")
686 elif self.__state == "successful":
687 return self.__res
688 elif self.__state == "failed":
689 casalog.post("Exception executing command '%s': %s"
690 % (self.__cmd,self.__formatted_traceback),
691 "SEVERE","ParallelTaskWorker::runCmd")
693 @staticmethod
694 def acquireTaskLock():
696 ParallelTaskWorker.__task_lock.acquire()
698 @staticmethod
699 def releaseTaskLock():
701 ParallelTaskWorker.__task_lock.release()
703 @staticmethod
704 def checkTaskLock():
706 return ParallelTaskWorker.__task_lock.locked()