LCOV - code coverage report
Current view: top level - mstransform/TVI - StatWtClassicalDataAggregator.cc (source / functions) Hit Total Coverage
Test: casacpp_coverage.info Lines: 0 143 0.0 %
Date: 2024-10-29 13:38:20 Functions: 0 7 0.0 %

          Line data    Source code
       1             : //#  CASA - Common Astronomy Software Applications (http://casa.nrao.edu/)
       2             : //#  Copyright (C) Associated Universities, Inc. Washington DC, USA 2011, All
       3             : //#  rights reserved.
       4             : //#  Copyright (C) European Southern Observatory, 2011, All rights reserved.
       5             : //#
       6             : //#  This library is free software; you can redistribute it and/or
       7             : //#  modify it under the terms of the GNU Lesser General Public
       8             : //#  License as published by the Free software Foundation; either
       9             : //#  version 2.1 of the License, or (at your option) any later version.
      10             : //#
      11             : //#  This library is distributed in the hope that it will be useful,
      12             : //#  but WITHOUT ANY WARRANTY, without even the implied warranty of
      13             : //#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
      14             : //#  Lesser General Public License for more details.
      15             : //#
      16             : //#  You should have received a copy of the GNU Lesser General Public
      17             : //#  License along with this library; if not, write to the Free Software
      18             : //#  Foundation, Inc., 59 Temple Place, Suite 330, Boston,
      19             : //#  MA 02111-1307  USA
      20             : 
      21             : #include <mstransform/TVI/StatWtClassicalDataAggregator.h>
      22             : 
      23             : #include <casacore/casa/Arrays/Cube.h>
      24             : #include <casacore/scimath/StatsFramework/ClassicalStatistics.h>
      25             : 
      26             : #ifdef _OPENMP
      27             : #include <omp.h>
      28             : #endif
      29             : 
      30             : using namespace casacore;
      31             : using namespace std;
      32             : 
      33             : namespace casa {
      34             : 
      35             : namespace vi {
      36             : 
      37           0 : StatWtClassicalDataAggregator::StatWtClassicalDataAggregator(
      38             :     ViImplementation2 *const vii,
      39             :     // shared_ptr<Bool>& mustComputeWtSp,
      40             :     const map<Int, vector<StatWtTypes::ChanBin>>& chanBins,
      41             :     std::shared_ptr<map<uInt, pair<uInt, uInt>>>& samples,
      42             :     StatWtTypes::Column column, Bool noModel,
      43             :     const map<uInt, Cube<Bool>>& chanSelFlags,
      44             :     shared_ptr<
      45             :         ClassicalStatistics<
      46             :             Double, Array<Float>::const_iterator,
      47             :             Array<Bool>::const_iterator
      48             :         >
      49             :     >& wtStats,
      50             :     shared_ptr<const pair<Double, Double>> wtrange, Bool combineCorr,
      51             :     shared_ptr<
      52             :         StatisticsAlgorithm<
      53             :             Double, Array<Float>::const_iterator, Array<Bool>::const_iterator,
      54             :             Array<Double>::const_iterator
      55             :         >
      56             :     >& statAlg, Int minSamp
      57           0 : ) : StatWtDataAggregator(
      58             :        vii, chanBins, samples, column, noModel, chanSelFlags, /* mustComputeWtSp,*/
      59             :        wtStats, wtrange, combineCorr, statAlg, minSamp
      60           0 :     ) {}
      61             : 
      62           0 : StatWtClassicalDataAggregator::~StatWtClassicalDataAggregator() {}
      63             : 
      64           0 : void StatWtClassicalDataAggregator::aggregate() {
      65             :     // Drive NEXT LOWER layer's ViImpl to gather data into allvis:
      66             :     // Assumes all sub-chunks in the current chunk are to be used
      67             :     // for the variance calculation
      68             :     // Essentially, we are sorting the incoming data into
      69             :     // allvis, to enable a convenient variance calculation
      70           0 :     _variances.clear();
      71           0 :     auto* vb = _vii->getVisBuffer();
      72           0 :     std::map<StatWtTypes::BaselineChanBin, Cube<Complex>> data;
      73           0 :     std::map<StatWtTypes::BaselineChanBin, Cube<Bool>> flags;
      74           0 :     std::map<StatWtTypes::BaselineChanBin, Vector<Double>> exposures;
      75           0 :     IPosition blc(3, 0);
      76           0 :     auto trc = blc;
      77           0 :     auto initChanSelTemplate = True;
      78           0 :     Cube<Bool> chanSelFlagTemplate, chanSelFlags;
      79           0 :     auto firstTime = True;
      80             :     // we cannot know the spw until we are in the subchunks loop
      81           0 :     Int spw = -1;
      82           0 :     for (_vii->origin(); _vii->more(); _vii->next()) {
      83           0 :         if (_checkFirstSubChunk(spw, firstTime, vb)) {
      84           0 :             return;
      85             :         }
      86           0 :         if (! _mustComputeWtSp) {
      87           0 :             _mustComputeWtSp.reset(
      88             :                 new Bool(
      89           0 :                     vb->existsColumn(VisBufferComponent2::WeightSpectrum)
      90           0 :                 )
      91             :             );
      92             :         }
      93           0 :         const auto& ant1 = vb->antenna1();
      94           0 :         const auto& ant2 = vb->antenna2();
      95             :         // [nCorr, nFreq, nRows)
      96           0 :         const auto& dataCube = _dataCube(vb);
      97           0 :         const auto& flagCube = vb->flagCube();
      98           0 :         const auto dataShape = dataCube.shape();
      99           0 :         const auto& exposureVector = vb->exposure();
     100           0 :         const auto nrows = vb->nRows();
     101           0 :         const auto npol = dataCube.nrow();
     102             :         const auto resultantFlags = _getResultantFlags(
     103             :             chanSelFlagTemplate, chanSelFlags, initChanSelTemplate,
     104             :             spw, flagCube
     105           0 :         );
     106           0 :         auto bins = _chanBins.find(spw)->second;
     107           0 :         StatWtTypes::BaselineChanBin blcb;
     108           0 :         blcb.spw = spw;
     109           0 :         IPosition dataCubeBLC(3, 0);
     110           0 :         auto dataCubeTRC = dataCube.shape() - 1;
     111           0 :         for (rownr_t row=0; row<nrows; ++row) {
     112           0 :             dataCubeBLC[2] = row;
     113           0 :             dataCubeTRC[2] = row;
     114           0 :             blcb.baseline = _baseline(ant1[row], ant2[row]);
     115           0 :             auto citer = bins.cbegin();
     116           0 :             auto cend = bins.cend();
     117           0 :             for (; citer!=cend; ++citer) {
     118           0 :                 dataCubeBLC[1] = citer->start;
     119           0 :                 dataCubeTRC[1] = citer->end;
     120           0 :                 blcb.chanBin.start = citer->start;
     121           0 :                 blcb.chanBin.end = citer->end;
     122           0 :                 auto dataSlice = dataCube(dataCubeBLC, dataCubeTRC);
     123           0 :                 auto flagSlice = resultantFlags(dataCubeBLC, dataCubeTRC);
     124           0 :                 if (data.find(blcb) == data.end()) {
     125           0 :                     data[blcb] = dataSlice;
     126           0 :                     flags[blcb] = flagSlice;
     127           0 :                     exposures[blcb] = Vector<Double>(1, exposureVector[row]);
     128             :                 }
     129             :                 else {
     130           0 :                     auto myshape = data[blcb].shape();
     131           0 :                     auto nplane = myshape[2];
     132           0 :                     auto nchan = myshape[1];
     133           0 :                     data[blcb].resize(npol, nchan, nplane+1, True);
     134           0 :                     flags[blcb].resize(npol, nchan, nplane+1, True);
     135           0 :                     exposures[blcb].resize(nplane+1, True);
     136           0 :                     trc = myshape - 1;
     137             :                     // because we've extended the cube by one plane since
     138             :                     // myshape was determined.
     139           0 :                     ++trc[2];
     140           0 :                     blc[2] = trc[2];
     141           0 :                     data[blcb](blc, trc) = dataSlice;
     142           0 :                     flags[blcb](blc, trc) = flagSlice;
     143           0 :                     exposures[blcb][trc[2]] = exposureVector[row];
     144           0 :                 }
     145           0 :             }
     146             :         }
     147           0 :     }
     148           0 :     _computeVariances(data, flags, exposures);
     149           0 : }
     150             : 
     151           0 : void StatWtClassicalDataAggregator::weightSingleChanBin(
     152             :     Matrix<Float>& wtmat, Int nrows
     153             : ) const {
     154           0 :     Vector<Int> ant1, ant2, spws;
     155           0 :     Vector<Double> exposures;
     156           0 :     _vii->antenna1(ant1);
     157           0 :     _vii->antenna2(ant2);
     158           0 :     _vii->spectralWindows(spws);
     159           0 :     _vii->exposure(exposures);
     160             :     // There is only one spw in a chunk
     161           0 :     auto spw = *spws.begin();
     162           0 :     StatWtTypes::BaselineChanBin blcb;
     163           0 :     blcb.spw = spw;
     164           0 :     for (Int i=0; i<nrows; ++i) {
     165           0 :         auto bins = _chanBins.find(spw)->second;
     166           0 :         blcb.baseline = _baseline(ant1[i], ant2[i]);
     167           0 :         blcb.chanBin = bins[0];
     168           0 :         auto variances = _variances.find(blcb)->second;
     169           0 :         if (_combineCorr) {
     170           0 :             wtmat.column(i) = exposures[i]/variances[0];
     171             :         }
     172             :         else {
     173           0 :             auto corr = 0;
     174           0 :             for (const auto variance: variances) {
     175           0 :                 wtmat(corr, i) = exposures[i]/variance;
     176           0 :                 ++corr;
     177           0 :             }
     178             :         }
     179           0 :     }
     180           0 : }
     181             : 
     182           0 : void StatWtClassicalDataAggregator::_computeVariances(
     183             :     const map<StatWtTypes::BaselineChanBin, Cube<Complex>>& data,
     184             :     const map<StatWtTypes::BaselineChanBin, Cube<Bool>>& flags,
     185             :     const map<StatWtTypes::BaselineChanBin, Vector<Double>>& exposures
     186             : ) const {
     187           0 :     auto diter = data.cbegin();
     188           0 :     auto dend = data.cend();
     189           0 :     const auto nActCorr = diter->second.shape()[0];
     190           0 :     const auto ncorr = _combineCorr ? 1 : nActCorr;
     191             :     // spw will be the same for all members
     192           0 :     const auto& spw = data.begin()->first.spw;
     193           0 :     vector<StatWtTypes::BaselineChanBin> keys(data.size());
     194           0 :     auto idx = 0;
     195           0 :     for (; diter!=dend; ++diter, ++idx) {
     196           0 :         const auto& blcb = diter->first;
     197           0 :         keys[idx] = blcb;
     198           0 :         _variances[blcb].resize(ncorr);
     199             :     }
     200           0 :     auto n = keys.size();
     201             : #ifdef _OPENMP
     202           0 : #pragma omp parallel for
     203             :     // cout << "WARN OMP PARALLEL LOOPING IS OFF FOR DEBUGGING" << endl;
     204             : #endif
     205             :     for (size_t i=0; i<n; ++i) {
     206             :         auto blcb = keys[i];
     207             :         auto dataForBLCB = data.find(blcb)->second;
     208             :         auto flagsForBLCB = flags.find(blcb)->second;
     209             :         auto exposuresForBLCB = exposures.find(blcb)->second;
     210             :         for (ssize_t corr=0; corr<ncorr; ++corr) {
     211             :             IPosition start(3, 0);
     212             :             auto end = dataForBLCB.shape() - 1;
     213             :             if (! _combineCorr) {
     214             :                 start[0] = corr;
     215             :                 end[0] = corr;
     216             :             }
     217             :             Slicer slice(start, end, Slicer::endIsLast);
     218             :             _variances[blcb][corr]
     219             :                 = _varianceComputer->computeVariance(
     220             :                     dataForBLCB(slice), flagsForBLCB(slice),
     221             :                     exposuresForBLCB, spw
     222             :                 );
     223             :         }
     224             :     }
     225           0 : }
     226             : 
     227           0 : void StatWtClassicalDataAggregator::weightSpectrumFlags(
     228             :     Cube<Float>& wtsp, Cube<Bool>& flagCube, Bool& checkFlags,
     229             :     const Vector<Int>& ant1, const Vector<Int>& ant2, const Vector<Int>& spws,
     230             :     const Vector<Double>& exposures, const Vector<rownr_t>&
     231             : ) const {
     232           0 :     Slicer slice(IPosition(3, 0), flagCube.shape(), Slicer::endIsLength);
     233           0 :     auto sliceStart = slice.start();
     234           0 :     auto sliceEnd = slice.end();
     235           0 :     auto nrows = ant1.size();
     236           0 :     for (size_t i=0; i<nrows; ++i) {
     237           0 :         sliceStart[2] = i;
     238           0 :         sliceEnd[2] = i;
     239           0 :         StatWtTypes::BaselineChanBin blcb;
     240           0 :         blcb.baseline = _baseline(ant1[i], ant2[i]);
     241           0 :         auto spw = spws[i];
     242           0 :         blcb.spw = spw;
     243           0 :         auto bins = _chanBins.find(spw)->second;
     244           0 :         for (const auto& bin: bins) {
     245           0 :             sliceStart[1] = bin.start;
     246           0 :             sliceEnd[1] = bin.end;
     247           0 :             blcb.chanBin = bin;
     248           0 :             auto variances = _variances.find(blcb)->second;
     249           0 :             auto ncorr = variances.size();
     250           0 :             Vector<Double> weights(ncorr);
     251           0 :             for (size_t corr=0; corr<ncorr; ++corr) {
     252           0 :                 if (! _combineCorr) {
     253           0 :                     sliceStart[0] = corr;
     254           0 :                     sliceEnd[0] = corr;
     255             :                 }
     256           0 :                 weights[corr] = variances[corr] == 0
     257           0 :                     ? 0 : exposures[i]/variances[corr];
     258           0 :                 slice.setStart(sliceStart);
     259           0 :                 slice.setEnd(sliceEnd);
     260           0 :                 _updateWtSpFlags(
     261           0 :                     wtsp, flagCube, checkFlags, slice, weights[corr]
     262             :                 );
     263             :             }
     264           0 :         }
     265           0 :     }
     266           0 : }
     267             : 
     268             : }
     269             : 
     270             : }

Generated by: LCOV version 1.16