Line data Source code
1 : //# CASA - Common Astronomy Software Applications (http://casa.nrao.edu/)
2 : //# Copyright (C) Associated Universities, Inc. Washington DC, USA 2011, All
3 : //# rights reserved.
4 : //# Copyright (C) European Southern Observatory, 2011, All rights reserved.
5 : //#
6 : //# This library is free software; you can redistribute it and/or
7 : //# modify it under the terms of the GNU Lesser General Public
8 : //# License as published by the Free software Foundation; either
9 : //# version 2.1 of the License, or (at your option) any later version.
10 : //#
11 : //# This library is distributed in the hope that it will be useful,
12 : //# but WITHOUT ANY WARRANTY, without even the implied warranty of
13 : //# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 : //# Lesser General Public License for more details.
15 : //#
16 : //# You should have received a copy of the GNU Lesser General Public
17 : //# License along with this library; if not, write to the Free Software
18 : //# Foundation, Inc., 59 Temple Place, Suite 330, Boston,
19 : //# MA 02111-1307 USA
20 :
21 : #include <mstransform/TVI/StatWtClassicalDataAggregator.h>
22 :
23 : #include <casacore/casa/Arrays/Cube.h>
24 : #include <casacore/scimath/StatsFramework/ClassicalStatistics.h>
25 :
26 : #ifdef _OPENMP
27 : #include <omp.h>
28 : #endif
29 :
30 : using namespace casacore;
31 : using namespace std;
32 :
33 : namespace casa {
34 :
35 : namespace vi {
36 :
37 12 : StatWtClassicalDataAggregator::StatWtClassicalDataAggregator(
38 : ViImplementation2 *const vii,
39 : // shared_ptr<Bool>& mustComputeWtSp,
40 : const map<Int, vector<StatWtTypes::ChanBin>>& chanBins,
41 : std::shared_ptr<map<uInt, pair<uInt, uInt>>>& samples,
42 : StatWtTypes::Column column, Bool noModel,
43 : const map<uInt, Cube<Bool>>& chanSelFlags,
44 : shared_ptr<
45 : ClassicalStatistics<
46 : Double, Array<Float>::const_iterator,
47 : Array<Bool>::const_iterator
48 : >
49 : >& wtStats,
50 : shared_ptr<const pair<Double, Double>> wtrange, Bool combineCorr,
51 : shared_ptr<
52 : StatisticsAlgorithm<
53 : Double, Array<Float>::const_iterator, Array<Bool>::const_iterator,
54 : Array<Double>::const_iterator
55 : >
56 : >& statAlg, Int minSamp
57 12 : ) : StatWtDataAggregator(
58 : vii, chanBins, samples, column, noModel, chanSelFlags, /* mustComputeWtSp,*/
59 : wtStats, wtrange, combineCorr, statAlg, minSamp
60 12 : ) {}
61 :
62 24 : StatWtClassicalDataAggregator::~StatWtClassicalDataAggregator() {}
63 :
64 123 : void StatWtClassicalDataAggregator::aggregate() {
65 : // Drive NEXT LOWER layer's ViImpl to gather data into allvis:
66 : // Assumes all sub-chunks in the current chunk are to be used
67 : // for the variance calculation
68 : // Essentially, we are sorting the incoming data into
69 : // allvis, to enable a convenient variance calculation
70 123 : _variances.clear();
71 123 : auto* vb = _vii->getVisBuffer();
72 123 : std::map<StatWtTypes::BaselineChanBin, Cube<Complex>> data;
73 123 : std::map<StatWtTypes::BaselineChanBin, Cube<Bool>> flags;
74 123 : std::map<StatWtTypes::BaselineChanBin, Vector<Double>> exposures;
75 123 : IPosition blc(3, 0);
76 123 : auto trc = blc;
77 123 : auto initChanSelTemplate = True;
78 123 : Cube<Bool> chanSelFlagTemplate, chanSelFlags;
79 123 : auto firstTime = True;
80 : // we cannot know the spw until we are in the subchunks loop
81 123 : Int spw = -1;
82 1483 : for (_vii->origin(); _vii->more(); _vii->next()) {
83 1372 : if (_checkFirstSubChunk(spw, firstTime, vb)) {
84 12 : return;
85 : }
86 1360 : if (! _mustComputeWtSp) {
87 20 : _mustComputeWtSp.reset(
88 : new Bool(
89 10 : vb->existsColumn(VisBufferComponent2::WeightSpectrum)
90 10 : )
91 : );
92 : }
93 1360 : const auto& ant1 = vb->antenna1();
94 1360 : const auto& ant2 = vb->antenna2();
95 : // [nCorr, nFreq, nRows)
96 1360 : const auto& dataCube = _dataCube(vb);
97 1360 : const auto& flagCube = vb->flagCube();
98 1360 : const auto dataShape = dataCube.shape();
99 1360 : const auto& exposureVector = vb->exposure();
100 1360 : const auto nrows = vb->nRows();
101 1360 : const auto npol = dataCube.nrow();
102 : const auto resultantFlags = _getResultantFlags(
103 : chanSelFlagTemplate, chanSelFlags, initChanSelTemplate,
104 : spw, flagCube
105 1360 : );
106 1360 : auto bins = _chanBins.find(spw)->second;
107 1360 : StatWtTypes::BaselineChanBin blcb;
108 1360 : blcb.spw = spw;
109 1360 : IPosition dataCubeBLC(3, 0);
110 1360 : auto dataCubeTRC = dataCube.shape() - 1;
111 35956 : for (rownr_t row=0; row<nrows; ++row) {
112 34596 : dataCubeBLC[2] = row;
113 34596 : dataCubeTRC[2] = row;
114 34596 : blcb.baseline = _baseline(ant1[row], ant2[row]);
115 34596 : auto citer = bins.cbegin();
116 34596 : auto cend = bins.cend();
117 113292 : for (; citer!=cend; ++citer) {
118 78696 : dataCubeBLC[1] = citer->start;
119 78696 : dataCubeTRC[1] = citer->end;
120 78696 : blcb.chanBin.start = citer->start;
121 78696 : blcb.chanBin.end = citer->end;
122 78696 : auto dataSlice = dataCube(dataCubeBLC, dataCubeTRC);
123 78696 : auto flagSlice = resultantFlags(dataCubeBLC, dataCubeTRC);
124 78696 : if (data.find(blcb) == data.end()) {
125 4758 : data[blcb] = dataSlice;
126 4758 : flags[blcb] = flagSlice;
127 4758 : exposures[blcb] = Vector<Double>(1, exposureVector[row]);
128 : }
129 : else {
130 73938 : auto myshape = data[blcb].shape();
131 73938 : auto nplane = myshape[2];
132 73938 : auto nchan = myshape[1];
133 73938 : data[blcb].resize(npol, nchan, nplane+1, True);
134 73938 : flags[blcb].resize(npol, nchan, nplane+1, True);
135 73938 : exposures[blcb].resize(nplane+1, True);
136 73938 : trc = myshape - 1;
137 : // because we've extended the cube by one plane since
138 : // myshape was determined.
139 73938 : ++trc[2];
140 73938 : blc[2] = trc[2];
141 73938 : data[blcb](blc, trc) = dataSlice;
142 73938 : flags[blcb](blc, trc) = flagSlice;
143 73938 : exposures[blcb][trc[2]] = exposureVector[row];
144 73938 : }
145 78696 : }
146 : }
147 1360 : }
148 111 : _computeVariances(data, flags, exposures);
149 195 : }
150 :
151 328 : void StatWtClassicalDataAggregator::weightSingleChanBin(
152 : Matrix<Float>& wtmat, Int nrows
153 : ) const {
154 328 : Vector<Int> ant1, ant2, spws;
155 328 : Vector<Double> exposures;
156 328 : _vii->antenna1(ant1);
157 328 : _vii->antenna2(ant2);
158 328 : _vii->spectralWindows(spws);
159 328 : _vii->exposure(exposures);
160 : // There is only one spw in a chunk
161 328 : auto spw = *spws.begin();
162 328 : StatWtTypes::BaselineChanBin blcb;
163 328 : blcb.spw = spw;
164 2284 : for (Int i=0; i<nrows; ++i) {
165 1956 : auto bins = _chanBins.find(spw)->second;
166 1956 : blcb.baseline = _baseline(ant1[i], ant2[i]);
167 1956 : blcb.chanBin = bins[0];
168 1956 : auto variances = _variances.find(blcb)->second;
169 1956 : if (_combineCorr) {
170 0 : wtmat.column(i) = exposures[i]/variances[0];
171 : }
172 : else {
173 1956 : auto corr = 0;
174 9780 : for (const auto variance: variances) {
175 7824 : wtmat(corr, i) = exposures[i]/variance;
176 7824 : ++corr;
177 1956 : }
178 : }
179 1956 : }
180 328 : }
181 :
182 111 : void StatWtClassicalDataAggregator::_computeVariances(
183 : const map<StatWtTypes::BaselineChanBin, Cube<Complex>>& data,
184 : const map<StatWtTypes::BaselineChanBin, Cube<Bool>>& flags,
185 : const map<StatWtTypes::BaselineChanBin, Vector<Double>>& exposures
186 : ) const {
187 111 : auto diter = data.cbegin();
188 111 : auto dend = data.cend();
189 111 : const auto nActCorr = diter->second.shape()[0];
190 111 : const auto ncorr = _combineCorr ? 1 : nActCorr;
191 : // spw will be the same for all members
192 111 : const auto& spw = data.begin()->first.spw;
193 111 : vector<StatWtTypes::BaselineChanBin> keys(data.size());
194 111 : auto idx = 0;
195 4869 : for (; diter!=dend; ++diter, ++idx) {
196 4758 : const auto& blcb = diter->first;
197 4758 : keys[idx] = blcb;
198 4758 : _variances[blcb].resize(ncorr);
199 : }
200 111 : auto n = keys.size();
201 : #ifdef _OPENMP
202 111 : #pragma omp parallel for
203 : // cout << "WARN OMP PARALLEL LOOPING IS OFF FOR DEBUGGING" << endl;
204 : #endif
205 : for (size_t i=0; i<n; ++i) {
206 : auto blcb = keys[i];
207 : auto dataForBLCB = data.find(blcb)->second;
208 : auto flagsForBLCB = flags.find(blcb)->second;
209 : auto exposuresForBLCB = exposures.find(blcb)->second;
210 : for (ssize_t corr=0; corr<ncorr; ++corr) {
211 : IPosition start(3, 0);
212 : auto end = dataForBLCB.shape() - 1;
213 : if (! _combineCorr) {
214 : start[0] = corr;
215 : end[0] = corr;
216 : }
217 : Slicer slice(start, end, Slicer::endIsLast);
218 : _variances[blcb][corr]
219 : = _varianceComputer->computeVariance(
220 : dataForBLCB(slice), flagsForBLCB(slice),
221 : exposuresForBLCB, spw
222 : );
223 : }
224 : }
225 111 : }
226 :
227 1360 : void StatWtClassicalDataAggregator::weightSpectrumFlags(
228 : Cube<Float>& wtsp, Cube<Bool>& flagCube, Bool& checkFlags,
229 : const Vector<Int>& ant1, const Vector<Int>& ant2, const Vector<Int>& spws,
230 : const Vector<Double>& exposures, const Vector<rownr_t>&
231 : ) const {
232 1360 : Slicer slice(IPosition(3, 0), flagCube.shape(), Slicer::endIsLength);
233 1360 : auto sliceStart = slice.start();
234 1360 : auto sliceEnd = slice.end();
235 1360 : auto nrows = ant1.size();
236 35956 : for (size_t i=0; i<nrows; ++i) {
237 34596 : sliceStart[2] = i;
238 34596 : sliceEnd[2] = i;
239 34596 : StatWtTypes::BaselineChanBin blcb;
240 34596 : blcb.baseline = _baseline(ant1[i], ant2[i]);
241 34596 : auto spw = spws[i];
242 34596 : blcb.spw = spw;
243 34596 : auto bins = _chanBins.find(spw)->second;
244 113292 : for (const auto& bin: bins) {
245 78696 : sliceStart[1] = bin.start;
246 78696 : sliceEnd[1] = bin.end;
247 78696 : blcb.chanBin = bin;
248 78696 : auto variances = _variances.find(blcb)->second;
249 78696 : auto ncorr = variances.size();
250 78696 : Vector<Double> weights(ncorr);
251 304380 : for (size_t corr=0; corr<ncorr; ++corr) {
252 225684 : if (! _combineCorr) {
253 195984 : sliceStart[0] = corr;
254 195984 : sliceEnd[0] = corr;
255 : }
256 451368 : weights[corr] = variances[corr] == 0
257 225684 : ? 0 : exposures[i]/variances[corr];
258 225684 : slice.setStart(sliceStart);
259 225684 : slice.setEnd(sliceEnd);
260 451368 : _updateWtSpFlags(
261 225684 : wtsp, flagCube, checkFlags, slice, weights[corr]
262 : );
263 : }
264 78696 : }
265 34596 : }
266 1360 : }
267 :
268 : }
269 :
270 : }
|