Line data Source code
1 : //# StatWtTVI.cc: This file contains the implementation of the StatWtTVI class.
2 : //#
3 : //# CASA - Common Astronomy Software Applications (http://casa.nrao.edu/)
4 : //# Copyright (C) Associated Universities, Inc. Washington DC, USA 2011, All
5 : //# rights reserved.
6 : //# Copyright (C) European Southern Observatory, 2011, All rights reserved.
7 : //#
8 : //# This library is free software; you can redistribute it and/or
9 : //# modify it under the terms of the GNU Lesser General Public
10 : //# License as published by the Free software Foundation; either
11 : //# version 2.1 of the License, or (at your option) any later version.
12 : //#
13 : //# This library is distributed in the hope that it will be useful,
14 : //# but WITHOUT ANY WARRANTY, without even the implied warranty of
15 : //# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 : //# Lesser General Public License for more details.
17 : //#
18 : //# You should have received a copy of the GNU Lesser General Public
19 : //# License along with this library; if not, write to the Free Software
20 : //# Foundation, Inc., 59 Temple Place, Suite 330, Boston,
21 : //# MA 02111-1307 USA
22 :
23 : #include <mstransform/TVI/StatWtTVI.h>
24 :
25 : #include <casacore/casa/Arrays/ArrayLogical.h>
26 : #include <casacore/casa/Quanta/QuantumHolder.h>
27 : #include <casacore/ms/MSOper/MSMetaData.h>
28 : #include <casacore/tables/Tables/ArrColDesc.h>
29 :
30 : #include <mstransform/TVI/StatWtClassicalDataAggregator.h>
31 : #include <mstransform/TVI/StatWtFloatingWindowDataAggregator.h>
32 : #include <mstransform/TVI/StatWtVarianceAndWeightCalculator.h>
33 :
34 : #ifdef _OPENMP
35 : #include <omp.h>
36 : #endif
37 :
38 : #include <iomanip>
39 :
40 : using namespace casacore;
41 : using namespace casac;
42 :
43 : namespace casa {
44 : namespace vi {
45 :
46 : const String StatWtTVI::CHANBIN = "stchanbin";
47 :
48 0 : StatWtTVI::StatWtTVI(ViImplementation2 * inputVii, const Record &configuration)
49 0 : : TransformingVi2 (inputVii) {
50 : // Parse and check configuration parameters
51 : // Note: if a constructor finishes by throwing an exception, the memory
52 : // associated with the object itself is cleaned up there is no memory leak.
53 0 : ThrowIf(
54 : ! _parseConfiguration(configuration),
55 : "Error parsing StatWtTVI configuration"
56 : );
57 0 : LogIO log(LogOrigin("StatWtTVI", __func__));
58 0 : log << LogIO::NORMAL << "Using " << StatWtTypes::asString(_column)
59 0 : << " to compute weights" << LogIO::POST;
60 : // FIXME when the TVI framework has methods to
61 : // check for metadata, like the existence of
62 : // columns, remove references to the original MS
63 0 : const auto& origMS = ms();
64 : // FIXME uses original MS explicitly
65 0 : ThrowIf(
66 : (_column == StatWtTypes::CORRECTED || _column == StatWtTypes::RESIDUAL)
67 : && ! origMS.isColumn(MSMainEnums::CORRECTED_DATA),
68 : "StatWtTVI requires the MS to have a CORRECTED_DATA column. This MS "
69 : "does not"
70 : );
71 : // FIXME uses original MS explicitly
72 0 : ThrowIf(
73 : (_column == StatWtTypes::DATA || _column == StatWtTypes::RESIDUAL_DATA)
74 : && ! origMS.isColumn(MSMainEnums::DATA),
75 : "StatWtTVI requires the MS to have a DATA column. This MS does not"
76 : );
77 0 : _mustComputeSigma = (
78 0 : _column == StatWtTypes::DATA || _column == StatWtTypes::RESIDUAL_DATA
79 : );
80 : // FIXME uses original MS explicitly
81 0 : _updateWeight = ! _mustComputeSigma
82 0 : || (_mustComputeSigma && ! origMS.isColumn(MSMainEnums::CORRECTED_DATA));
83 0 : _noModel = (
84 0 : _column == StatWtTypes::RESIDUAL || _column == StatWtTypes::RESIDUAL_DATA
85 0 : ) && ! origMS.isColumn(MSMainEnums::MODEL_DATA)
86 0 : && ! origMS.source().isColumn(MSSourceEnums::SOURCE_MODEL);
87 : // Initialize attached VisBuffer
88 0 : setVisBuffer(createAttachedVisBuffer(VbRekeyable));
89 0 : }
90 :
91 0 : StatWtTVI::~StatWtTVI() {}
92 :
93 0 : Bool StatWtTVI::_parseConfiguration(const Record& config) {
94 0 : String field = CHANBIN;
95 0 : if (config.isDefined(field)) {
96 : // channel binning
97 0 : auto fieldNum = config.fieldNumber(field);
98 0 : switch (config.type(fieldNum)) {
99 0 : case DataType::TpArrayBool:
100 : // because this is the actual default variant type, no matter
101 : // what is specified in the xml
102 0 : ThrowIf(
103 : ! config.asArrayBool(field).empty(),
104 : "Unsupported data type for " + field
105 : );
106 0 : _setDefaultChanBinMap();
107 0 : break;
108 0 : case DataType::TpInt:
109 : Int binWidth;
110 0 : config.get(CHANBIN, binWidth);
111 0 : _setChanBinMap(binWidth);
112 0 : break;
113 0 : case DataType::TpString:
114 : {
115 0 : auto chanbin = config.asString(field);
116 0 : if (chanbin == "spw") {
117 : // bin using entire spws
118 0 : _setDefaultChanBinMap();
119 0 : break;
120 : }
121 : else {
122 0 : QuantumHolder qh(casaQuantity(chanbin));
123 0 : _setChanBinMap(qh.asQuantity());
124 0 : }
125 0 : break;
126 0 : }
127 0 : default:
128 0 : ThrowCc("Unsupported data type for " + field);
129 : }
130 : }
131 : else {
132 0 : _setDefaultChanBinMap();
133 : }
134 0 : field = "minsamp";
135 0 : if (config.isDefined(field)) {
136 0 : config.get(field, _minSamp);
137 0 : ThrowIf(_minSamp < 2, "Minimum size of sample must be >= 2.");
138 : }
139 0 : field = "combine";
140 0 : if (config.isDefined(field)) {
141 0 : ThrowIf(
142 : config.type(config.fieldNumber(field)) != TpString,
143 : "Unsupported data type for combine"
144 : );
145 0 : _combineCorr = config.asString(field).contains("corr");
146 : }
147 0 : field = "wtrange";
148 0 : if (config.isDefined(field)) {
149 0 : ThrowIf(
150 : config.type(config.fieldNumber(field)) != TpArrayDouble,
151 : "Unsupported type for field '" + field + "'"
152 : );
153 0 : auto myrange = config.asArrayDouble(field);
154 0 : if (! myrange.empty()) {
155 0 : ThrowIf(
156 : myrange.size() != 2,
157 : "Array specified in '" + field
158 : + "' must have exactly two values"
159 : );
160 0 : ThrowIf(
161 : casacore::anyLT(myrange, 0.0),
162 : "Both values specified in '" + field
163 : + "' array must be non-negative"
164 : );
165 0 : std::set<Double> rangeset(myrange.begin(), myrange.end());
166 0 : ThrowIf(
167 : rangeset.size() == 1, "Values specified in '" + field
168 : + "' array must be unique"
169 : );
170 0 : auto iter = rangeset.begin();
171 0 : _wtrange.reset(new std::pair<Double, Double>(*iter, *(++iter)));
172 0 : }
173 0 : }
174 0 : auto excludeChans = False;
175 0 : field = "excludechans";
176 0 : if (config.isDefined(field)) {
177 0 : ThrowIf(
178 : config.type(config.fieldNumber(field)) != TpBool,
179 : "Unsupported type for field '" + field + "'"
180 : );
181 0 : excludeChans = config.asBool(field);
182 : }
183 0 : field = "fitspw";
184 0 : if (config.isDefined(field)) {
185 0 : ThrowIf(
186 : config.type(config.fieldNumber(field)) != TpString,
187 : "Unsupported type for field '" + field + "'"
188 : );
189 0 : auto val = config.asString(field);
190 0 : if (! val.empty()) {
191 : // FIXME references underlying MS
192 0 : const auto& myms = ms();
193 0 : MSSelection sel(myms);
194 0 : sel.setSpwExpr(val);
195 0 : auto chans = sel.getChanList();
196 0 : auto nrows = chans.nrow();
197 0 : MSMetaData md(&myms, 50);
198 0 : auto nchans = md.nChans();
199 0 : IPosition start(3, 0);
200 0 : IPosition stop(3, 0);
201 0 : IPosition step(3, 1);
202 0 : for (size_t i=0; i<nrows; ++i) {
203 0 : auto row = chans.row(i);
204 0 : const auto& spw = row[0];
205 0 : if (_chanSelFlags.find(spw) == _chanSelFlags.end()) {
206 0 : _chanSelFlags[spw]
207 0 : = Cube<Bool>(1, nchans[spw], 1, ! excludeChans);
208 : }
209 0 : start[1] = row[1];
210 0 : ThrowIf(
211 : start[1] < 0, "Invalid channel selection in spw "
212 : + String::toString(spw))
213 : ;
214 0 : stop[1] = row[2];
215 0 : step[1] = row[3];
216 0 : Slicer slice(start, stop, step, Slicer::endIsLast);
217 0 : _chanSelFlags[spw](slice) = excludeChans;
218 0 : }
219 0 : }
220 0 : }
221 0 : field = "datacolumn";
222 0 : if (config.isDefined(field)) {
223 0 : ThrowIf(
224 : config.type(config.fieldNumber(field)) != TpString,
225 : "Unsupported type for field '" + field + "'"
226 : );
227 0 : auto val = config.asString(field);
228 0 : if (! val.empty()) {
229 0 : val.downcase();
230 0 : ThrowIf (
231 : ! (
232 : val.startsWith("c") || val.startsWith("d")
233 : || val.startsWith("residual") || val.startsWith("residual_")
234 : ),
235 : "Unsupported value for " + field + ": " + val
236 : );
237 0 : _column = val.startsWith("c") ? StatWtTypes::CORRECTED
238 0 : : val.startsWith("d") ? StatWtTypes::DATA
239 0 : : val.startsWith("residual_") ? StatWtTypes::RESIDUAL_DATA
240 : : StatWtTypes::RESIDUAL;
241 :
242 : }
243 0 : }
244 0 : field = "slidetimebin";
245 0 : ThrowIf(
246 : ! config.isDefined(field), "Config param " + field + " must be defined"
247 : );
248 0 : ThrowIf(
249 : config.type(config.fieldNumber(field)) != TpBool,
250 : "Unsupported type for field '" + field + "'"
251 : );
252 0 : _timeBlockProcessing = ! config.asBool(field);
253 0 : field = "timebin";
254 0 : ThrowIf(
255 : ! config.isDefined(field), "Config param " + field + " must be defined"
256 : );
257 0 : auto mytype = config.type(config.fieldNumber(field));
258 0 : ThrowIf(
259 : ! (
260 : mytype == TpString || mytype == TpDouble
261 : || mytype == TpInt
262 : ),
263 : "Unsupported type for field '" + field + "'"
264 : );
265 0 : switch(mytype) {
266 0 : case TpDouble: {
267 0 : _binWidthInSeconds.reset(new Double(config.asDouble(field)));
268 0 : break;
269 : }
270 0 : case TpInt:
271 0 : _nTimeStampsInBin.reset(new Int(config.asInt(field)));
272 0 : ThrowIf(
273 : *_nTimeStampsInBin <= 0,
274 : "Logic Error: nTimeStamps must be positive"
275 : );
276 0 : break;
277 0 : case TpString: {
278 0 : QuantumHolder qh(casaQuantity(config.asString(field)));
279 0 : _binWidthInSeconds.reset(
280 0 : new Double(getTimeBinWidthInSec(qh.asQuantity()))
281 : );
282 0 : break;
283 0 : }
284 0 : default:
285 0 : ThrowCc("Logic Error: Unhandled type for timebin");
286 :
287 : }
288 0 : _doClassicVIVB = _binWidthInSeconds && _timeBlockProcessing;
289 0 : _configureStatAlg(config);
290 0 : if (_doClassicVIVB) {
291 0 : _dataAggregator.reset(
292 : new StatWtClassicalDataAggregator(
293 0 : getVii(), _chanBins, _samples, _column, _noModel, _chanSelFlags,
294 0 : _wtStats, _wtrange, _combineCorr, _statAlg, _minSamp
295 0 : )
296 : );
297 : }
298 : else {
299 0 : _dataAggregator.reset(
300 : new StatWtFloatingWindowDataAggregator(
301 0 : getVii(), _chanBins, _samples, _column, _noModel, _chanSelFlags,
302 0 : _combineCorr, _wtStats, _wtrange, _binWidthInSeconds,
303 0 : _nTimeStampsInBin, _timeBlockProcessing, _statAlg, _minSamp
304 0 : )
305 : );
306 : }
307 0 : _dataAggregator->setMustComputeWtSp(_mustComputeWtSp);
308 0 : return True;
309 0 : }
310 :
311 0 : void StatWtTVI::_configureStatAlg(const Record& config) {
312 0 : String field = "statalg";
313 0 : if (config.isDefined(field)) {
314 0 : ThrowIf(
315 : config.type(config.fieldNumber(field)) != TpString,
316 : "Unsupported type for field '" + field + "'"
317 : );
318 0 : auto alg = config.asString(field);
319 0 : alg.downcase();
320 0 : if (alg.startsWith("cl")) {
321 0 : _statAlg.reset(
322 : new ClassicalStatistics<
323 : Double, Array<Float>::const_iterator,
324 : Array<Bool>::const_iterator, Array<Double>::const_iterator
325 0 : >()
326 : );
327 : }
328 : else {
329 : casacore::StatisticsAlgorithmFactory<
330 : Double, Array<Float>::const_iterator,
331 : Array<Bool>::const_iterator, Array<Double>::const_iterator
332 0 : > saf;
333 0 : if (alg.startsWith("ch")) {
334 0 : Int maxiter = -1;
335 0 : field = "maxiter";
336 0 : if (config.isDefined(field)) {
337 0 : ThrowIf(
338 : config.type(config.fieldNumber(field)) != TpInt,
339 : "Unsupported type for field '" + field + "'"
340 : );
341 0 : maxiter = config.asInt(field);
342 : }
343 0 : Double zscore = -1;
344 0 : field = "zscore";
345 0 : if (config.isDefined(field)) {
346 0 : ThrowIf(
347 : config.type(config.fieldNumber(field)) != TpDouble,
348 : "Unsupported type for field '" + field + "'"
349 : );
350 0 : zscore = config.asDouble(field);
351 : }
352 0 : saf.configureChauvenet(zscore, maxiter);
353 : }
354 0 : else if (alg.startsWith("f")) {
355 0 : auto center = FitToHalfStatisticsData::CMEAN;
356 0 : field = "center";
357 0 : if (config.isDefined(field)) {
358 0 : ThrowIf(
359 : config.type(config.fieldNumber(field)) != TpString,
360 : "Unsupported type for field '" + field + "'"
361 : );
362 0 : auto cs = config.asString(field);
363 0 : cs.downcase();
364 0 : if (cs == "mean") {
365 0 : center = FitToHalfStatisticsData::CMEAN;
366 : }
367 0 : else if (cs == "median") {
368 0 : center = FitToHalfStatisticsData::CMEDIAN;
369 : }
370 0 : else if (cs == "zero") {
371 0 : center = FitToHalfStatisticsData::CVALUE;
372 : }
373 : else {
374 0 : ThrowCc("Unsupported value for '" + field + "'");
375 : }
376 0 : }
377 0 : field = "lside";
378 0 : auto ud = FitToHalfStatisticsData::LE_CENTER;
379 0 : if (config.isDefined(field)) {
380 0 : ThrowIf(
381 : config.type(config.fieldNumber(field)) != TpBool,
382 : "Unsupported type for field '" + field + "'"
383 : );
384 0 : ud = config.asBool(field)
385 0 : ? FitToHalfStatisticsData::LE_CENTER
386 : : FitToHalfStatisticsData::GE_CENTER;
387 : }
388 0 : saf.configureFitToHalf(center, ud, 0);
389 : }
390 0 : else if (alg.startsWith("h")) {
391 0 : Double fence = -1;
392 0 : field = "fence";
393 0 : if (config.isDefined(field)) {
394 0 : ThrowIf(
395 : config.type(config.fieldNumber(field)) != TpDouble,
396 : "Unsupported type for field '" + field + "'"
397 : );
398 0 : fence = config.asDouble(field);
399 : }
400 0 : saf.configureHingesFences(fence);
401 : }
402 : else {
403 0 : ThrowCc("Unsupported value for 'statalg'");
404 : }
405 : // clone needed for CountedPtr -> shared_ptr hand off
406 0 : _statAlg.reset(saf.createStatsAlgorithm()->clone());
407 0 : }
408 0 : }
409 : else {
410 0 : _statAlg.reset(
411 : new ClassicalStatistics<
412 : Double, Array<Float>::const_iterator,
413 : Array<Bool>::const_iterator, Array<Double>::const_iterator
414 0 : >());
415 : }
416 0 : std::set<StatisticsData::STATS> stats {StatisticsData::VARIANCE};
417 0 : _statAlg->setStatsToCalculate(stats);
418 : // also configure the _wtStats object here
419 : // FIXME? Does not include exposure weighting
420 0 : _wtStats.reset(
421 : new ClassicalStatistics<
422 : Double, Array<Float>::const_iterator,
423 : Array<Bool>::const_iterator
424 0 : >()
425 : );
426 0 : stats.insert(StatisticsData::MEAN);
427 0 : _wtStats->setStatsToCalculate(stats);
428 0 : _wtStats->setCalculateAsAdded(True);
429 0 : }
430 :
431 0 : void StatWtTVI::_logUsedChannels() const {
432 : // FIXME uses underlying MS
433 0 : MSMetaData msmd(&ms(), 100.0);
434 0 : const auto nchan = msmd.nChans();
435 0 : LogIO log(LogOrigin("StatWtTVI", __func__));
436 0 : log << LogIO::NORMAL << "Weights are being computed using ";
437 0 : const auto cend = _chanSelFlags.cend();
438 0 : const auto nspw = _samples->size();
439 0 : uInt spwCount = 0;
440 0 : for (const auto& kv: *_samples) {
441 0 : const auto spw = kv.first;
442 0 : log << "SPW " << spw << ", channels ";
443 0 : const auto flagCube = _chanSelFlags.find(spw);
444 0 : if (flagCube == cend) {
445 0 : log << "0~" << (nchan[spw] - 1);
446 : }
447 : else {
448 0 : vector<pair<uInt, uInt>> startEnd;
449 0 : const auto flags = flagCube->second.tovector();
450 0 : bool started = false;
451 0 : std::unique_ptr<pair<uInt, uInt>> curPair;
452 0 : for (uInt j=0; j<nchan[spw]; ++j) {
453 0 : if (started) {
454 0 : if (flags[j]) {
455 : // found a bad channel, end current range
456 0 : startEnd.push_back(*curPair);
457 0 : started = false;
458 : }
459 : else {
460 : // found a "good" channel, update end of current range
461 0 : curPair->second = j;
462 : }
463 : }
464 0 : else if (! flags[j]) {
465 : // found a good channel, start new range
466 0 : started = true;
467 0 : curPair.reset(new pair<uInt, uInt>(j, j));
468 : }
469 : }
470 0 : if (curPair) {
471 0 : if (started) {
472 : // The last pair won't get added inside the previous loop,
473 : // so add it here
474 0 : startEnd.push_back(*curPair);
475 : }
476 0 : auto nPairs = startEnd.size();
477 0 : for (uInt i=0; i<nPairs; ++i) {
478 0 : log << startEnd[i].first << "~" << startEnd[i].second;
479 0 : if (i < nPairs - 1) {
480 0 : log << ", ";
481 : }
482 : }
483 : }
484 : else {
485 : // if the pointer never got set, all the channels are bad
486 0 : log << "no channels";
487 : }
488 0 : }
489 0 : if (spwCount < (nspw - 1)) {
490 0 : log << ";";
491 : }
492 0 : ++spwCount;
493 : }
494 0 : log << LogIO::POST;
495 0 : }
496 :
497 0 : void StatWtTVI::_setChanBinMap(const casacore::Quantity& binWidth) {
498 0 : if (! binWidth.isConform(Unit("Hz"))) {
499 0 : ostringstream oss;
500 : oss << "If specified as a quantity, channel bin width must have "
501 0 : << "frequency units. " << binWidth << " does not.";
502 0 : ThrowCc(oss.str());
503 0 : }
504 0 : ThrowIf(binWidth.getValue() <= 0, "channel bin width must be positive");
505 0 : MSMetaData msmd(&ms(), 100.0);
506 0 : auto chanFreqs = msmd.getChanFreqs();
507 0 : auto nspw = chanFreqs.size();
508 0 : auto binWidthHz = binWidth.getValue("Hz");
509 0 : for (uInt i=0; i<nspw; ++i) {
510 0 : auto cfs = chanFreqs[i].getValue("Hz");
511 0 : auto citer = cfs.begin();
512 0 : auto cend = cfs.end();
513 0 : StatWtTypes::ChanBin bin;
514 0 : bin.start = 0;
515 0 : bin.end = 0;
516 0 : uInt chanNum = 0;
517 0 : auto startFreq = *citer;
518 0 : auto nchan = cfs.size();
519 0 : for (; citer!=cend; ++citer, ++chanNum) {
520 : // both could be true, in which case both conditionals
521 : // must be executed
522 0 : if (abs(*citer - startFreq) > binWidthHz) {
523 : // add bin to list
524 0 : bin.end = chanNum - 1;
525 0 : _chanBins[i].push_back(bin);
526 0 : bin.start = chanNum;
527 0 : startFreq = *citer;
528 : }
529 0 : if (chanNum + 1 == nchan) {
530 : // add last bin
531 0 : bin.end = chanNum;
532 0 : _chanBins[i].push_back(bin);
533 : }
534 : }
535 0 : }
536 : // weight spectrum must be computed
537 0 : _mustComputeWtSp.reset(new Bool(True));
538 0 : }
539 :
540 0 : void StatWtTVI::_setChanBinMap(Int binWidth) {
541 0 : ThrowIf(binWidth < 1, "Channel bin width must be positive");
542 0 : MSMetaData msmd(&ms(), 100.0);
543 0 : auto nchans = msmd.nChans();
544 0 : auto nspw = nchans.size();
545 0 : StatWtTypes::ChanBin bin;
546 0 : for (uInt i=0; i<nspw; ++i) {
547 0 : auto lastChan = nchans[i]-1;
548 0 : for (uInt j=0; j<nchans[i]; j += binWidth) {
549 0 : bin.start = j;
550 0 : bin.end = min(j+binWidth-1, lastChan);
551 0 : _chanBins[i].push_back(bin);
552 : }
553 : }
554 : // weight spectrum must be computed
555 0 : _mustComputeWtSp.reset(new Bool(True));
556 0 : }
557 :
558 0 : void StatWtTVI::_setDefaultChanBinMap() {
559 : // FIXME uses underlying MS
560 0 : MSMetaData msmd(&ms(), 0.0);
561 0 : auto nchans = msmd.nChans();
562 0 : auto niter = nchans.begin();
563 0 : auto nend = nchans.end();
564 0 : Int i = 0;
565 0 : StatWtTypes::ChanBin bin;
566 0 : bin.start = 0;
567 0 : for (; niter!=nend; ++niter, ++i) {
568 0 : bin.end = *niter - 1;
569 0 : _chanBins[i].push_back(bin);
570 : }
571 0 : }
572 :
573 0 : Double StatWtTVI::getTimeBinWidthInSec(const casacore::Quantity& binWidth) {
574 0 : ThrowIf(
575 : ! binWidth.isConform(Unit("s")),
576 : "Time bin width unit must be a unit of time"
577 : );
578 0 : auto v = binWidth.getValue("s");
579 0 : checkTimeBinWidth(v);
580 0 : return v;
581 : }
582 :
583 0 : void StatWtTVI::checkTimeBinWidth(Double binWidth) {
584 0 : ThrowIf(binWidth <= 0, "time bin width must be positive");
585 0 : }
586 :
587 0 : void StatWtTVI::sigmaSpectrum(Cube<Float>& sigmaSp) const {
588 0 : if (_mustComputeSigma) {
589 : {
590 0 : Cube<Float> wtsp;
591 : // this computes _newWtsp, ignore wtsp
592 0 : weightSpectrum(wtsp);
593 0 : }
594 0 : sigmaSp = Float(1.0)/sqrt(_newWtSp);
595 0 : if (anyEQ(_newWtSp, Float(0))) {
596 0 : auto iter = sigmaSp.begin();
597 0 : auto end = sigmaSp.end();
598 0 : auto witer = _newWtSp.cbegin();
599 0 : for ( ; iter != end; ++iter, ++witer) {
600 0 : if (*witer == 0) {
601 0 : *iter = -1;
602 : }
603 : }
604 0 : }
605 : }
606 : else {
607 0 : TransformingVi2::sigmaSpectrum(sigmaSp);
608 : }
609 0 : }
610 :
611 0 : void StatWtTVI::weightSpectrum(Cube<Float>& newWtsp) const {
612 0 : ThrowIf(! _weightsComputed, "Weights have not been computed yet");
613 0 : if (! _dataAggregator->mustComputeWtSp()) {
614 0 : newWtsp.resize(IPosition(3, 0));
615 0 : return;
616 : }
617 0 : if (! _newWtSp.empty()) {
618 : // already calculated
619 0 : if (_updateWeight) {
620 0 : newWtsp = _newWtSp.copy();
621 : }
622 : else {
623 0 : TransformingVi2::weightSpectrum(newWtsp);
624 : }
625 0 : return;
626 : }
627 0 : _computeWeightSpectrumAndFlags();
628 0 : if (_updateWeight) {
629 0 : newWtsp = _newWtSp.copy();
630 : }
631 : else {
632 0 : TransformingVi2::weightSpectrum(newWtsp);
633 : }
634 : }
635 :
636 0 : void StatWtTVI::_computeWeightSpectrumAndFlags() const {
637 : size_t nOrigFlagged;
638 0 : auto mypair = _getLowerLayerWtSpFlags(nOrigFlagged);
639 0 : auto& wtsp = mypair.first;
640 0 : auto& flagCube = mypair.second;
641 0 : if (_dataAggregator->mustComputeWtSp() && wtsp.empty()) {
642 : // This can happen in preview mode if
643 : // WEIGHT_SPECTRUM doesn't exist or is empty
644 0 : wtsp.resize(flagCube.shape());
645 : }
646 0 : auto checkFlags = False;
647 0 : Vector<Int> ant1, ant2, spws;
648 0 : antenna1(ant1);
649 0 : antenna2(ant2);
650 0 : spectralWindows(spws);
651 0 : Vector<rownr_t> rowIDs;
652 0 : getRowIds(rowIDs);
653 0 : Vector<Double> exposures;
654 0 : exposure(exposures);
655 0 : _dataAggregator->weightSpectrumFlags(
656 : wtsp, flagCube, checkFlags, ant1, ant2, spws, exposures, rowIDs
657 : );
658 0 : if (checkFlags) {
659 0 : _nNewFlaggedPts += ntrue(flagCube) - nOrigFlagged;
660 : }
661 0 : _newWtSp = wtsp;
662 0 : _newFlag = flagCube;
663 0 : }
664 :
665 0 : std::pair<Cube<Float>, Cube<Bool>> StatWtTVI::_getLowerLayerWtSpFlags(
666 : size_t& nOrigFlagged
667 : ) const {
668 0 : auto mypair = std::make_pair(Cube<Float>(), Cube<Bool>());
669 0 : if (_dataAggregator->mustComputeWtSp()) {
670 0 : getVii()->weightSpectrum(mypair.first);
671 : }
672 0 : getVii()->flag(mypair.second);
673 0 : _nTotalPts += mypair.second.size();
674 0 : nOrigFlagged = ntrue(mypair.second);
675 0 : _nOrigFlaggedPts += nOrigFlagged;
676 0 : return mypair;
677 0 : }
678 :
679 0 : void StatWtTVI::sigma(Matrix<Float>& sigmaMat) const {
680 0 : if (_mustComputeSigma) {
681 0 : if (_newWt.empty()) {
682 0 : Matrix<Float> wtmat;
683 0 : weight(wtmat);
684 0 : }
685 0 : sigmaMat = Float(1.0)/sqrt(_newWt);
686 0 : if (anyEQ(_newWt, Float(0))) {
687 0 : Matrix<Float>::iterator iter = sigmaMat.begin();
688 0 : Matrix<Float>::iterator end = sigmaMat.end();
689 0 : Matrix<Float>::iterator witer = _newWt.begin();
690 0 : for ( ; iter != end; ++iter, ++witer) {
691 0 : if (*witer == 0) {
692 0 : *iter = -1;
693 : }
694 : }
695 0 : }
696 : }
697 : else {
698 0 : TransformingVi2::sigma(sigmaMat);
699 : }
700 0 : }
701 :
702 0 : void StatWtTVI::weight(Matrix<Float> & wtmat) const {
703 0 : ThrowIf(! _weightsComputed, "Weights have not been computed yet");
704 0 : if (! _newWt.empty()) {
705 0 : if (_updateWeight) {
706 0 : wtmat = _newWt.copy();
707 : }
708 : else {
709 0 : TransformingVi2::weight(wtmat);
710 : }
711 0 : return;
712 : }
713 0 : auto nrows = nRows();
714 0 : getVii()->weight(wtmat);
715 0 : if (_dataAggregator->mustComputeWtSp()) {
716 : // always use classical algorithm to get median for weights
717 : ClassicalStatistics<
718 : Double, Array<Float>::const_iterator, Array<Bool>::const_iterator
719 0 : > cs;
720 0 : Cube<Float> wtsp;
721 0 : Cube<Bool> flagCube;
722 : // this computes _newWtsP which is what we will use, so
723 : // just ignore wtsp
724 0 : weightSpectrum(wtsp);
725 0 : flag(flagCube);
726 0 : IPosition blc(3, 0);
727 0 : IPosition trc = _newWtSp.shape() - 1;
728 0 : const auto ncorr = _newWtSp.shape()[0];
729 0 : for (rownr_t i=0; i<nrows; ++i) {
730 0 : blc[2] = i;
731 0 : trc[2] = i;
732 0 : if (_combineCorr) {
733 0 : auto flags = flagCube(blc, trc);
734 0 : if (allTrue(flags)) {
735 0 : wtmat.column(i) = 0;
736 : }
737 : else {
738 0 : auto weights = _newWtSp(blc, trc);
739 0 : auto mask = ! flags;
740 0 : cs.setData(weights.begin(), mask.begin(), weights.size());
741 0 : wtmat.column(i) = cs.getMedian();
742 0 : }
743 0 : }
744 : else {
745 0 : for (uInt corr=0; corr<ncorr; ++corr) {
746 0 : blc[0] = corr;
747 0 : trc[0] = corr;
748 0 : auto weights = _newWtSp(blc, trc);
749 0 : auto flags = flagCube(blc, trc);
750 0 : if (allTrue(flags)) {
751 0 : wtmat(corr, i) = 0;
752 : }
753 : else {
754 0 : auto mask = ! flags;
755 0 : cs.setData(
756 0 : weights.begin(), mask.begin(), weights.size()
757 : );
758 0 : wtmat(corr, i) = cs.getMedian();
759 0 : }
760 0 : }
761 : }
762 : }
763 0 : }
764 : else {
765 : // the only way this can happen is if there is a single channel bin
766 : // for each baseline/spw pair
767 0 : _dataAggregator->weightSingleChanBin(wtmat, nrows);
768 : }
769 0 : _newWt = wtmat.copy();
770 0 : if (! _updateWeight) {
771 0 : wtmat = Matrix<Float>(wtmat.shape());
772 0 : TransformingVi2::weight(wtmat);
773 : }
774 : }
775 :
776 0 : void StatWtTVI::flag(Cube<Bool>& flagCube) const {
777 0 : ThrowIf(! _weightsComputed, "Weights have not been computed yet");
778 0 : if (! _newFlag.empty()) {
779 0 : flagCube = _newFlag.copy();
780 0 : return;
781 : }
782 0 : _computeWeightSpectrumAndFlags();
783 0 : flagCube = _newFlag.copy();
784 : }
785 :
786 0 : void StatWtTVI::flagRow(Vector<Bool>& flagRow) const {
787 0 : ThrowIf(! _weightsComputed, "Weights have not been computed yet");
788 0 : if (! _newFlagRow.empty()) {
789 0 : flagRow = _newFlagRow.copy();
790 0 : return;
791 : }
792 0 : Cube<Bool> flags;
793 0 : flag(flags);
794 0 : getVii()->flagRow(flagRow);
795 0 : auto nrows = nRows();
796 0 : for (rownr_t i=0; i<nrows; ++i) {
797 0 : flagRow[i] = allTrue(flags.xyPlane(i));
798 : }
799 0 : _newFlagRow = flagRow.copy();
800 0 : }
801 :
802 0 : void StatWtTVI::originChunks(Bool forceRewind) {
803 : // Drive next lower layer
804 0 : getVii()->originChunks(forceRewind);
805 0 : _weightsComputed = False;
806 0 : _dataAggregator->aggregate();
807 0 : _weightsComputed = True;
808 0 : _clearCache();
809 : // re-origin this chunk in next layer
810 : // (ensures wider scopes see start of the this chunk)
811 0 : getVii()->origin();
812 0 : }
813 :
814 0 : void StatWtTVI::nextChunk() {
815 : // Drive next lower layer
816 0 : getVii()->nextChunk();
817 0 : _weightsComputed = False;
818 0 : _dataAggregator->aggregate();
819 0 : _weightsComputed = True;
820 0 : _clearCache();
821 : // re-origin this chunk next layer
822 : // (ensures wider scopes see start of the this chunk)
823 0 : getVii()->origin();
824 0 : }
825 :
826 0 : void StatWtTVI::_clearCache() {
827 0 : _newWtSp.resize(0, 0, 0);
828 0 : _newWt.resize(0, 0);
829 0 : _newFlag.resize(0, 0, 0);
830 0 : _newFlagRow.resize(0);
831 0 : }
832 :
833 0 : Cube<Bool> StatWtTVI::_getResultantFlags(
834 : Cube<Bool>& chanSelFlagTemplate, Cube<Bool>& chanSelFlags,
835 : Bool& initTemplate, Int spw, const Cube<Bool>& flagCube
836 : ) const {
837 0 : if (_chanSelFlags.find(spw) == _chanSelFlags.cend()) {
838 : // no selection of channels to ignore
839 0 : return flagCube;
840 : }
841 0 : if (initTemplate) {
842 : // this can be done just once per chunk because all the rows
843 : // in the chunk are guaranteed to have the same spw
844 : // because each subchunk is guaranteed to have a single
845 : // data description ID.
846 0 : chanSelFlagTemplate = _chanSelFlags.find(spw)->second;
847 0 : initTemplate = False;
848 : }
849 0 : auto dataShape = flagCube.shape();
850 0 : chanSelFlags.resize(dataShape, False);
851 0 : auto ncorr = dataShape[0];
852 0 : auto nrows = dataShape[2];
853 0 : IPosition start(3, 0);
854 0 : IPosition end = dataShape - 1;
855 0 : Slicer sl(start, end, Slicer::endIsLast);
856 0 : for (uInt corr=0; corr<ncorr; ++corr) {
857 0 : start[0] = corr;
858 0 : end[0] = corr;
859 0 : for (Int row=0; row<nrows; ++row) {
860 0 : start[2] = row;
861 0 : end[2] = row;
862 0 : sl.setStart(start);
863 0 : sl.setEnd(end);
864 0 : chanSelFlags(sl) = chanSelFlagTemplate;
865 : }
866 : }
867 0 : return flagCube || chanSelFlags;
868 0 : }
869 :
870 0 : void StatWtTVI::initWeightSpectrum (const Cube<Float>& wtspec) {
871 : // Pass to next layer down
872 0 : getVii()->initWeightSpectrum(wtspec);
873 0 : }
874 :
875 0 : void StatWtTVI::initSigmaSpectrum (const Cube<Float>& sigspec) {
876 : // Pass to next layer down
877 0 : getVii()->initSigmaSpectrum(sigspec);
878 0 : }
879 :
880 :
881 0 : void StatWtTVI::writeBackChanges(VisBuffer2 *vb) {
882 : // Pass to next layer down
883 0 : getVii()->writeBackChanges(vb);
884 0 : }
885 :
886 0 : void StatWtTVI::summarizeFlagging() const {
887 0 : auto orig = (Double)_nOrigFlaggedPts/(Double)_nTotalPts*100;
888 0 : auto stwt = (Double)_nNewFlaggedPts/(Double)_nTotalPts*100;
889 0 : auto total = orig + stwt;
890 0 : LogIO log(LogOrigin("StatWtTVI", __func__));
891 : log << LogIO::NORMAL << "Originally, " << orig
892 : << "% of the data were flagged. StatWtTVI flagged an "
893 0 : << "additional " << stwt << "%." << LogIO::POST;
894 : log << LogIO::NORMAL << "TOTAL FLAGGED DATA AFTER RUNNING STATWT: "
895 0 : << total << "%" << LogIO::POST;
896 0 : log << LogIO::NORMAL << std::endl << LogIO::POST;
897 0 : if (_nOrigFlaggedPts == _nTotalPts) {
898 : log << LogIO::WARN << "IT APPEARS THAT ALL THE DATA IN THE INPUT "
899 : << "MS/SELECTION WERE FLAGGED PRIOR TO RUNNING STATWT"
900 0 : << LogIO::POST;
901 0 : log << LogIO::NORMAL << std::endl << LogIO::POST;
902 : }
903 0 : else if (_nOrigFlaggedPts + _nNewFlaggedPts == _nTotalPts) {
904 : log << LogIO::WARN << "IT APPEARS THAT STATWT FLAGGED ALL THE DATA "
905 : "IN THE REQUESTED SELECTION THAT WASN'T ORIGINALLY FLAGGED"
906 0 : << LogIO::POST;
907 0 : log << LogIO::NORMAL << std::endl << LogIO::POST;
908 : }
909 0 : String col0 = "SPECTRAL_WINDOW";
910 0 : String col1 = "SAMPLES_WITH_NON-ZERO_VARIANCE";
911 : String col2 = "SAMPLES_WHERE_REAL_PART_VARIANCE_DIFFERS_BY_>50%_FROM_"
912 0 : "IMAGINARY_PART";
913 0 : log << LogIO::NORMAL << col0 << " " << col1 << " " << col2 << LogIO::POST;
914 0 : auto n0 = col0.size();
915 0 : auto n1 = col1.size();
916 0 : auto n2 = col2.size();
917 0 : for (const auto& sample: *_samples) {
918 0 : ostringstream oss;
919 0 : oss << std::setw(n0) << sample.first << " " << std::setw(n1)
920 0 : << sample.second.first << " " << std::setw(n2)
921 0 : << sample.second.second;
922 0 : log << LogIO::NORMAL << oss.str() << LogIO::POST;
923 0 : }
924 0 : }
925 :
926 0 : void StatWtTVI::summarizeStats(Double& mean, Double& variance) const {
927 0 : LogIO log(LogOrigin("StatWtTVI", __func__));
928 0 : _logUsedChannels();
929 : try {
930 0 : mean = _wtStats->getStatistic(StatisticsData::MEAN);
931 0 : variance = _wtStats->getStatistic(StatisticsData::VARIANCE);
932 : log << LogIO::NORMAL << "The mean of the computed weights is "
933 0 : << mean << LogIO::POST;
934 : log << LogIO::NORMAL << "The variance of the computed weights is "
935 0 : << variance << LogIO::POST;
936 : log << LogIO::NORMAL << "Weights which had corresponding flags of True "
937 : << "prior to running this application were not used to compute these "
938 0 : << "stats." << LogIO::POST;
939 : }
940 0 : catch (const AipsError& x) {
941 : log << LogIO::WARN << "There was a problem calculating the mean and "
942 : << "variance of the weights computed by this application. Perhaps there "
943 : << "was something amiss with the input MS and/or the selection criteria. "
944 : << "Examples of such issues are that all the data were originally flagged "
945 : << "or that the sample size was consistently too small for computations "
946 0 : << "of variances" << LogIO::POST;
947 0 : setNaN(mean);
948 0 : setNaN(variance);
949 0 : }
950 0 : }
951 :
952 0 : void StatWtTVI::origin() {
953 : // Drive underlying ViImplementation2
954 0 : getVii()->origin();
955 : // Synchronize own VisBuffer
956 0 : configureNewSubchunk();
957 0 : _clearCache();
958 0 : }
959 :
960 0 : void StatWtTVI::next() {
961 : // Drive underlying ViImplementation2
962 0 : getVii()->next();
963 : // Synchronize own VisBuffer
964 0 : configureNewSubchunk();
965 0 : _clearCache();
966 0 : }
967 :
968 : }
969 :
970 : }
|