LCOV - code coverage report
Current view: top level - synthesis/MeasurementEquations - objfunc_alglib.h (source / functions) Hit Total Coverage
Test: casacpp_coverage.info Lines: 100 100 100.0 %
Date: 2024-12-11 20:54:31 Functions: 14 14 100.0 %

          Line data    Source code
       1             : #ifndef SYNTHESIS_OBJFUNCALGLIB_H
       2             : #define SYNTHESIS_OBJFUNCALGLIB_H
       3             : 
       4             : #include <casacore/ms/MeasurementSets/MeasurementSet.h>
       5             : #include <casacore/casa/Arrays/Matrix.h>
       6             : #include <casacore/casa/Arrays/IPosition.h>
       7             : #include <casacore/images/Images/ImageInterface.h>
       8             : #include <casacore/images/Images/PagedImage.h>
       9             : #include <casacore/images/Images/TempImage.h>
      10             : 
      11             : #include <casacore/scimath/Mathematics/FFTServer.h>
      12             : #include <casacore/scimath/Functionals/Gaussian2D.h>
      13             : 
      14             : #include "lbfgs/optimization.h"
      15             : 
      16             : #ifndef isnan
      17             : #define isnan(x) std::isnan(x)
      18             : #endif
      19             : 
      20             : namespace casa { //# NAMESPACE CASA - BEGIN
      21             : 
      22             : class ParamAlglibObj
      23             : {
      24             : private:
      25             :   int nX;
      26             :   int nY;
      27             :   unsigned int AspLen;
      28             :   casacore::Matrix<casacore::Float> itsMatDirty;
      29             :   casacore::Matrix<casacore::Complex> itsPsfFT;
      30             :   std::vector<casacore::IPosition> center;
      31             :   casacore::Matrix<casacore::Float> newResidual;
      32             :   casacore::Matrix<casacore::Float> AspConvPsf;
      33             :   casacore::Matrix<casacore::Float> dAspConvPsf;
      34             :   casacore::Matrix<casacore::Float> Asp;
      35             :   casacore::Matrix<casacore::Float> dAsp;
      36             : 
      37             : public:
      38             :   casacore::FFTServer<casacore::Float,casacore::Complex> fft;
      39             : 
      40         789 :   ParamAlglibObj(const casacore::Matrix<casacore::Float>& dirty,
      41             :     const casacore::Matrix<casacore::Complex>& psf,
      42             :     const std::vector<casacore::IPosition>& positionOptimum,
      43         789 :     const casacore::FFTServer<casacore::Float,casacore::Complex>& fftin) :
      44         789 :     itsMatDirty(dirty),
      45         789 :     itsPsfFT(psf),
      46         789 :     center(positionOptimum),
      47         789 :     fft(fftin)
      48             :   {
      49         789 :     nX = itsMatDirty.shape()(0);
      50         789 :     nY = itsMatDirty.shape()(1);
      51         789 :     AspLen = center.size();
      52         789 :     newResidual.resize(nX, nY);
      53         789 :     AspConvPsf.resize(nX, nY);
      54         789 :     dAspConvPsf.resize(nX, nY);
      55         789 :     Asp.resize(nX, nY);
      56         789 :     dAsp.resize(nX, nY);
      57         789 :   }
      58             : 
      59         789 :   ~ParamAlglibObj() = default;
      60             : 
      61       12700 :   casacore::Matrix<casacore::Float>  getterDirty() { return itsMatDirty; }
      62       12700 :   casacore::Matrix<casacore::Complex> getterPsfFT() { return itsPsfFT; }
      63       12700 :   std::vector<casacore::IPosition> getterCenter() {return center;}
      64       12700 :   unsigned int getterAspLen() { return AspLen; }
      65       12700 :   int getterNX() { return nX; }
      66       12700 :   int getterNY() { return nY; }
      67       12700 :   casacore::Matrix<casacore::Float>  getterRes() { return newResidual; }
      68             :   void setterRes(const casacore::Matrix<casacore::Float>& res) { newResidual = res; }
      69       12700 :   casacore::Matrix<casacore::Float>  getterAspConvPsf() { return AspConvPsf; }
      70             :   void setterAspConvPsf(const casacore::Matrix<casacore::Float>& m) { AspConvPsf = m; }
      71       12700 :   casacore::Matrix<casacore::Float>  getterDAspConvPsf() { return dAspConvPsf; }
      72       12700 :   casacore::Matrix<casacore::Float>  getterAsp() { return Asp; }
      73             :   void setterAsp(const casacore::Matrix<casacore::Float>& m) { Asp = m; }
      74       12700 :   casacore::Matrix<casacore::Float>  getterDAsp() { return dAsp; }
      75             : };
      76             : 
      77       12700 : void objfunc_alglib(const alglib::real_1d_array &x, double &func, alglib::real_1d_array &grad, void *ptr) 
      78             : {
      79             :     // retrieve params for GSL bfgs optimization
      80       12700 :     casa::ParamAlglibObj *MyP = (casa::ParamAlglibObj *) ptr; //re-cast back to ParamAlglibObj to retrieve images
      81             : 
      82       12700 :     casacore::Matrix<casacore::Float> itsMatDirty(MyP->getterDirty());
      83       12700 :     casacore::Matrix<casacore::Complex> itsPsfFT(MyP->getterPsfFT());
      84       12700 :     std::vector<casacore::IPosition> center = MyP->getterCenter();
      85       12700 :     const unsigned int AspLen = MyP->getterAspLen();
      86       12700 :     const int nX = MyP->getterNX();
      87       12700 :     const int nY = MyP->getterNY();
      88       12700 :     casacore::Matrix<casacore::Float> newResidual(MyP->getterRes());
      89       12700 :     casacore::Matrix<casacore::Float> AspConvPsf(MyP->getterAspConvPsf());
      90       12700 :     casacore::Matrix<casacore::Float> Asp(MyP->getterAsp());
      91       12700 :     casacore::Matrix<casacore::Float> dAspConvPsf(MyP->getterDAspConvPsf());
      92       12700 :     casacore::Matrix<casacore::Float> dAsp(MyP->getterDAsp());
      93             : 
      94       12700 :     func = 0;
      95       12700 :     double amp = 1;
      96             : 
      97       12700 :     const int refi = nX/2;
      98       12700 :     const int refj = nY/2;
      99             : 
     100       12700 :     int minX = nX - 1;
     101       12700 :     int maxX = 0;
     102       12700 :     int minY = nY - 1;
     103       12700 :     int maxY = 0;
     104             : 
     105             :     // First, get the amp * AspenConvPsf for each Aspen to update the residual
     106       25268 :     for (unsigned int k = 0; k < AspLen; k ++)
     107             :     {
     108       12700 :         amp = x[2*k];
     109       12700 :         double scale = x[2*k+1];
     110             :         //std::cout << "f: amp " << amp << " scale " << scale << std::endl;
     111             : 
     112       12700 :       if (isnan(amp) || scale < 0.4) // GSL scale < 0
     113             :       {
     114             :         //std::cout << "nan? " << amp << " neg scale? " << scale << std::endl;
     115             :         // If scale is small (<0.4), make it 0 scale to utilize Hogbom and save time
     116         178 :         scale = (scale = fabs(scale)) < 0.4 ? 0 : scale;
     117             :         //std::cout << "reset neg scale to " << scale << std::endl;
     118             : 
     119         178 :         if (scale <= 0)
     120         132 :           return;
     121             :       }
     122             : 
     123             :       // generate a gaussian for each Asp in the Aspen set
     124             :       // x[0]: Amplitude0,       x[1]: scale0
     125             :       // x[2]: Amplitude1,       x[3]: scale1
     126             :       // x[2k]: Amplitude(k), x[2k+1]: scale(k+1)
     127             :       //casacore::Matrix<casacore::Float> Asp(nX, nY);
     128       12568 :       Asp = 0.0;
     129       12568 :       dAsp = 0.0;
     130             : 
     131       12568 :       const double sigma5 = 5 * scale / 2;
     132       12568 :       const int minI = std::max(0, (int)(center[k][0] - sigma5));
     133       12568 :       const int maxI = std::min(nX-1, (int)(center[k][0] + sigma5));
     134       12568 :       const int minJ = std::max(0, (int)(center[k][1] - sigma5));
     135       12568 :       const int maxJ = std::min(nY-1, (int)(center[k][1] + sigma5));
     136             : 
     137       12568 :       if (minI < minX)
     138       12568 :         minX = minI;
     139       12568 :       if (maxI > maxX)
     140       12568 :         maxX = maxI;
     141       12568 :       if (minJ < minY)
     142       12568 :         minY = minJ;
     143       12568 :       if (maxJ > maxY)
     144       12568 :         maxY = maxJ;
     145             : 
     146      947370 :       for (int j = minJ; j <= maxJ; j++)
     147             :       {
     148   158804516 :         for (int i = minI; i <= maxI; i++)
     149             :         {
     150   157869714 :           const int px = i;
     151   157869714 :           const int py = j;
     152             : 
     153   157869714 :           Asp(i,j) = (1.0/(sqrt(2*M_PI)*fabs(scale)))*exp(-(pow(i-center[k][0],2) + pow(j-center[k][1],2))*0.5/pow(scale,2));
     154   157869714 :           dAsp(i,j)= Asp(i,j) * (((pow(i-center[k][0],2) + pow(j-center[k][1],2)) / pow(scale,2) - 1) / fabs(scale)); // verified by python
     155             :         }
     156             :       }
     157             : 
     158       12568 :       casacore::Matrix<casacore::Complex> AspFT;
     159       12568 :       MyP->fft.fft0(AspFT, Asp);
     160       12568 :       casacore::Matrix<casacore::Complex> cWork;
     161       12568 :       cWork = AspFT * itsPsfFT;
     162       12568 :       MyP->fft.fft0(AspConvPsf, cWork, false);
     163       12568 :       MyP->fft.flip(AspConvPsf, false, false); //need this
     164             : 
     165             :       // gradient. 0: amplitude; 1: scale
     166             :       // returns the gradient evaluated on x
     167       12568 :       casacore::Matrix<casacore::Complex> dAspFT;
     168             : 
     169             :       //auto start = std::chrono::high_resolution_clock::now();
     170       12568 :       MyP->fft.fft0(dAspFT, dAsp);
     171             :       //auto stop = std::chrono::high_resolution_clock::now();
     172             :       //auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start) ;
     173             :       //std::cout << "BFGS fft0 runtime " << duration.count() << " us" << std::endl;
     174             : 
     175       12568 :       casacore::Matrix<casacore::Complex> dcWork;
     176       12568 :       dcWork = dAspFT * itsPsfFT;
     177       12568 :       MyP->fft.fft0(dAspConvPsf, dcWork, false);
     178       12568 :       MyP->fft.flip(dAspConvPsf, false, false); //need this
     179       12568 :     } // end get amp * AspenConvPsf
     180             : 
     181             :     // reset grad to 0. This is important to get the correct optimization.
     182       12568 :     double dA = 0.0;
     183       12568 :     double dS = 0.0;
     184             : 
     185             :     // Update the residual using the current residual image and the latest Aspen.
     186             :     // Sanjay used, Res = OrigDirty - active-set aspen * Psf, in 2004, instead.
     187             :     // Both works but the current approach is simpler and performs well too.
     188      934802 :     for (int j = minY; j < maxY; ++j)
     189             :     {
     190   156864550 :       for(int i = minX; i < maxX; ++i)
     191             :       {
     192   155942316 :         newResidual(i, j) = itsMatDirty(i, j) - amp * AspConvPsf(i, j);
     193   155942316 :         func = func + double(pow(newResidual(i, j), 2));
     194             : 
     195             :         // derivatives of amplitude
     196   155942316 :         dA += double((-2) * newResidual(i,j) * AspConvPsf(i,j));
     197             :         // derivative of scale
     198   155942316 :         dS += double((-2) * amp * newResidual(i,j) * dAspConvPsf(i,j));
     199             :       }
     200             :     }
     201             :     //std::cout << "after f " << func << std::endl;
     202             : 
     203       12568 :     grad[0] = dA;
     204       12568 :     grad[1] = dS; 
     205       13624 : }
     206             : 
     207             : 
     208             : 
     209             : } // end namespace casa
     210             : 
     211             : #endif // SYNTHESIS_OBJFUNCALGLIB_H

Generated by: LCOV version 1.16