LCOV - code coverage report
Current view: top level - synthesis/MeasurementComponents - VisCalSolver2.cc (source / functions) Hit Total Coverage
Test: casacpp_coverage.info Lines: 269 479 56.2 %
Date: 2024-12-11 20:54:31 Functions: 14 17 82.4 %

          Line data    Source code
       1             : //# VisCalSolver2.cc: Implementation of generic visibility solving
       2             : //# Copyright (C) 1996,1997,1998,1999,2000,2001,2002,2003
       3             : //# Associated Universities, Inc. Washington DC, USA.
       4             : //#
       5             : //# This library is free software; you can redistribute it and/or modify it
       6             : //# under the terms of the GNU Library General Public License as published by
       7             : //# the Free Software Foundation; either version 2 of the License, or (at your
       8             : //# option) any later version.
       9             : //#
      10             : //# This library is distributed in the hope that it will be useful, but WITHOUT
      11             : //# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
      12             : //# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Library General Public
      13             : //# License for more details.
      14             : //#
      15             : //# You should have received a copy of the GNU Library General Public License
      16             : //# along with this library; if not, write to the Free Software Foundation,
      17             : //# Inc., 675 Massachusetts Ave, Cambridge, MA 02139, USA.
      18             : //#
      19             : //# Correspondence concerning AIPS++ should be addressed as follows:
      20             : //#        Internet email: casa-feedback@nrao.edu.
      21             : //#        Postal address: AIPS++ Project Office
      22             : //#                        National Radio Astronomy Observatory
      23             : //#                        520 Edgemont Road
      24             : //#                        Charlottesville, VA 22903-2475 USA
      25             : //#
      26             : 
      27             : #include <synthesis/MeasurementComponents/VisCalSolver2.h>
      28             : 
      29             : #include <msvis/MSVis/VisBuffer.h>
      30             : 
      31             : #include <casacore/casa/Arrays/ArrayMath.h>
      32             : #include <casacore/casa/Arrays/MaskArrMath.h>
      33             : #include <casacore/casa/Arrays/ArrayLogical.h>
      34             : #include <casacore/casa/Arrays/ArrayIter.h>
      35             : //#include <scimath/Mathematics/MatrixMathLA.h>
      36             : #include <casacore/casa/BasicSL/String.h>
      37             : #include <casacore/casa/BasicMath/Math.h>
      38             : #include <casacore/casa/Utilities/Assert.h>
      39             : #include <casacore/casa/Exceptions/Error.h>
      40             : #include <casacore/casa/OS/Memory.h>
      41             : #include <casacore/casa/OS/Path.h>
      42             : 
      43             : #include <sstream>
      44             : 
      45             : #include <casacore/casa/Logging/LogMessage.h>
      46             : #include <casacore/casa/Logging/LogSink.h>
      47             : 
      48             : #define VCS2_PRTLEV 0
      49             : 
      50             : namespace casa { //# NAMESPACE CASA - BEGIN
      51             : 
      52             : using namespace casacore;
      53             : 
      54             : 
      55             : // **********************************************************
      56             : //  VisCalSolver2 Implementations
      57             : //
      58             : 
      59           0 : VisCalSolver2::VisCalSolver2() :
      60           0 :   SDBs_(NULL),
      61           0 :   ve_(NULL),
      62           0 :   svc_(NULL),
      63           0 :   nPar_(0),
      64           0 :   maxIter_(50),
      65           0 :   chiSq_(0.0),
      66           0 :   chiSqV_(4,0.0),
      67           0 :   lastChiSq_(0.0),dChiSq_(0.0),
      68           0 :   sumWt_(0.0),sumWtV_(4,0.0),nWt_(0),
      69           0 :   cvrgcount_(0),
      70           0 :   par_(), parOK_(), parErr_(), lastPar_(),
      71           0 :   dpar_(), 
      72           0 :   grad_(),hess_(),
      73           0 :   lambda_(2.0),
      74           0 :   optstep_(True),
      75           0 :   doL1_(false),
      76           0 :   L1clamp_(0),
      77           0 :   doRMSThresh_(false),
      78           0 :   RMSThresh_(0),
      79           0 :   nRMSThresh_(0),
      80           0 :   prtlev_(VCS2_PRTLEV)
      81             : {
      82           0 :   if (prtlev()>0) cout << "VCS2::VCS2()" << endl;
      83           0 : }
      84             : 
      85       15936 : VisCalSolver2::VisCalSolver2(String solmode, Vector<Float>& rmsthresh) :
      86       15936 :   SDBs_(NULL),
      87       15936 :   ve_(NULL),
      88       15936 :   svc_(NULL),
      89       15936 :   nPar_(0),
      90       15936 :   maxIter_(50),
      91       15936 :   chiSq_(0.0),
      92       15936 :   chiSqV_(4,0.0),
      93       15936 :   lastChiSq_(0.0),dChiSq_(0.0),
      94       15936 :   sumWt_(0.0),sumWtV_(4,0.0),nWt_(0),
      95       15936 :   cvrgcount_(0),
      96       15936 :   par_(), parOK_(), parErr_(), lastPar_(),
      97       15936 :   dpar_(), 
      98       15936 :   grad_(),hess_(),
      99       15936 :   lambda_(2.0),
     100       15936 :   optstep_(True),
     101       15936 :   doL1_(false),
     102       15936 :   L1clamp_(std::vector<Float>({5e-3, 5e-4, 5e-5})),
     103       15936 :   doRMSThresh_(false),
     104       15936 :   RMSThresh_(rmsthresh),  // 
     105       15936 :   nRMSThresh_(rmsthresh.nelements()),
     106       15936 :   prtlev_(VCS2_PRTLEV)
     107             : {
     108       15936 :   if (prtlev()>0) cout << "VCS2::VCS2(solmode)" << endl;
     109             : 
     110       15936 :   if (solmode.contains("L1")) doL1_=true;
     111       15936 :   if (solmode.contains("R")) doRMSThresh_=true;
     112             : 
     113       15936 :   if (doRMSThresh_ && nRMSThresh_==0) {
     114           0 :     doRMSThresh_=false;
     115             :     //RMSThresh_=Vector<Float>(std::vector<Float>({7.0,5.0,4.0,3.5,3.0,2.8,2.6,2.4,2.2}));
     116             :     //nRMSThresh_=RMSThresh_.nelements();
     117             :   }
     118             : 
     119       15936 : }
     120             : 
     121       15936 : VisCalSolver2::~VisCalSolver2() 
     122             : {  
     123       15936 :   if (prtlev()>0) cout << "VCS2::~VCS2()" << endl;
     124       15936 : }
     125             : 
     126             : 
     127             : // New SDBList version
     128       17184 : Bool VisCalSolver2::solve(VisEquation& ve, SolvableVisCal& svc, SDBList& sdbs) {
     129             : 
     130             :   // If L1 and/or outlier flagging requested, call specialize method
     131       17184 :   if (doL1_ || doRMSThresh_)
     132           0 :     return solveL1R(ve,svc,sdbs);
     133             : 
     134       17184 :   if (prtlev()>1) cout << "VCS2::solve(,,SDBs)" << endl;
     135             : 
     136             :   /*
     137             :   LogSink logsink;
     138             :   {
     139             :     LogMessage message(LogOrigin("VisCalSolver2", "solve"));
     140             :     ostringstream o; o<<"Beginning solve...";
     141             :     message.message(o);
     142             :     logsink.post(message);
     143             :   }
     144             :   */
     145             :   // Pointers to local ve,svc
     146       17184 :   ve_=&ve;
     147       17184 :   svc_=&svc;
     148       17184 :   SDBs_=&sdbs;
     149             : 
     150             :   // Verify that VisEq has the correct svc:
     151             :   // TBD?
     152             : 
     153             :   // Initialize everything 
     154       17184 :   initSolve();
     155             : 
     156       17184 :   Vector<Float> steplist(maxIter_+2,0.0);
     157       17184 :   Vector<Float> rsteplist(maxIter_+2,0.0);
     158             : 
     159             :   // Verify Data's validity for solve w.r.t. baselines available
     160             :   //   (this sets parOK() on per-antenna basis (for focusChan)
     161             :   //    based on data weights and baseline participation)
     162       17184 :   Bool oktosolve = svc_->verifyConstraints(*SDBs_);
     163             : 
     164       17184 :   if (oktosolve) {
     165             :     
     166        9414 :     if (prtlev()>1) cout << "First guess:" << endl
     167           0 :                          << "amp = " << amplitude(par()) << endl
     168           0 :                          << "pha = " << phase(par()) 
     169           0 :                          << endl;
     170             : 
     171             :     // Iterate solution
     172        9414 :     Int iter(0);
     173        9414 :     Bool done(False);
     174       77145 :     while (!done) {
     175             :       
     176       77145 :       if (prtlev()>2) cout << " Beginning iteration " << iter 
     177           0 :                            << "---------------------------------" << endl;
     178             :       
     179             :       // Differentiate the VB and get current Chi2
     180       77145 :       differentiate2();
     181       77145 :       chiSquare2();
     182       77145 :       if (chiSq()==0.0) {
     183           0 :         cout << "CHI2 IS SPURIOUSLY ZERO!*************************************" << endl;
     184             :         //cout << "R() = " << R() << endl;
     185             :         //      cout << "sum(wtmat) = " << sum(wtmat) << endl;
     186           0 :         return False;
     187             :       }
     188             : 
     189       77145 :       dChiSq() = chiSq()-lastChiSq();
     190             : 
     191             :       //      cout << "chi2 = " << chiSq() << " " << dChiSq() << " " << dChiSq()/chiSq() << endl;
     192             :       
     193             :       // Continuue if we haven't converged
     194       77145 :       if (!converged()) {
     195             :         
     196       67731 :         if (dChiSq()<=0.0) {
     197             :           // last step was good...
     198       67054 :           lastChiSq()=chiSq();
     199             :           
     200             :           // so accumulate new grad/hess...
     201       67054 :           accGradHess2();
     202             :           
     203             :           //...and adjust lambda downward
     204             :           //    lambda()/=2.0;
     205             :           //    lambda()=0.8;
     206       67054 :           lambda()=1.0;
     207             :         }
     208             :         else {
     209             :           //      cout << "reverting..." << chiSq() << " " << dChiSq() << " (" << iter << ")" << endl;
     210             :           // last step was bad, revert to previous 
     211         677 :           revert();
     212             :           //...with a larger lambda
     213             :           //    lambda()*=4.0;
     214         677 :           lambda()=1.0;
     215             :         }
     216             :         
     217             :         // Solve for the parameter step
     218       67731 :         solveGradHess();
     219             :         
     220             :         // Remember curr pars
     221       67731 :         lastPar()=par();
     222             : 
     223             :         // Refine the step size by exploring chi2 in the
     224             :         //  gradient direction
     225       67731 :         if (optstep_) //  && cvrgcount_>=3)
     226       67731 :           optStepSize2();
     227             :         
     228             :         // Update current parameters (saves a copy of them)
     229       67731 :         updatePar();
     230             : 
     231             : 
     232       67731 :         steplist(iter)=max(amplitude(dpar()));
     233       67731 :         rsteplist(iter)=max(amplitude(dpar())/amplitude(par()));
     234             : 
     235             :       }
     236             :       else {
     237             :         // Convergence means we're done!
     238        9414 :         done=True;
     239             : 
     240        9414 :         if (prtlev()>0) {
     241           0 :           cout << "par()=" << par() << endl;
     242             :         }
     243             : 
     244             : 
     245             :         /*
     246             :         cout << " good pars=" << ntrue(parOK())
     247             :              << " iterations=" << iter << endl
     248             :              << " steps=" << steplist(IPosition(1,0),IPosition(1,iter)) 
     249             :              << endl
     250             :              << " rsteps=" << rsteplist(IPosition(1,0),IPosition(1,iter)) 
     251             :              << endl;
     252             :         */   
     253             : 
     254             :         // Get parameter errors:
     255        9414 :         accGradHess2();
     256        9414 :         getErrors();
     257             : 
     258             :         // Return, signaling success if at least 1 good solution
     259        9414 :         return (ntrue(parOK())>0);
     260             :         
     261             :       }
     262             :       
     263             :       // Escape iteration loop via iteration limit
     264       67731 :       if (iter==maxIter()) {
     265           0 :         cout << "Reached iteration limit: " << iter << " iterations.  " << endl;
     266             :         //      cout << " good pars = " << ntrue(parOK())
     267             :         //           << "  steps = " << steplist
     268             :         //           << endl;
     269           0 :         done=True;
     270             :       }
     271             :       
     272             :       // Advance iteration counter
     273       67731 :       iter++;
     274             :     }
     275             :     
     276             :   }
     277             :   else {
     278        7770 :     cout << " Insufficient unflagged antennas to proceed with this solve." << endl;
     279             :   }
     280             : 
     281        7770 :   return False;
     282             :     
     283       17184 : }
     284             : 
     285             : // New L1(R)-capable version
     286           0 : Bool VisCalSolver2::solveL1R(VisEquation& ve, SolvableVisCal& svc, SDBList& sdbs) {
     287             : 
     288           0 :   if (prtlev()>1) cout << "VCS2::solve(,,SDBs)" << endl;
     289             : 
     290             :   /*
     291             :   LogSink logsink;
     292             :   {
     293             :     LogMessage message(LogOrigin("VisCalSolver2", "solve"));
     294             :     ostringstream o; o<<"Beginning solve...";
     295             :     message.message(o);
     296             :     logsink.post(message);
     297             :   }
     298             :   */
     299             :   // Pointers to local ve,svc
     300           0 :   ve_=&ve;
     301           0 :   svc_=&svc;
     302           0 :   SDBs_=&sdbs;
     303             : 
     304             :   // Verify that VisEq has the correct svc:
     305             :   // TBD?
     306             : 
     307             :   // Initialize everything 
     308           0 :   initSolve();
     309             : 
     310           0 :   Vector<Float> steplist(maxIter_+2,0.0);
     311           0 :   Vector<Float> rsteplist(maxIter_+2,0.0);
     312             : 
     313             :   // Verify Data's validity for solve w.r.t. baselines available
     314             :   //   (this sets parOK() on per-antenna basis (for focusChan)
     315             :   //    based on data weights and baseline participation)
     316           0 :   Bool oktosolve = svc_->verifyConstraints(*SDBs_);
     317             : 
     318           0 :   if (oktosolve) {
     319             :     
     320             :     // Tweak guess in L1 case, to avoid degeneracy...
     321           0 :     if (doL1_)
     322           0 :       par()*=Complex(1.0001,0.0);
     323             :   
     324           0 :     if (prtlev()>1) cout << "First guess:" << endl
     325           0 :                          << "amp = " << amplitude(par()) << endl
     326           0 :                          << "pha = " << phase(par()) 
     327           0 :                          << endl;
     328             : 
     329             :     // Iterate solution
     330           0 :     Int iter(0);
     331           0 :     Bool done(False);
     332           0 :     Bool applyWorkingFlags(false);
     333           0 :     Int L1iter(0), IRiter(0);
     334           0 :     while (!done) {
     335             :       
     336           0 :       if (prtlev()>2) cout << " Beginning iteration " << iter 
     337           0 :                            << "---------------------------------" << endl;
     338             :       
     339             :       // Differentiate the VB and get current Chi2
     340           0 :       differentiate2();
     341             : 
     342           0 :       if (doRMSThresh_ && applyWorkingFlags) {
     343           0 :         SDBs_->updateWorkingFlags();
     344           0 :         applyWorkingFlags=false;   // must be explicitly triggered below
     345             :       }
     346             : 
     347             :       // Set up working weights
     348           0 :       if (doL1_)
     349           0 :         SDBs_->updateWorkingWeights(doL1_,L1clamp_(L1iter));
     350             :       else
     351           0 :         SDBs_->updateWorkingWeights(false);
     352             : 
     353             : 
     354           0 :       chiSquare2();
     355           0 :       if (chiSq()==0.0) {
     356           0 :         cout << "CHI2 IS SPURIOUSLY ZERO!*************************************" << endl;
     357             :         //cout << "R() = " << R() << endl;
     358             :         //      cout << "sum(wtmat) = " << sum(wtmat) << endl;
     359           0 :         return False;
     360             :       }
     361             : 
     362           0 :       dChiSq() = chiSq()-lastChiSq();
     363             : 
     364             :       //cout << "iter=" << iter << " X2=" << chiSq() << " dX2=" << dChiSq() << " dX2/X2=" << dChiSq()/chiSq(); // << endl;
     365             :       
     366             :       // Continuue if we haven't converged
     367           0 :       if (!converged()) {
     368             :         
     369             :         //if (dChiSq()<=0.0) {
     370             :         if (true || dChiSq()<=0.0) {
     371             :           // last step was good...
     372           0 :           lastChiSq()=chiSq();
     373             :           
     374             :           // so accumulate new grad/hess...
     375           0 :           accGradHess2();
     376             :           
     377             :           //...and adjust lambda downward
     378             :           //    lambda()/=2.0;
     379             :           //    lambda()=0.8;
     380           0 :           lambda()=1.0;
     381             :         }
     382             :         else {
     383             :           //      cout << "reverting..." << chiSq() << " " << dChiSq() << " (" << iter << ")" << endl;
     384             :           // last step was bad, revert to previous 
     385             :           revert();
     386             :           //...with a larger lambda
     387             :           //    lambda()*=4.0;
     388             :           lambda()=1.0;
     389             :         }
     390             :         
     391             :         // Solve for the parameter step
     392           0 :         solveGradHess();
     393             :         
     394             :         // Remember curr pars
     395           0 :         lastPar()=par();
     396             : 
     397             :         // Refine the step size by exploring chi2 in the
     398             :         //  gradient direction
     399           0 :         if (optstep_ && !doL1_) //  && cvrgcount_>=3)
     400           0 :           optStepSize2();
     401             :         
     402             :         // Update current parameters (saves a copy of them)
     403           0 :         updatePar();
     404             : 
     405           0 :         steplist(iter)=max(amplitude(dpar()));
     406           0 :         rsteplist(iter)=max(amplitude(dpar())/amplitude(par()));
     407             : 
     408             :         //cout << "  rstep=" << rsteplist(iter) << endl;
     409             : 
     410             :       }
     411             :       else {
     412             : 
     413             :         // Convergence means we're done, NOMINALLY
     414           0 :         done=True;
     415             : 
     416             :         // Override convergence if we need to solve again with
     417             :         //  revised weight/flag conditions for robustness
     418           0 :         if (doL1_ && L1iter<Int(L1clamp_.nelements())-1) {
     419             :           //cout << "*~*~*~*~*~*~* Converged w/ L1clamp = " << L1clamp_(L1iter) << " *~*~*~*~*~*~*~*~*~*~*~*~*~*~*~*~*" << endl;
     420           0 :           done=false;
     421           0 :           ++L1iter;
     422           0 :           iter=-1;
     423           0 :           cvrgcount_=0;
     424           0 :           lastChiSq()=DBL_MAX;
     425             :         }
     426           0 :         else if (doRMSThresh_ && IRiter<nRMSThresh_) {
     427             :           //cout << "*~*~*~*~*~*~* Applying RMSThresh = " << RMSThresh_(IRiter) << " *~*~*~*~*~*~*~*~*~*~*~*~*~*~*~*~*" << endl;
     428           0 :           RMSThresh(IRiter);
     429           0 :           ++IRiter;
     430           0 :           applyWorkingFlags=true;  // force apply of the RMSThresh'd flags at the top of loop _after_ differentiation
     431           0 :           done=false;
     432           0 :           L1iter=0;
     433           0 :           iter=-1;
     434           0 :           cvrgcount_=0;
     435           0 :           lastChiSq()=DBL_MAX;
     436             :         }
     437             : 
     438             :         // If still done (robustness options absent or exhausted), escape solve loop
     439           0 :         if (done) {
     440             : 
     441           0 :           if (prtlev()>0) {
     442           0 :             cout << "par()=" << par() << endl;
     443             :           }
     444             : 
     445             :           /*
     446             :           cout << " good pars=" << ntrue(parOK())
     447             :                << " iterations=" << iter << endl
     448             :                << " steps=" << steplist(IPosition(1,0),IPosition(1,iter)) 
     449             :                << endl
     450             :                << " rsteps=" << rsteplist(IPosition(1,0),IPosition(1,iter)) 
     451             :                << endl;
     452             :           */
     453             : 
     454             :           // Get parameter errors:
     455           0 :           accGradHess2();
     456           0 :           getErrors();
     457             :           
     458             :           // Return, signaling success if at least 1 good solution
     459           0 :           return (ntrue(parOK())>0);
     460             :         }
     461             : 
     462             :       }  // converged?
     463             :       
     464             :       // Escape iteration loop via iteration limit
     465           0 :       if (iter==maxIter()) {
     466           0 :         cout << "Reached iteration limit: " << iter << " iterations.  " << endl;
     467             :         //      cout << " good pars = " << ntrue(parOK())
     468             :         //           << "  steps = " << steplist
     469             :         //           << endl;
     470           0 :         done=True;
     471             :       }
     472             :       
     473             :       // Advance iteration counter
     474           0 :       iter++;
     475             :     }
     476             :     
     477             :   }
     478             :   else {
     479           0 :     cout << " Insufficient unflagged antennas to proceed with this solve." << endl;
     480             :   }
     481             : 
     482           0 :   return False;
     483             :     
     484           0 : }
     485             :   
     486       17184 : void VisCalSolver2::initSolve() {
     487             :     
     488       17184 :   if (prtlev()>2) cout << " VCS2::initSolve()" << endl;
     489             : 
     490             :   // Get total number of cal parameters from svc info
     491       17184 :   nPar()=svc().nTotalPar();
     492             : 
     493       17184 :   if (prtlev()>2)
     494           0 :     cout << "  Total parameters in solve: " << nPar() << endl;
     495             : 
     496             :   // Chi2 and weights
     497       17184 :   chiSq()=0.0;
     498       17184 :   lastChiSq()=DBL_MAX;
     499       17184 :   dChiSq()=0.0;
     500             :     
     501       17184 :   sumWt()=0.0;
     502       17184 :   nWt()=0;
     503             : 
     504             :   // Link up svc's internal pars with local reference
     505             :   //   (only if shape is correct)
     506             : 
     507       17184 :   if (svc().solveCPar().nelements()==uInt(nPar())) {
     508       17184 :     par().reference(svc().solveCPar().reform(IPosition(1,nPar())));
     509       17184 :     parOK().reference(svc().solveParOK().reform(IPosition(1,nPar())));
     510       17184 :     parErr().reference(svc().solveParErr().reform(IPosition(1,nPar())));
     511             :   }
     512             :   else
     513           0 :     throw(AipsError("Solver and SVC cannot synchronize parameters."));
     514             : 
     515             :   // Pars
     516             : 
     517       17184 :   dpar().resize(nPar());
     518       17184 :   dpar()=0.0;
     519             : 
     520       17184 :   lastPar().resize(nPar());
     521             : 
     522             :   // Gradient and Hessian
     523       17184 :   grad().resize(nPar());
     524       17184 :   grad()=0.0;
     525             : 
     526       17184 :   hess().resize(nPar());
     527       17184 :   hess()=0.0;
     528             : 
     529             :   // Levenberg-Marquardt factor
     530       17184 :   lambda()=2.0;
     531             : 
     532             :   // Convergence anticipation
     533       17184 :   cvrgcount_=0;
     534             : 
     535       17184 : }
     536             : 
     537      177187 : void VisCalSolver2::residualate2() {
     538             : 
     539             :   //  if (prtlev()>2) cout << "   VCS2::residualate()" << endl;
     540             : 
     541             :   // For now, just use ve.diffResid, until we have
     542             :   //  implemented focuschan-aware trial corrupt in SVC
     543             :   //  (this will hurt performance a bit)
     544             : 
     545      416014 :   for (Int isdb=0;isdb<sdbs().nSDB();++isdb) 
     546      238827 :     ve().differentiate(sdbs()(isdb));
     547      177187 : }
     548             : 
     549       77145 : void VisCalSolver2::differentiate2() {
     550             : 
     551       77145 :   if (prtlev()>2) cout << "  VCS2::differentiate(SDB version)" << endl;
     552             : 
     553             :   // TBD:  Should this be packaged in the SolveDataBuffer such
     554             :   //  that is could be called there with a reference to the svc()?
     555             :   //  Eg:
     556             :   //  sdbs().differentiate(svc());  // an aggregate method in SDBList
     557             :   //   ...which then does:
     558             :   //      svc.differentiate(this)  (for each SDB)
     559             :   //
     560             :   //  also consider whether VE is in the loop here?
     561             :   //
     562             :   //  (don't wind this up in a way that makes it harder to extend....)
     563             : 
     564             :   // Delegate to VisEquation
     565      168085 :   for (Int isdb=0;isdb<sdbs().nSDB();++isdb)
     566       90940 :     ve().differentiate(sdbs()(isdb));
     567             : 
     568       77145 : }
     569             : 
     570      254332 : void VisCalSolver2::chiSquare2() {
     571             : 
     572      254332 :   if (prtlev()>2) cout << "   VCS2::chiSquare(SDB version)" << endl;
     573             : 
     574             :   // TBD: per-ant/bln chiSq?
     575             : 
     576      254332 :   chiSq()=0.0;
     577      254332 :   chiSqV()=0.0;
     578      254332 :   sumWt()=0.0;
     579      254332 :   sumWtV()=0.0;
     580      254332 :   nWt()=0;
     581             : 
     582      254332 :   Cube<Complex> R;
     583             : 
     584             :   // Loop over SDBs
     585      584099 :   for (Int isdb=0;isdb<sdbs().nSDB();++isdb) {
     586             : 
     587             :     // Current SDB
     588      329767 :     SolveDataBuffer& sdb(sdbs()(isdb));
     589      329767 :     R.reference(sdb.residuals());
     590             : 
     591             :     // _const_ access to working flags and weights
     592      329767 :     const Cube<Bool>& wFC(sdb.const_workingFlagCube());
     593      329767 :     const Cube<Float>& wWS(sdb.const_workingWtSpec());
     594             : 
     595             :     // Shapes for iteration
     596      329767 :     IPosition shR(R.shape());
     597      329767 :     Int nCorr=shR(0);
     598      329767 :     Int nChan=shR(1);
     599      329767 :     Int nRow=shR(2);
     600             : 
     601             :     // Simple indexed accumulation of chiSq
     602             :     //  TBD: optimize w.r.t. indexing?
     603      329767 :     Double chisq0(0.0);
     604    21809974 :     for (Int irow=0;irow<nRow;++irow) { 
     605    21480207 :       if (!sdb.flagRow()(irow)) {
     606    42960414 :         for (Int ich=0;ich<nChan;++ich) {
     607    92104019 :           for (Int icorr=0;icorr<nCorr;++icorr) {
     608             :             //if (!sdb.residFlagCube()(icorr,ich,irow)) {    // OLD: residFlagCube
     609    70623812 :             const Bool& fl(wFC(icorr,ich,irow));             // NEW: workingFlagCube  CORRECT?
     610    70623812 :             if (!fl) {      
     611    70009102 :               const Float& wt(wWS(icorr,ich,irow));
     612    70009102 :               if (wt>0.0) {
     613    59612014 :                 Complex& Ri(R(icorr,ich,irow));
     614             : 
     615             :                 // This element's contribution
     616    59612014 :                 chisq0=Double(wt*real(Ri*conj(Ri)));  //  cf:  square(abs(R))?  
     617             : 
     618             :                 // Accumulate per-corr
     619    59612014 :                 chiSqV()(icorr)+=chisq0;
     620    59612014 :                 sumWtV()(icorr)+=wt;
     621    59612014 :                 nWt()++;
     622             :               }  // wt>0     
     623             :             } // !flag
     624             :           } // icorr
     625             :         } // ich
     626             :       } // !flagRow
     627             :     } // irow
     628             : 
     629      329767 :   } // isdb
     630             : 
     631             :   //cout << "chiSqV() = " << chiSqV() << endl;
     632             : 
     633             :   // Totals over corrs
     634      254332 :   chiSq()=sum(chiSqV());  
     635      254332 :   sumWt()=sum(sumWtV());  
     636             : 
     637      254332 : }
     638             : 
     639             : // RMS calculation (for thresholding)
     640           0 : void VisCalSolver2::RMSThresh(Int RejIter) {
     641             : 
     642           0 :   if (prtlev()>2) cout << "   VCS2::RMS(SDB version)" << endl;
     643             : 
     644           0 :   const Float threshold(RMSThresh_(RejIter));
     645           0 :   Bool dolog=(RejIter==nRMSThresh_-1);
     646             : 
     647             :   // TBD: per-ant/bln chiSq?
     648             : 
     649           0 :   Int nCorr=sdbs().nCorrelations();
     650           0 :   Vector<Double> xxV(nCorr,0.0);
     651           0 :   Vector<Double> sWtV(nCorr,0.0);
     652             : 
     653           0 :   Cube<Complex> R;
     654             : 
     655             :   // Loop over SDBs
     656           0 :   for (Int isdb=0;isdb<sdbs().nSDB();++isdb) {
     657             : 
     658             :     // Current SDB
     659           0 :     SolveDataBuffer& sdb(sdbs()(isdb));
     660           0 :     R.reference(sdb.residuals());
     661             : 
     662             :     // Shapes for iteration
     663           0 :     IPosition shR(R.shape());
     664           0 :     Int nCorr=shR(0);
     665           0 :     Int nChan=shR(1);
     666           0 :     Int nRow=shR(2);
     667             : 
     668           0 :     const Cube<Bool>& wFC(sdb.const_workingFlagCube());
     669             : 
     670             :     // Simple indexed accumulation of XX
     671           0 :     Double xx0(0.0);
     672           0 :     for (Int irow=0;irow<nRow;++irow) { 
     673           0 :       if (!sdb.flagRow()(irow)) {
     674           0 :         for (Int ich=0;ich<nChan;++ich) {
     675           0 :           for (Int icorr=0;icorr<nCorr;++icorr) {
     676           0 :             if (!wFC(icorr,ich,irow)) { 
     677           0 :               Float& wt(sdb.infocusWtSpec()(icorr,ich,irow));
     678           0 :               if (wt>0.0) {
     679           0 :                 Complex& Ri(R(icorr,ich,irow));
     680             :                 
     681             :                 // This element's contribution
     682           0 :                 xx0=Double(wt*real(Ri*conj(Ri)));  //  cf:  square(abs(R))?  
     683             :                 
     684             :                 // Accumulate per-corr
     685           0 :                 xxV(icorr)+=xx0;
     686           0 :                 sWtV(icorr)+=wt;
     687             :               }  // wt>0     
     688             :             } // !flag
     689             :           } // icorr
     690             :         } // ich
     691             :       } // !flagRow
     692             :     } // irow
     693             :     
     694           0 :   } // isdb
     695             : 
     696           0 :   Vector<Float> rmsV(nCorr,0.0);
     697           0 :   for (Int icorr=0;icorr<nCorr;++icorr) {
     698           0 :     if (sWtV(icorr)>0.0)
     699           0 :       rmsV(icorr)=Float(sqrt(xxV(icorr)/sWtV(icorr)));
     700             :   }
     701             : 
     702             :   // Now Apply the threshold
     703             : 
     704           0 :   LogIO logsink;
     705             : 
     706             :   // Loop over SDBs
     707           0 :   for (Int isdb=0;isdb<sdbs().nSDB();++isdb) {
     708             : 
     709             :     // Current SDB
     710           0 :     SolveDataBuffer& sdb(sdbs()(isdb));
     711           0 :     R.reference(sdb.residuals());
     712             : 
     713             :     // Initialize wFC afresh
     714           0 :     sdb.workingFlagCube().resize(0,0,0);
     715           0 :     sdb.workingFlagCube().assign(sdb.residFlagCube());
     716             : 
     717             :     // Shapes for iteration
     718           0 :     IPosition shR(R.shape());
     719           0 :     Int nCorr=shR(0);
     720           0 :     Int nChan=shR(1);
     721           0 :     Int nRow=shR(2);
     722             : 
     723           0 :     for (Int irow=0;irow<nRow;++irow) { 
     724           0 :       if (!sdb.flagRow()(irow)) {
     725           0 :         for (Int ich=0;ich<nChan;++ich) {
     726           0 :           for (Int icorr=0;icorr<nCorr;++icorr) {
     727           0 :             if (!sdb.residFlagCube()(icorr,ich,irow)) { 
     728           0 :               Float& wt(sdb.infocusWtSpec()(icorr,ich,irow));
     729           0 :               if (wt>0.0) {
     730           0 :                 Float Ra(abs(R(icorr,ich,irow)));
     731           0 :                 if (Ra>(threshold*rmsV(icorr))) {
     732           0 :                   sdb.workingFlagCube()(icorr,ich,irow)=true;
     733             :                   //sdb.workingWtSpec()(icorr,ich,irow)=0.0;
     734             :                   
     735           0 :                   if (dolog) // only on last go-round, report what baselines have been flagged
     736           0 :                     logsink << "Rejected outlier at: " << MVTime(sdb.time()(irow)/C::day).string(MVTime::YMD,7)
     737           0 :                             << " spw=" << sdb.spectralWindow()(irow) 
     738           0 :                             << " BL=" << sdb.antenna1()(irow) << "-" << sdb.antenna2()(irow)
     739             :                             << " corr=" << icorr
     740           0 :                             << ":  residual=" << Ra/rmsV(icorr) << "sigma" << " (threshold=" << threshold << ")" << LogIO::POST;
     741             : 
     742             :                 }
     743             :               }  // wt>0     
     744             :             } // !flag
     745             :           } // icorr
     746             :         } // ich
     747             :       } // !flagRow
     748             :     } // irow
     749             :     
     750           0 :   } // isdb
     751             : 
     752           0 : }
     753             : 
     754             : 
     755             : 
     756       77145 : Bool VisCalSolver2::converged() {
     757             : 
     758       77145 :   if (prtlev()>2) cout << "    VCS2::converged()" << endl;
     759             : 
     760             :   // Change in chi2
     761       77145 :   dChiSq() = chiSq()-lastChiSq();
     762       77145 :   Float fChiSq(dChiSq()/chiSq());
     763             : 
     764             :   // Consider convergence if chi2 decreases...
     765             :   //  if (dChiSq()<=0.0) {
     766       77145 :   if (fChiSq<=0.001) {
     767             : 
     768             :     // ...and the change is small:
     769       77145 :     if (abs(dChiSq()) < 0.1*chiSq()) {
     770       56484 :       ++cvrgcount_;
     771             : 
     772             :       //      if (cvrgcount_==2) lambda()=2.0;
     773             : 
     774             :     }
     775             :     
     776       77145 :     if (prtlev()>0)
     777           0 :       cout << "     Good: chiSq=" << chiSq()
     778           0 :            << " dChiSq=" << dChiSq()
     779           0 :            << " fChiSq=" << dChiSq()/chiSq()
     780           0 :            << " cvrgcnt=" << cvrgcount_
     781           0 :            << " lambda=" << lambda()
     782           0 :            << endl;
     783             : 
     784             : 
     785             :     // Five such steps we believe we have converged!
     786       77145 :     if (cvrgcount_>5)
     787        9414 :       return True;
     788             :      
     789             :   }
     790             :   else {
     791             :     // (chi2 increased)
     792             : 
     793             :     // If a large increase, don't anticipate yet
     794           0 :     if (abs(dChiSq()) > 0.1*chiSq())
     795           0 :       cvrgcount_=0;
     796             :     else {
     797             :       // anticipate a little less if upward change is small
     798             :       //  TBD:  is this right?
     799           0 :       --cvrgcount_;
     800           0 :       cvrgcount_=max(cvrgcount_,0);  // never less than zero
     801             :     }
     802             : 
     803           0 :     if (prtlev()>0)
     804           0 :       cout << "     Bad:  chiSq=" << chiSq()
     805           0 :            << " dChiSq=" << dChiSq()
     806           0 :            << " fChiSq=" << dChiSq()/chiSq()
     807           0 :            << " cvrgcnt=" << cvrgcount_
     808           0 :            << " lambda=" << lambda()
     809           0 :            << endl;
     810             : 
     811             : 
     812             :   }
     813             : 
     814             :   // Not yet converged
     815       67731 :   return False;
     816             : 
     817             : }
     818             : 
     819       76468 : void VisCalSolver2::accGradHess2() {
     820             : 
     821       76468 :   if (prtlev()>2) cout << "     VCS2::accGradHess(SDB version)" << endl;
     822             : 
     823       76468 :   grad()=0.0;
     824       76468 :   hess()=0.0;
     825             : 
     826       76468 :   Cube<Complex> R;
     827       76468 :   Array<Complex> dR;
     828             : 
     829             :   // Loop over SDBs
     830      166570 :   for (Int isdb=0;isdb<sdbs().nSDB();++isdb) {
     831             : 
     832             :     // Current SDB
     833       90102 :     SolveDataBuffer& sdb(sdbs()(isdb));
     834             : 
     835       90102 :     R.reference(sdb.residuals());
     836       90102 :     dR.reference(sdb.diffResiduals());
     837             :     
     838       90102 :     const Cube<Float>& wWS(sdb.const_workingWtSpec());
     839       90102 :     const Cube<Bool>& wFC(sdb.const_workingFlagCube());
     840             : 
     841       90102 :     IPosition dRip(dR.shape());
     842             :     
     843       90102 :     Int nRow(dRip(3));
     844       90102 :     Int nChan(dRip(2));
     845       90102 :     Int nParPerAnt(dRip(1));   // pars per antenna
     846       90102 :     Int nCorr(dRip(0));
     847             : 
     848             :     // Simple indexed accumulation
     849     6153352 :     for (Int irow=0;irow<nRow;++irow) {
     850     6063250 :       if (!sdb.flagRow()(irow)) {
     851     6063250 :         Int a1i= nParPerAnt*sdb.antenna1()(irow);
     852     6063250 :         Int a2i= nParPerAnt*sdb.antenna2()(irow);
     853    12126500 :         for (Int ichan=0;ichan<nChan;++ichan) {
     854    25688810 :           for (int icorr=0;icorr<nCorr;++icorr) {
     855             :             //if (!sdb.residFlagCube()(icorr,ichan,irow)) {  // OLD: residFlagCube
     856    19625560 :             const Bool& fl(wFC(icorr,ichan,irow));             // NEW: workingFlagCube  CORRECT?
     857    19625560 :             if (!fl) {      
     858    19440098 :               const Float& wt(wWS(icorr,ichan,irow));
     859    19440098 :               if (wt>0.0) {
     860    17181466 :                 Complex& Ri(R(icorr,ichan,irow));
     861    41093238 :                 for (Int ipar=0;ipar<nParPerAnt;++ipar) {
     862             : 
     863             :                   // Accumulate grad and hess for this icorr,ichan,irow,ipar
     864             :                   // for a1:
     865    23911772 :                   Complex& dR1(dR(IPosition(5,icorr,ipar,ichan,irow,0)));
     866    23911772 :                   grad()(a1i+ipar)+= DComplex(wt*(Ri*conj(dR1)));
     867    23911772 :                   hess()(a1i+ipar)+= Double(wt*real(dR1*conj(dR1)));
     868             :                   // for a2:
     869    23911772 :                   Complex& dR2(dR(IPosition(5,icorr,ipar,ichan,irow,1)));
     870    23911772 :                   grad()(a2i+ipar)+= DComplex(wt*(dR2*conj(Ri)));
     871    23911772 :                   hess()(a2i+ipar)+= Double(wt*real(dR2*conj(dR2)));
     872             : 
     873             :                 } // ipar
     874             :               } // wt>0
     875             :             } // !flag
     876             :           } // icorr
     877             :         } // ichan
     878             :       } // !flagRow
     879             :     } // irow
     880             : 
     881       90102 :   } // isdb
     882             : 
     883       76468 :   if (prtlev()>4) {  // grad, hess
     884           0 :     cout << "      grad= " << grad() << endl;
     885           0 :     cout << "      hess= " << hess() << endl;
     886             :   }    
     887             : 
     888       76468 : }
     889             : 
     890         677 : void VisCalSolver2::revert() {
     891             : 
     892         677 :   if (prtlev()>2) cout << "     VCS2::revert()" << endl;
     893             : 
     894             :   // Recall the last decent parameter set
     895             :   //  TBD: the OK flag?
     896         677 :   par()=lastPar();
     897             : 
     898         677 : }
     899             : 
     900       67731 : void VisCalSolver2::solveGradHess() {
     901             : 
     902       67731 :   if (prtlev()>2) cout << "      VCS2::solveGradHess()" << endl;
     903             : 
     904             :   // TBD: explicit option to avoid lmfact?
     905             :   // TBD: pointer (or MaskedArray?) optimization?
     906             : 
     907       67731 :   Double lmfact(1.0+lambda());
     908             : 
     909       67731 :   lmfact=2.0;
     910             : 
     911       67731 :   dpar()=Complex(0.0);
     912     1120767 :   for (Int ipar=0; ipar<nPar(); ipar++) {
     913     1053036 :     if ( parOK()(ipar) && hess()(ipar)!=0.0) {
     914             :       // good hess for this par:
     915     1033166 :       dpar()(ipar) = grad()(ipar)/hess()(ipar);
     916     1033166 :       dpar()(ipar)/=lmfact;
     917             :     }
     918             :     else {
     919       19870 :       dpar()(ipar)=0.0; 
     920       19870 :       parOK()(ipar)=False;
     921             :     }
     922             :   }
     923             :   
     924             :   // Negate (so updatePar() can _add_)
     925       67731 :   dpar()*=Complex(-1.0f);
     926             : 
     927       67731 : }
     928             : 
     929       67731 : void VisCalSolver2::updatePar() {
     930             : 
     931       67731 :   if (prtlev()>2) cout << "       VCS2::updatePar()" << endl;
     932             : 
     933             :   //  if (prtlev()>4) cout << "        old =" << par() << endl;
     934             : 
     935             :   //  if (prtlev()>4) cout << "        dpar=" << dpar() << endl;
     936             : 
     937             : 
     938             : 
     939             :   // Tell svc to update the par 
     940             :   //   (permits svc() to condition the current solutions)
     941       67731 :   svc().updatePar(dpar());
     942             : 
     943       67731 :   if (prtlev()>4) {
     944           0 :     cout << "        abs(dpar()) = " << amplitude(dpar()) << endl;
     945           0 :     cout << "        new amp = " << amplitude(par()) << endl
     946           0 :          << "            pha = " << phase(par()) << endl;
     947             :   }
     948             : 
     949       67731 : }
     950             : 
     951             : 
     952       67731 : void VisCalSolver2::optStepSize2() {
     953             : 
     954       67731 :   if (prtlev()>2) cout << "  VCS2::optStepSize2(SDB version)" << endl;
     955             : 
     956       67731 :   Vector<Double> x2(3,-999.0);
     957       67731 :   Float step(1.0);
     958             :   
     959             :   // Starting point is curr chiSq
     960       67731 :   x2(0)=chiSq();
     961             : 
     962             :   // take nominal step
     963       67731 :   par()+=dpar();  
     964       67731 :   residualate2();
     965       67731 :   chiSquare2();
     966       67731 :   x2(1)=chiSq();
     967             : 
     968             :   // If nominal step is an improvement...
     969       67731 :   if (x2(1)<x2(0)) {
     970             : 
     971             :     // ...double step size until x2 starts increasing
     972       63913 :     par()=dpar(); par()*=Complex(2.0*step); par()+=lastPar();
     973       63913 :     residualate2();
     974       63913 :     chiSquare2();
     975       63913 :     x2(2)=chiSq();
     976       63913 :     if (prtlev()>4)
     977           0 :       cout <<   "  down:    " << step << " " << x2-x2(0) << LogicalArray(x2>=x2(0)) <<endl;
     978       93900 :     while (x2(2)<x2(1)) {    //  && step<4.0) {
     979       29987 :       step*=2.0;
     980       29987 :       par()=dpar(); par()*=Complex(2.0*step); par()+=lastPar();
     981       29987 :       residualate2();
     982       29987 :       chiSquare2();
     983       29987 :       x2(1)=x2(2);
     984       29987 :       x2(2)=chiSq();
     985       29987 :       if (prtlev()>4)
     986           0 :         cout << "  stretch: " << step << " " << x2-x2(0) << LogicalArray(x2>=x2(0)) <<endl;
     987             : 
     988             :     }
     989             :   }
     990             :   // else nominal step too big, so...
     991             :   else {
     992             : 
     993             :     // ... contract by halves until we bracket a minimum
     994        3818 :     step*=0.5;
     995        3818 :     par()=dpar(); par()*=Complex(step); par()+=lastPar();
     996        3818 :     residualate2();
     997        3818 :     chiSquare2();
     998        3818 :     x2(2)=x2(1);
     999        3818 :     x2(1)=chiSq();
    1000        3818 :     if (prtlev()>4)
    1001           0 :       cout <<   "  up:       " << step << " " << x2-x2(0) << LogicalArray(x2>=x2(0)) <<endl;
    1002       15556 :     while (x2(1)>x2(0)) { //  && step>0.125) {
    1003       11738 :       step*=0.5;
    1004       11738 :       par()=dpar(); par()*=Complex(step); par()+=lastPar();
    1005       11738 :       residualate2();
    1006       11738 :       chiSquare2();
    1007       11738 :       x2(2)=x2(1);
    1008       11738 :       x2(1)=chiSq();
    1009       11738 :       if (prtlev()>4)
    1010           0 :         cout << "  contract: " << step << " " << x2-x2(0) << LogicalArray(x2>=x2(0)) <<endl;
    1011             :     }
    1012             : 
    1013             :   }
    1014             : 
    1015             :   // At this point   x2(0) > x2(1) < x2(2), so 
    1016             :   //   calculate (quadratic) step optimization factor
    1017       67731 :   Double optfactor(0.0);
    1018       67731 :   Double optn(x2(2)-x2(1));
    1019       67731 :   Double optd(x2(0)-2*x2(1)+x2(2));
    1020             :               
    1021       67731 :   if (abs(optd)>0.0)
    1022       67692 :     optfactor=Double(step)*(1.5-optn/optd);
    1023             :   
    1024             :   /*  
    1025             :     cout << "Optimization: " 
    1026             :        << step << " " 
    1027             :        << optfactor << " "
    1028             :        << x2 << " "
    1029             :        << "(" << min(amplitude(lastPar())) << ") "
    1030             :        << max(amplitude(dpar())/amplitude(lastPar()))*180.0/C::pi << " ";
    1031             :   */
    1032             : 
    1033             : 
    1034       67731 :   if (prtlev()>4) cout << "   optfactor=" << optfactor << endl;
    1035             : 
    1036             : 
    1037       67731 :   par()=lastPar();
    1038             :   
    1039             :   // Adjust step by the optfactor
    1040       67731 :   if (optfactor>0.0)
    1041       67692 :     dpar()*=Complex(optfactor);
    1042             : 
    1043             :   /*
    1044             :   cout << max(amplitude(dpar())/amplitude(lastPar()))*180.0/C::pi
    1045             :        << endl;
    1046             :   */
    1047       67731 : }
    1048             : 
    1049        9414 : void VisCalSolver2::getErrors() {
    1050             : 
    1051             :   // Number of *REAL* dof
    1052             :   //  Int nDOF=2*(nWt()-ntrue(parOK()));  // !!!! this is zero for 3 antennas!
    1053        9414 :   Int nDOF=max(2*(nWt()-ntrue(parOK())), 1u);
    1054             : 
    1055        9414 :   Double k2=chiSq()/Double(nDOF);
    1056             : 
    1057        9414 :   parErr()=0.0;
    1058      154342 :   for (Int i=0;i<nPar();++i) 
    1059      144928 :     if (hess()(i)>0.0) {
    1060      142172 :       parErr()(i)=1.0/sqrt(hess()(i)/k2/2.0);   // 2 is from def of Hess!
    1061             :     }
    1062             : 
    1063             : 
    1064        9414 :   if (prtlev()>4) {
    1065             : 
    1066           0 :     cout << "ChiSq  = " << chiSq() << endl;
    1067           0 :     cout << "ChiSqV = " << chiSqV() << endl;
    1068           0 :     cout << "sumWt  = " << sumWt() << endl;
    1069           0 :     cout << "nWt    = " << nWt()
    1070           0 :          << "; nPar() = " << nPar() 
    1071           0 :          << "; nParOK = " << ntrue(parOK())
    1072           0 :          << "; nDOF = " << nDOF 
    1073           0 :          << endl;
    1074             :     
    1075           0 :     cout << "rChiSq = " << k2 << endl;
    1076           0 :     cout << "max(dpar) = " << max(amplitude(dpar())) << endl;
    1077           0 :     cout << "Amplitudes = " << amplitude(par()) << endl;
    1078           0 :     cout << "Errors     = " << parErr() << endl;
    1079             :     //    cout << "Errors = " << mean(parErr()(parOK())) << endl;
    1080             :     
    1081             :   }
    1082        9414 : }
    1083             : 
    1084             : 
    1085             : } //# NAMESPACE CASA - END
    1086             : 

Generated by: LCOV version 1.16