LCOV - code coverage report
Current view: top level - msvis/MSVis - StatWT.cc (source / functions) Hit Total Coverage
Test: casacpp_coverage.info Lines: 0 147 0.0 %
Date: 2024-11-06 17:42:47 Functions: 0 6 0.0 %

          Line data    Source code
       1             : //# StatWT.cc: Subtract the continuum from VisBuffGroups and
       2             : //# write them to a different MS.
       3             : //# Copyright (C) 2011
       4             : //# Associated Universities, Inc. Washington DC, USA.
       5             : //#
       6             : //# This library is free software; you can redistribute it and/or modify it
       7             : //# under the terms of the GNU Library General Public License as published by
       8             : //# the Free Software Foundation; either version 2 of the License, or (at your
       9             : //# option) any later version.
      10             : //#
      11             : //# This library is distributed in the hope that it will be useful, but WITHOUT
      12             : //# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
      13             : //# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Library General Public
      14             : //# License for more details.
      15             : //#
      16             : //# You should have received a copy of the GNU Library General Public License
      17             : //# along with this library; if not, write to the Free Software Foundation,
      18             : //# Inc., 675 Massachusetts Ave, Cambridge, MA 02139, USA.
      19             : //#
      20             : //# Correspondence concerning AIPS++ should be addressed as follows:
      21             : //#        Internet email: casa-feedback@nrao.edu.
      22             : //#        Postal address: AIPS++ Project Office
      23             : //#                        National Radio Astronomy Observatory
      24             : //#                        520 Edgemont Road
      25             : //#                        Charlottesville, VA 22903-2475 USA
      26             : //#
      27             : 
      28             : #include <msvis/MSVis/StatWT.h>
      29             : //#include <msvis/MSVis/SubMS.h>
      30             : #include <msvis/MSVis/VisBufferComponents.h>
      31             : #include <msvis/MSVis/VisBuffGroup.h>
      32             : #include <msvis/MSVis/VisBuffGroupAcc.h>
      33             : #include <casacore/casa/Exceptions/Error.h>
      34             : #include <casacore/casa/Logging/LogIO.h>
      35             : #include <casacore/ms/MSSel/MSSelection.h>
      36             : #include <casacore/casa/Arrays/ArrayMath.h>
      37             : 
      38             : using namespace casacore;
      39             : namespace casa {
      40             : 
      41           0 : StatWT::StatWT(const ROVisibilityIterator& vi,
      42             :                const MS::PredefinedColumns datacol,
      43             :                const String& fitspw,
      44             :                const String& outspw,
      45             :                const Bool dorms,
      46             :                const uInt minsamp,
      47           0 :                const vector<uInt> selcorrs) :
      48             :   GroupWorker(vi),
      49           0 :   datacol_p(datacol),
      50           0 :   fitspw_p(fitspw),
      51           0 :   outspw_p(outspw),
      52           0 :   dorms_p(dorms),
      53           0 :   rowsdone_p(0)
      54             : {
      55           0 :   LogIO os(LogOrigin("StatWT", "StatWT()"));
      56             : 
      57           0 :   if(dorms && minsamp < 1){
      58             :     os << LogIO::WARN
      59             :        << "It takes at least one to measure an rms - using minsamp = 1."
      60           0 :        << LogIO::POST;
      61           0 :     minsamp_p = 1;
      62             :   }
      63           0 :   else if(!dorms && minsamp < 2){
      64             :     os << LogIO::WARN
      65             :        << "It takes at least two to measure a variance - using minsamp = 2."
      66           0 :        << LogIO::POST;
      67           0 :     minsamp_p = 2;
      68             :   }
      69             :   else
      70           0 :     minsamp_p = minsamp;
      71             : 
      72           0 :   prefetchColumns_p = asyncio::PrefetchColumns::prefetchColumns(
      73             :                                   VisBufferComponents::Ant1,
      74             :                                   VisBufferComponents::Ant2,
      75             :                                   VisBufferComponents::ArrayId,
      76             :                                   VisBufferComponents::CorrType,
      77             :                                   VisBufferComponents::DataDescriptionId,
      78             :                                   //VisBufferComponents::Feed1,
      79             :                                   //VisBufferComponents::Feed2,
      80             :                                   VisBufferComponents::FieldId,
      81             :                                   VisBufferComponents::FlagCube,
      82             :                                   VisBufferComponents::Flag,
      83             :                                   VisBufferComponents::FlagRow,
      84             :                                   VisBufferComponents::ObservationId,
      85             :                                   //VisBufferComponents::NChannel,
      86             :                                   VisBufferComponents::NCorr,
      87             :                                   VisBufferComponents::NRow,
      88             :                                   //VisBufferComponents::ProcessorId,
      89             :                                   VisBufferComponents::Scan,
      90             :                                   VisBufferComponents::SpW,
      91             :                                   VisBufferComponents::SigmaMat,
      92             :                                   VisBufferComponents::StateId,
      93             :                                   //VisBufferComponents::Time,
      94             :                                   //VisBufferComponents::TimeCentroid,
      95             :                                   //VisBufferComponents::TimeInterval,
      96             :                                   VisBufferComponents::WeightMat,
      97           0 :                                   -1);
      98           0 :   if(datacol == MS::DATA)
      99           0 :     prefetchColumns_p.insert(VisBufferComponents::ObservedCube);
     100           0 :   else if(datacol == MS::MODEL_DATA)
     101           0 :     prefetchColumns_p.insert(VisBufferComponents::ModelCube);
     102           0 :   else if(datacol == MS::CORRECTED_DATA)
     103           0 :     prefetchColumns_p.insert(VisBufferComponents::CorrectedCube);
     104             :   //  else if(datacol == MS::FLOAT_DATA)    // Not in VisBufferComponents yet.
     105             :   //  prefetchColumns_p.insert(VisBufferComponents::FloatCube);
     106             : 
     107           0 :   VisBuffGroupAcc::fillChanMask(fitmask_p, fitspw, invi_p.ms());
     108             : 
     109           0 :   MSSelection mssel;
     110           0 :   mssel.setSpwExpr(outspw);
     111           0 :   Matrix<Int> chansel = mssel.getChanList(&(invi_p.ms()), 1);
     112           0 :   Vector<Int> spws(chansel.column(0));
     113           0 :   uInt nselspws = spws.nelements();
     114           0 :   selcorrs_p = selcorrs;
     115             : 
     116           0 :   for(uInt i = 0; i < nselspws; ++i)
     117           0 :     outspws_p.insert(spws[i]);
     118           0 : }
     119             : 
     120           0 : StatWT::~StatWT()
     121             : {
     122           0 :   VisBuffGroupAcc::clearChanMask(fitmask_p);
     123           0 : }
     124             : 
     125             : // StatWT& StatWT::operator=(const StatWT &other)
     126             : // {
     127             : //   // trivial so far.
     128             : //   vi_p = other.vi_p;
     129             : //   return *this;
     130             : // }
     131             : 
     132           0 : Bool StatWT::process(VisBuffGroup& vbg)
     133             : {
     134           0 :   LogIO os(LogOrigin("StatWT", "process()"));
     135           0 :   Bool worked = true;
     136           0 :   uInt nvbs = vbg.nBuf();
     137           0 :   Int maxAnt = 0;
     138           0 :   Int maxNCorr = 0;
     139           0 :   Int maxSpw = 0;   // VisBuffGroupAcc is 1 of those things that uses SpW when
     140             :                     // it should use DDID.
     141             :   
     142           0 :   for(uInt bufnum = 0; bufnum < nvbs; ++bufnum){
     143           0 :     if(vbg(bufnum).numberAnt() > maxAnt)        // Record maxAnt even for buffers
     144           0 :       maxAnt = vbg(bufnum).numberAnt();         // that won't be used in the fit.
     145           0 :     if(vbg(bufnum).nCorr() > maxNCorr)
     146           0 :       maxNCorr = vbg(bufnum).nCorr();
     147             : 
     148           0 :     if(bufnum > 0 && anyTrue(vbg(bufnum).corrType() != vbg(0).corrType())){
     149             :       os << LogIO::SEVERE
     150             :          << "statwt does not yet support combining data description IDs with different correlation setups."
     151           0 :          << LogIO::POST;
     152           0 :       return false;
     153             :     }
     154             : 
     155           0 :     Int spw = vbg(bufnum).spectralWindow();
     156           0 :     if(fitmask_p.count(spw) > 0){               // This requires fitspw to
     157             :                                                 // follow the '' = nothing,
     158             :                                                 // '*' = everything convention.
     159           0 :       if(spw > maxSpw)
     160           0 :         maxSpw = vbg(bufnum).spectralWindow();
     161             :     }
     162             :   }
     163             : 
     164           0 :   Cube<Bool> chanmaskedflags;
     165             : 
     166             :   // Map from hashFunction(ant1, ant2) to running number of visibilities[corr]
     167           0 :   std::map<uInt, Vector<uInt> > ns;
     168             : 
     169             :   // Map from hashFunction(ant1, ant2) to running mean[corr]
     170           0 :   std::map<uInt, Vector<Complex> > means;
     171             : 
     172             :   // Map from hashFunction(ant1, ant2) to variance[corr], initially stored as
     173             :   // running sums of squared differences.
     174           0 :   std::map<uInt, Vector<Double> > variances;
     175             : 
     176             :   // The accumulation of sums for the variances could be parallelized.
     177             :   // See Chan, Tony F.; Golub, Gene H.; LeVeque, Randall J. (1979), "Updating
     178             :   // Formulae and a Pairwise Algorithm for Computing Sample Variances.",
     179             :   // Technical Report STAN-CS-79-773, Department of Computer Science, Stanford
     180             :   // University.
     181             : 
     182           0 :   for(uInt bufnum = 0; bufnum < nvbs; ++bufnum){
     183           0 :     Int spw = vbg(bufnum).spectralWindow();
     184             : 
     185           0 :     if(fitmask_p.count(spw) > 0){
     186           0 :       VisBuffGroup::applyChanMask(chanmaskedflags, fitmask_p[spw], vbg(bufnum));
     187             : 
     188           0 :       if(!update_variances(ns, means, variances, vbg(bufnum), chanmaskedflags,
     189             :                            maxAnt))
     190           0 :         return false;
     191             :     }
     192             :   }
     193           0 :   for(std::map<uInt, Vector<Double> >::iterator it = variances.begin();
     194           0 :       it != variances.end(); ++it)
     195           0 :     for(Int corr = 0; corr < maxNCorr; ++corr)
     196           0 :       it->second[corr] /= (2.*ns[it->first][corr] - 1);
     197             : 
     198             :   // TODO
     199             :   // if(byantenna_p){
     200             :   // // The formula for the variance of antenna k is
     201             :   // // \sigma_k^2 = \frac{1}{n - 1} \sum_{i \notequal k} \left(
     202             :   // // \sigma_{ik}^2 \frac{\sum_{j \notequal i,k}^{k - 1} \sigma_{jk}^2}
     203             :   // // {\sum_{j \notequal i,k} \sigma_{ij}^2}\right)
     204             :   // // where \sigma_{ij}^2 is the already calculated variance of baseline ij.
     205             :   //
     206             :   // // So, get the antenna based variances, take their sqrts \sigma_k, and
     207             :   // // update variances to \sigma_i \sigma_j, taking sepacs_p into account all
     208             :   // // along.
     209             :   // }
     210             : 
     211             :   //uInt oldrowsdone = rowsdone_p;
     212           0 :   for(uInt bufnum = 0; bufnum < nvbs; ++bufnum){
     213           0 :     uInt spw = vbg(bufnum).spectralWindow();
     214             : 
     215           0 :     rowsdone_p += vbg(bufnum).nRow();
     216           0 :     if(outspws_p.find(spw) != outspws_p.end()){
     217           0 :       worked &= apply_variances(vbg(bufnum), ns, variances, maxAnt);      
     218             :       //cerr << "Wrote out row IDs " << oldrowsdone << " - " << rowsdone_p - 1 << ",";
     219             :     }
     220             :     //else
     221             :     //  cerr << "No output for";
     222             :     //cerr << " spw " << spw << endl;
     223             :     //oldrowsdone = rowsdone_p;
     224             : 
     225             :     // Catch outvi_p up with invi_p.
     226           0 :     if(vbg.chunkEnd(bufnum) && outvi_p.moreChunks()){
     227           0 :       outvi_p.nextChunk();
     228           0 :       outvi_p.origin();
     229             :     }
     230           0 :     else if(outvi_p.more())
     231           0 :       ++outvi_p;
     232             :   }
     233             :   
     234           0 :   return worked;
     235           0 : }
     236             : 
     237           0 : Bool StatWT::update_variances(std::map<uInt, Vector<uInt> >& ns,
     238             :                               std::map<uInt, Vector<Complex> >& means,
     239             :                               std::map<uInt, Vector<Double> >& variances,
     240             :                               const VisBuffer& vb,
     241             :                               const Cube<Bool>& chanmaskedflags, const uInt maxAnt)
     242             : {
     243           0 :   Cube<Complex> data(vb.dataCube(datacol_p));
     244             : 
     245           0 :   if(data.shape() != chanmaskedflags.shape())
     246           0 :     return false;
     247             : 
     248           0 :   Bool retval = true;
     249           0 :   uInt nCorr = data.shape()[0];
     250           0 :   uInt nChan = data.shape()[1];
     251           0 :   rownr_t nRows = data.shape()[2];
     252           0 :   Vector<uInt> unflagged(nChan);
     253           0 :   Vector<Int> a1(vb.antenna1());
     254           0 :   Vector<Int> a2(vb.antenna2());
     255             : 
     256           0 :   for(rownr_t r = 0; r < nRows; ++r){
     257           0 :     if(!vb.flagRow()[r]){
     258           0 :       uInt hr = hashFunction(a1[r], a2[r], maxAnt);
     259             :       // setup defaults, clear on all-flagged not needed as variances == 0 is
     260             :       // skipped in apply_variances
     261           0 :       if(!ns.count(hr)){
     262           0 :         ns[hr] = Vector<uInt>(nCorr, 0);
     263           0 :         means[hr] = Vector<Complex>(nCorr, 0);
     264           0 :         variances[hr] = Vector<Double>(nCorr, 0);
     265             :       }
     266           0 :       Vector<uInt> & vns = ns[hr];
     267           0 :       Vector<Complex> & vmeans = means[hr];
     268           0 :       Vector<Double> & vvariances = variances[hr];
     269             : 
     270           0 :       for(uInt corr = 0; corr < nCorr; ++corr){
     271           0 :         for(uInt ch = 0; ch < nChan; ++ch){
     272           0 :           if(!chanmaskedflags(corr, ch, r) && !vb.flagCube()(corr,ch,r)){
     273           0 :             Complex vis, vmoldmean, vmmean;
     274           0 :             ++vns[corr];
     275           0 :             vis = data(corr, ch, r);
     276           0 :             vmoldmean = vis - vmeans[corr];
     277             : 
     278           0 :             if(!dorms_p)  // It's not that Complex / Int isn't defined, it's
     279             :                           // that it is, along with Complex / Double, creating
     280             :                           // an ambiguity.
     281           0 :               vmeans[corr] += vmoldmean / static_cast<Double>(vns[corr]);
     282             : 
     283             :             // This term is guaranteed to have its parts be nonnegative.
     284           0 :             vmmean = vis - vmeans[corr];
     285           0 :             vvariances[corr] += vmmean.real() * vmoldmean.real() +
     286           0 :                                 vmmean.imag() * vmoldmean.imag();
     287             :           }
     288             :         }
     289             :       }
     290             :     }
     291             :   }
     292           0 :   return retval;
     293           0 : }
     294             : 
     295           0 : Bool StatWT::apply_variances(VisBuffer& vb,
     296             :                              std::map<uInt, Vector<uInt> >& ns,
     297             :                              std::map<uInt, Vector<Double> >& variances,
     298             :                              const uInt maxAnt)
     299             : {
     300           0 :   Bool retval = true;
     301           0 :   IPosition shp(vb.flagCube().shape());
     302           0 :   uInt nCorr = shp[0];
     303           0 :   uInt nChan = shp[1];
     304           0 :   rownr_t nRows = shp[2];
     305           0 :   Vector<Int> a1(vb.antenna1());
     306           0 :   Vector<Int> a2(vb.antenna2());
     307             : 
     308           0 :   for(rownr_t r = 0; r < nRows; ++r){
     309           0 :     uInt hr = hashFunction(a1[r], a2[r], maxAnt);
     310           0 :     Bool unflagged = false;
     311           0 :     Bool havevar = ns.count(hr) > 0;
     312             : 
     313           0 :     for(uInt corr = 0; corr < nCorr; ++corr){
     314           0 :         if(havevar &&
     315           0 :            (ns[hr][corr] >= minsamp_p) &&
     316           0 :            (0.0 < variances[hr][corr])){ // For some reason emacs likes 0 < v,
     317             :                                          // but not v > 0.
     318           0 :           Double var = variances[hr][corr];
     319             : 
     320           0 :           unflagged = true;
     321           0 :           vb.sigmaMat()(corr, r) = sqrt(var);
     322           0 :           vb.weightMat()(corr, r) = 1.0 / var;
     323             :         }
     324             :         else{
     325           0 :           vb.sigmaMat()(corr, r) = -1.0;
     326           0 :           vb.weightMat()(corr, r) = 0.0;
     327           0 :           for(uInt ch = 0; ch < nChan; ++ch){
     328           0 :             vb.flagCube()(corr, ch, r) = true;
     329             :           }
     330             :         }
     331           0 :         if(!unflagged)
     332           0 :           vb.flagRow()[r] = true;
     333             :     }
     334             :   }
     335             :   
     336             :   // argh
     337             :   // outvi_p.setFlagCube(vb.flagCube());
     338           0 :   outvi_p.setFlag(vb.flag());
     339             : 
     340           0 :   outvi_p.setSigmaMat(vb.sigmaMat());
     341           0 :   outvi_p.setWeightMat(vb.weightMat());
     342           0 :   return retval;
     343           0 : }
     344             : 
     345             : using namespace casacore;
     346             : } // end namespace casa

Generated by: LCOV version 1.16