LCOV - code coverage report
Current view: top level - bnmin1/src - gradientminim.cc (source / functions) Hit Total Coverage
Test: casacpp_coverage.info Lines: 0 73 0.0 %
Date: 2024-10-12 00:35:29 Functions: 0 10 0.0 %

          Line data    Source code
       1             : /**
       2             :    Bojan Nikolic <bojan@bnikolic.co.uk> 
       3             :    Initial version 2008
       4             : 
       5             :    This file is part of BNMin1 and is licensed under GNU General
       6             :    Public License version 2.
       7             : 
       8             :    \file gradientminim.cxx
       9             :    Renamed to gradientminim.cc 2023
      10             : 
      11             : */
      12             : 
      13             : #include "gradientminim.h"
      14             : #include "gradientmodel.h"
      15             : 
      16             : #include <gsl/gsl_multimin.h>
      17             : 
      18             : 
      19             : //#include <functional> //<boost/mem_fn.hpp>, was not used
      20             : 
      21             : #include <iostream>
      22             : #include <stdexcept>
      23             : 
      24             : namespace Minim {
      25             : 
      26             :   struct GSLGradWrap
      27             :   {
      28             :     
      29             :     ModelDesc &minim;
      30             :     LGradientModel &model;
      31             :     size_t n;
      32             : 
      33           0 :     GSLGradWrap(ModelDesc &minim,
      34           0 :                 LGradientModel &model):
      35           0 :       minim(minim),
      36           0 :       model(model),
      37           0 :       n(minim.NParam())
      38             :     {
      39           0 :     }
      40             : 
      41           0 :     void setpars(const gsl_vector * X)
      42             :     {
      43           0 :       std::vector<double> pars(n);
      44           0 :       for(size_t i = 0; i<n; ++i)
      45             :       {
      46           0 :         pars[i]=gsl_vector_get(X,i);
      47             :       }
      48           0 :       minim.put(pars);
      49           0 :     }
      50             : 
      51           0 :     double f(const gsl_vector * X)
      52             :     {
      53           0 :       setpars(X);
      54           0 :       const double res=model.lLikely();
      55           0 :       return res;
      56             :     }
      57             : 
      58           0 :     void df(const gsl_vector * X, 
      59             :             gsl_vector * G)
      60             :     {
      61           0 :       setpars(X);
      62           0 :       std::vector<double> g;
      63           0 :       model.lGrd(g);
      64           0 :       for(size_t i =0 ; i < g.size(); ++i)
      65             :       {
      66           0 :         gsl_vector_set(G,i,g[i]);
      67             :       }
      68             : 
      69           0 :     }
      70             : 
      71           0 :     void fdf (const gsl_vector * X, 
      72             :               double * f, 
      73             :               gsl_vector * G)
      74             :     {
      75           0 :       setpars(X);
      76           0 :       *f=model.lLikely();
      77           0 :       std::vector<double> g;
      78           0 :       model.lGrd(g);
      79           0 :       for(size_t i =0 ; i < g.size(); ++i)
      80             :       {
      81           0 :         gsl_vector_set(G,i,g[i]);
      82             :       }
      83           0 :     }
      84             : 
      85             :   };
      86             : 
      87           0 :   double bngsl_f(const gsl_vector * X, 
      88             :                  void * PARAMS)
      89             :   {
      90           0 :     if (not PARAMS)
      91             :     {
      92           0 :       throw std::runtime_error("BNGSL not passed a pointer to data class");
      93             :     }
      94           0 :     return reinterpret_cast<GSLGradWrap*>(PARAMS)->f(X);
      95             :   }
      96             : 
      97           0 :   void bngsl_df(const gsl_vector * X, 
      98             :                 void * PARAMS, 
      99             :                 gsl_vector * G)  
     100             :   {
     101           0 :     if (not PARAMS)
     102             :     {
     103           0 :       throw std::runtime_error("BNGSL not passed a pointer to data class");
     104             :     }
     105           0 :     reinterpret_cast<GSLGradWrap*>(PARAMS)->df(X,G);
     106           0 :   }
     107             : 
     108           0 :   void bngsl_fdf (const gsl_vector * X, 
     109             :                   void * PARAMS, 
     110             :                   double * f, 
     111             :                   gsl_vector * G)
     112             :   {
     113           0 :     if (not PARAMS)
     114             :     {
     115           0 :       throw std::runtime_error("BNGSL not passed a pointer to data class");
     116             :     }
     117           0 :     reinterpret_cast<GSLGradWrap*>(PARAMS)->fdf(X,f,G);
     118           0 :   }
     119             :   
     120             :     
     121             : 
     122           0 :   BFGS2Minim::BFGS2Minim(LGradientModel &pm):
     123             :     ModelDesc(pm),
     124           0 :     lgm(pm)
     125             :   {
     126           0 :   }
     127             : 
     128           0 :   void BFGS2Minim::solve(void)
     129             :   {
     130           0 :     GSLGradWrap w(*this, lgm);
     131             :     // Function to minimise
     132             :     gsl_multimin_function_fdf mfunc;
     133             : 
     134           0 :     mfunc.n=NParam();
     135           0 :     mfunc.f=&bngsl_f;
     136           0 :     mfunc.df=&bngsl_df;
     137           0 :     mfunc.fdf=&bngsl_fdf;
     138           0 :     mfunc.params=reinterpret_cast<void*>(&w);
     139             :     
     140           0 :     gsl_vector *startp = gsl_vector_alloc(NParam());
     141           0 :     std::vector<double> startv(NParam());
     142           0 :     get(startv);
     143           0 :     for (size_t i =0; i<startv.size(); ++i)
     144             :     {
     145           0 :       gsl_vector_set(startp,i,startv[i]);
     146             :     }
     147             : 
     148           0 :     gsl_multimin_fdfminimizer *s=gsl_multimin_fdfminimizer_alloc(gsl_multimin_fdfminimizer_vector_bfgs2,
     149           0 :                                                                  NParam());
     150           0 :     gsl_multimin_fdfminimizer_set(s,
     151             :                                   &mfunc,
     152             :                                   startp, 
     153             :                                   0.01,
     154             :                                   1e-4);
     155             : 
     156           0 :     size_t iter=0;
     157             :     int status;
     158             :     do
     159             :     {
     160           0 :       ++iter;
     161           0 :       status = gsl_multimin_fdfminimizer_iterate(s);
     162             : 
     163           0 :       if (status)
     164             :       {
     165           0 :         break;
     166             :         //throw std::runtime_error("Problem in minimisation iteration");
     167             :       }
     168             :       
     169           0 :       status = gsl_multimin_test_gradient(s->gradient,
     170             :                                           1e-3);
     171             :       
     172             :     }
     173           0 :     while (status == GSL_CONTINUE && iter < 100);
     174             :     
     175           0 :     gsl_multimin_fdfminimizer_free(s);
     176           0 :     gsl_vector_free(startp);
     177             :     
     178           0 :   }    
     179             : }
     180             : 
     181             : 
     182             : 

Generated by: LCOV version 1.16