LCOV - code coverage report
Current view: top level - synthesis/MeasurementEquations - objfunc_alglib.h (source / functions) Hit Total Coverage
Test: casacpp_coverage.info Lines: 0 100 0.0 %
Date: 2024-10-04 16:51:10 Functions: 0 14 0.0 %

          Line data    Source code
       1             : #ifndef SYNTHESIS_OBJFUNCALGLIB_H
       2             : #define SYNTHESIS_OBJFUNCALGLIB_H
       3             : 
       4             : #include <casacore/ms/MeasurementSets/MeasurementSet.h>
       5             : #include <casacore/casa/Arrays/Matrix.h>
       6             : #include <casacore/casa/Arrays/IPosition.h>
       7             : #include <casacore/images/Images/ImageInterface.h>
       8             : #include <casacore/images/Images/PagedImage.h>
       9             : #include <casacore/images/Images/TempImage.h>
      10             : 
      11             : #include <casacore/scimath/Mathematics/FFTServer.h>
      12             : #include <casacore/scimath/Functionals/Gaussian2D.h>
      13             : 
      14             : #include "lbfgs/optimization.h"
      15             : 
      16             : #ifndef isnan
      17             : #define isnan(x) std::isnan(x)
      18             : #endif
      19             : 
      20             : namespace casa { //# NAMESPACE CASA - BEGIN
      21             : 
      22             : class ParamAlglibObj
      23             : {
      24             : private:
      25             :   int nX;
      26             :   int nY;
      27             :   unsigned int AspLen;
      28             :   casacore::Matrix<casacore::Float> itsMatDirty;
      29             :   casacore::Matrix<casacore::Complex> itsPsfFT;
      30             :   std::vector<casacore::IPosition> center;
      31             :   casacore::Matrix<casacore::Float> newResidual;
      32             :   casacore::Matrix<casacore::Float> AspConvPsf;
      33             :   casacore::Matrix<casacore::Float> dAspConvPsf;
      34             :   casacore::Matrix<casacore::Float> Asp;
      35             :   casacore::Matrix<casacore::Float> dAsp;
      36             : 
      37             : public:
      38             :   casacore::FFTServer<casacore::Float,casacore::Complex> fft;
      39             : 
      40           0 :   ParamAlglibObj(const casacore::Matrix<casacore::Float>& dirty,
      41             :     const casacore::Matrix<casacore::Complex>& psf,
      42             :     const std::vector<casacore::IPosition>& positionOptimum,
      43           0 :     const casacore::FFTServer<casacore::Float,casacore::Complex>& fftin) :
      44           0 :     itsMatDirty(dirty),
      45           0 :     itsPsfFT(psf),
      46           0 :     center(positionOptimum),
      47           0 :     fft(fftin)
      48             :   {
      49           0 :     nX = itsMatDirty.shape()(0);
      50           0 :     nY = itsMatDirty.shape()(1);
      51           0 :     AspLen = center.size();
      52           0 :     newResidual.resize(nX, nY);
      53           0 :     AspConvPsf.resize(nX, nY);
      54           0 :     dAspConvPsf.resize(nX, nY);
      55           0 :     Asp.resize(nX, nY);
      56           0 :     dAsp.resize(nX, nY);
      57           0 :   }
      58             : 
      59           0 :   ~ParamAlglibObj() = default;
      60             : 
      61           0 :   casacore::Matrix<casacore::Float>  getterDirty() { return itsMatDirty; }
      62           0 :   casacore::Matrix<casacore::Complex> getterPsfFT() { return itsPsfFT; }
      63           0 :   std::vector<casacore::IPosition> getterCenter() {return center;}
      64           0 :   unsigned int getterAspLen() { return AspLen; }
      65           0 :   int getterNX() { return nX; }
      66           0 :   int getterNY() { return nY; }
      67           0 :   casacore::Matrix<casacore::Float>  getterRes() { return newResidual; }
      68             :   void setterRes(const casacore::Matrix<casacore::Float>& res) { newResidual = res; }
      69           0 :   casacore::Matrix<casacore::Float>  getterAspConvPsf() { return AspConvPsf; }
      70             :   void setterAspConvPsf(const casacore::Matrix<casacore::Float>& m) { AspConvPsf = m; }
      71           0 :   casacore::Matrix<casacore::Float>  getterDAspConvPsf() { return dAspConvPsf; }
      72           0 :   casacore::Matrix<casacore::Float>  getterAsp() { return Asp; }
      73             :   void setterAsp(const casacore::Matrix<casacore::Float>& m) { Asp = m; }
      74           0 :   casacore::Matrix<casacore::Float>  getterDAsp() { return dAsp; }
      75             : };
      76             : 
      77           0 : void objfunc_alglib(const alglib::real_1d_array &x, double &func, alglib::real_1d_array &grad, void *ptr) 
      78             : {
      79             :     // retrieve params for GSL bfgs optimization
      80           0 :     casa::ParamAlglibObj *MyP = (casa::ParamAlglibObj *) ptr; //re-cast back to ParamAlglibObj to retrieve images
      81             : 
      82           0 :     casacore::Matrix<casacore::Float> itsMatDirty(MyP->getterDirty());
      83           0 :     casacore::Matrix<casacore::Complex> itsPsfFT(MyP->getterPsfFT());
      84           0 :     std::vector<casacore::IPosition> center = MyP->getterCenter();
      85           0 :     const unsigned int AspLen = MyP->getterAspLen();
      86           0 :     const int nX = MyP->getterNX();
      87           0 :     const int nY = MyP->getterNY();
      88           0 :     casacore::Matrix<casacore::Float> newResidual(MyP->getterRes());
      89           0 :     casacore::Matrix<casacore::Float> AspConvPsf(MyP->getterAspConvPsf());
      90           0 :     casacore::Matrix<casacore::Float> Asp(MyP->getterAsp());
      91           0 :     casacore::Matrix<casacore::Float> dAspConvPsf(MyP->getterDAspConvPsf());
      92           0 :     casacore::Matrix<casacore::Float> dAsp(MyP->getterDAsp());
      93             : 
      94           0 :     func = 0;
      95           0 :     double amp = 1;
      96             : 
      97           0 :     const int refi = nX/2;
      98           0 :     const int refj = nY/2;
      99             : 
     100           0 :     int minX = nX - 1;
     101           0 :     int maxX = 0;
     102           0 :     int minY = nY - 1;
     103           0 :     int maxY = 0;
     104             : 
     105             :     // First, get the amp * AspenConvPsf for each Aspen to update the residual
     106           0 :     for (unsigned int k = 0; k < AspLen; k ++)
     107             :     {
     108           0 :         amp = x[2*k];
     109           0 :         double scale = x[2*k+1];
     110             :         //std::cout << "f: amp " << amp << " scale " << scale << std::endl;
     111             : 
     112           0 :       if (isnan(amp) || scale < 0.4) // GSL scale < 0
     113             :       {
     114             :         //std::cout << "nan? " << amp << " neg scale? " << scale << std::endl;
     115             :         // If scale is small (<0.4), make it 0 scale to utilize Hogbom and save time
     116           0 :         scale = (scale = fabs(scale)) < 0.4 ? 0 : scale;
     117             :         //std::cout << "reset neg scale to " << scale << std::endl;
     118             : 
     119           0 :         if (scale <= 0)
     120           0 :           return;
     121             :       }
     122             : 
     123             :       // generate a gaussian for each Asp in the Aspen set
     124             :       // x[0]: Amplitude0,       x[1]: scale0
     125             :       // x[2]: Amplitude1,       x[3]: scale1
     126             :       // x[2k]: Amplitude(k), x[2k+1]: scale(k+1)
     127             :       //casacore::Matrix<casacore::Float> Asp(nX, nY);
     128           0 :       Asp = 0.0;
     129           0 :       dAsp = 0.0;
     130             : 
     131           0 :       const double sigma5 = 5 * scale / 2;
     132           0 :       const int minI = std::max(0, (int)(center[k][0] - sigma5));
     133           0 :       const int maxI = std::min(nX-1, (int)(center[k][0] + sigma5));
     134           0 :       const int minJ = std::max(0, (int)(center[k][1] - sigma5));
     135           0 :       const int maxJ = std::min(nY-1, (int)(center[k][1] + sigma5));
     136             : 
     137           0 :       if (minI < minX)
     138           0 :         minX = minI;
     139           0 :       if (maxI > maxX)
     140           0 :         maxX = maxI;
     141           0 :       if (minJ < minY)
     142           0 :         minY = minJ;
     143           0 :       if (maxJ > maxY)
     144           0 :         maxY = maxJ;
     145             : 
     146           0 :       for (int j = minJ; j <= maxJ; j++)
     147             :       {
     148           0 :         for (int i = minI; i <= maxI; i++)
     149             :         {
     150           0 :           const int px = i;
     151           0 :           const int py = j;
     152             : 
     153           0 :           Asp(i,j) = (1.0/(sqrt(2*M_PI)*fabs(scale)))*exp(-(pow(i-center[k][0],2) + pow(j-center[k][1],2))*0.5/pow(scale,2));
     154           0 :           dAsp(i,j)= Asp(i,j) * (((pow(i-center[k][0],2) + pow(j-center[k][1],2)) / pow(scale,2) - 1) / fabs(scale)); // verified by python
     155             :         }
     156             :       }
     157             : 
     158           0 :       casacore::Matrix<casacore::Complex> AspFT;
     159           0 :       MyP->fft.fft0(AspFT, Asp);
     160           0 :       casacore::Matrix<casacore::Complex> cWork;
     161           0 :       cWork = AspFT * itsPsfFT;
     162           0 :       MyP->fft.fft0(AspConvPsf, cWork, false);
     163           0 :       MyP->fft.flip(AspConvPsf, false, false); //need this
     164             : 
     165             :       // gradient. 0: amplitude; 1: scale
     166             :       // returns the gradient evaluated on x
     167           0 :       casacore::Matrix<casacore::Complex> dAspFT;
     168             : 
     169             :       //auto start = std::chrono::high_resolution_clock::now();
     170           0 :       MyP->fft.fft0(dAspFT, dAsp);
     171             :       //auto stop = std::chrono::high_resolution_clock::now();
     172             :       //auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start) ;
     173             :       //std::cout << "BFGS fft0 runtime " << duration.count() << " us" << std::endl;
     174             : 
     175           0 :       casacore::Matrix<casacore::Complex> dcWork;
     176           0 :       dcWork = dAspFT * itsPsfFT;
     177           0 :       MyP->fft.fft0(dAspConvPsf, dcWork, false);
     178           0 :       MyP->fft.flip(dAspConvPsf, false, false); //need this
     179           0 :     } // end get amp * AspenConvPsf
     180             : 
     181             :     // reset grad to 0. This is important to get the correct optimization.
     182           0 :     double dA = 0.0;
     183           0 :     double dS = 0.0;
     184             : 
     185             :     // Update the residual using the current residual image and the latest Aspen.
     186             :     // Sanjay used, Res = OrigDirty - active-set aspen * Psf, in 2004, instead.
     187             :     // Both works but the current approach is simpler and performs well too.
     188           0 :     for (int j = minY; j < maxY; ++j)
     189             :     {
     190           0 :       for(int i = minX; i < maxX; ++i)
     191             :       {
     192           0 :         newResidual(i, j) = itsMatDirty(i, j) - amp * AspConvPsf(i, j);
     193           0 :         func = func + double(pow(newResidual(i, j), 2));
     194             : 
     195             :         // derivatives of amplitude
     196           0 :         dA += double((-2) * newResidual(i,j) * AspConvPsf(i,j));
     197             :         // derivative of scale
     198           0 :         dS += double((-2) * amp * newResidual(i,j) * dAspConvPsf(i,j));
     199             :       }
     200             :     }
     201             :     //std::cout << "after f " << func << std::endl;
     202             : 
     203           0 :     grad[0] = dA;
     204           0 :     grad[1] = dS; 
     205           0 : }
     206             : 
     207             : 
     208             : 
     209             : } // end namespace casa
     210             : 
     211             : #endif // SYNTHESIS_OBJFUNCALGLIB_H

Generated by: LCOV version 1.16