Line data Source code
1 : /**
2 : Bojan Nikolic <bojan@bnikolic.co.uk>
3 : Initial version 2009
4 :
5 : This file is part of BNMin1 and is licensed under GNU General
6 : Public License version 2
7 :
8 : \file nestedsampler.cxx
9 : Renamed to nestedsampler.cc 2023
10 :
11 : */
12 :
13 : #include <cmath>
14 : #include <iostream>
15 :
16 : #include <random>
17 :
18 : #include "nestedsampler.h"
19 : #include "priors.h"
20 : #include "minim.h"
21 : #include "prior_sampler.h"
22 : #include "nestederr.h"
23 : #include "mcmonitor.h"
24 : #include "nestedinitial.h"
25 :
26 : namespace Minim {
27 :
28 0 : NestedS::NestedS(PriorNLikelihood & ml,
29 : const std::list<MCPoint> & start,
30 0 : unsigned /*seed*/) noexcept(false):
31 : ModelDesc(ml),
32 0 : Zseq(1,0.0),
33 0 : Xseq(1,1.0),
34 0 : ml(ml),
35 0 : ps(new CSRMSSS(ml, *this, g_ss())),
36 0 : initials(new InitialWorst()),
37 0 : mon(nullptr),
38 0 : n_psample(100)
39 : {
40 0 : llPoint(ml,
41 : start,
42 0 : ss);
43 :
44 : //ps->mon= new SOutMCMon();
45 :
46 0 : if (ss.size() < 2)
47 : {
48 0 : throw NestedSmallStart(start);
49 : }
50 0 : }
51 :
52 81 : NestedS::NestedS(PriorNLikelihood & ml,
53 81 : unsigned /*seed*/):
54 : ModelDesc(ml),
55 81 : Zseq(1,0.0),
56 81 : Xseq(1,1.0),
57 81 : ml(ml),
58 81 : ps(nullptr),
59 81 : initials(new InitialWorst()),
60 81 : mon(nullptr),
61 162 : n_psample(100)
62 : {
63 81 : }
64 :
65 162 : NestedS::~NestedS(void)
66 : {
67 81 : if(ps)
68 : {
69 81 : delete ps->mon;
70 : }
71 162 : }
72 :
73 81 : void NestedS::reset(const std::list<MCPoint> &start)
74 : {
75 81 : if (start.size() < 2)
76 : {
77 0 : throw NestedSmallStart(start);
78 : }
79 81 : if (start.begin()->p.size() != NParam())
80 : {
81 0 : throw BaseErr("Dimension of start set points is not the same as number of parameters to fit");
82 : }
83 :
84 81 : ps.reset(new CSRMSSS(ml,
85 : *this,
86 81 : g_ss()));
87 81 : Zseq.clear();
88 81 : Zseq.push_back(0.0);
89 81 : Xseq.clear();
90 81 : Xseq.push_back(1.0);
91 81 : llPoint(ml,
92 : *this,
93 : start,
94 81 : ss);
95 81 : }
96 :
97 0 : void NestedS::InitalS(NestedInitial *ins)
98 : {
99 0 : initials.reset(ins);
100 0 : }
101 :
102 771024 : size_t NestedS::N(void) const
103 : {
104 771024 : return ss.size();
105 : }
106 :
107 81 : double NestedS::sample(size_t j)
108 : {
109 771101 : for (size_t i=0; i<j; ++i)
110 : {
111 771024 : std::set<MCPoint>::iterator worst( --ss.end() );
112 :
113 771024 : const double Llow=exp(-worst->ll);
114 771024 : const double X=exp(-((double)Xseq.size())/N());
115 771024 : const double w=Xseq[Xseq.size()-1]-X;
116 :
117 : // Look for the next sample
118 771024 : put((*initials)(*this).p);
119 : //put(worst->p);
120 771024 : const double newl = ps->advance(worst->ll,
121 : n_psample);
122 :
123 : // Create new point
124 771024 : MCPoint np;
125 771024 : np.p.resize(NParam());
126 771024 : get(np.p);
127 771024 : np.ll=-newl;
128 :
129 : // Is the new sample actually inside the contours of last?
130 771024 : const bool better = -newl < worst->ll;
131 :
132 771024 : if (not better )
133 : {
134 : // Can't find a better point so terminate early.
135 :
136 : // Note that this test is not definitive, since some
137 : // strategies will start from a point which not the worst
138 : // point and hence will return a "better" point event though
139 : // they haven not actually advanced their chain at all. See
140 : // below.
141 :
142 4 : break;
143 : }
144 :
145 : // Save the point about to be bumped off
146 771020 : Zseq.push_back(Zseq[Zseq.size()-1] + Llow* w);
147 771020 : Xseq.push_back(X);
148 771020 : post.push_back(WPPoint(*worst, w));
149 :
150 : // Erase old point
151 771020 : ss.erase(worst);
152 :
153 771020 : std::pair<std::set<MCPoint>::iterator, bool> r=ss.insert(np);
154 771020 : if (not r.second)
155 : {
156 : // Could not insert a point because it has identical
157 : // likelihood to an existing point. Can not contiue as we have
158 : // fewer points in the live set now.
159 :
160 : // Note that this is often due to the chain not avancing in
161 : // the constained sampler.
162 0 : break;
163 : }
164 :
165 771020 : if(mon)
166 0 : mon->accept(np);
167 :
168 771024 : }
169 81 : return Zseq[Zseq.size()-1];
170 : }
171 :
172 0 : double NestedS::Z(void) const
173 : {
174 0 : return Zseq[Zseq.size()-1];
175 : }
176 :
177 :
178 81 : const std::list<WPPoint> & NestedS::g_post(void) const
179 : {
180 81 : return post;
181 : }
182 :
183 771105 : const std::set<MCPoint> & NestedS::g_ss(void) const
184 : {
185 771105 : return ss;
186 : }
187 :
188 81 : void llPoint(PriorNLikelihood & ml,
189 : ModelDesc &md,
190 : const std::list<MCPoint> &lp,
191 : std::set<MCPoint> &res)
192 : {
193 81 : for(std::list<MCPoint>::const_iterator i(lp.begin());
194 16281 : i != lp.end();
195 16200 : ++i)
196 : {
197 16200 : MCPoint p(i->p);
198 16200 : md.put(p.p);
199 16200 : p.ll=ml.llprob();
200 16200 : res.insert(p);
201 16200 : }
202 81 : }
203 :
204 0 : void llPoint(PriorNLikelihood & ml,
205 : const std::list<MCPoint> &lp,
206 : std::set<MCPoint> &res)
207 : {
208 0 : ModelDesc m(ml);
209 0 : llPoint(ml,
210 : m,
211 : lp,
212 : res);
213 0 : }
214 :
215 81 : void startSetDirect(IndependentFlatPriors &prior,
216 : size_t n,
217 : std::list<MCPoint> &res,
218 : unsigned seed)
219 : {
220 81 : const size_t nprior=prior.npriors();
221 :
222 81 : std::mt19937 rng(seed);
223 81 : std::uniform_real_distribution<double> zo(0.,1.);
224 :
225 81 : Minim::MCPoint p(nprior);
226 16281 : for(size_t i=0; i<n; ++i)
227 : {
228 16200 : size_t j=0;
229 16200 : for (IndependentFlatPriors::priorlist_t::const_iterator dimp(prior.pbegin());
230 64800 : dimp != prior.pend();
231 48600 : ++dimp)
232 : {
233 48600 : p.p[j]=dimp->pmin+ (dimp->pmax - dimp->pmin)* zo(rng);
234 48600 : ++j;
235 : }
236 16200 : res.push_back(p);
237 : }
238 :
239 81 : }
240 :
241 0 : void printSS(const std::set<MCPoint> &ss)
242 : {
243 0 : for(std::set<MCPoint>::const_iterator i(ss.begin());
244 0 : i != ss.end();
245 0 : ++i)
246 : {
247 0 : std::cout<<"p:";
248 0 : for(size_t j=0; j<i->p.size(); ++j)
249 0 : std::cout<<i->p[j]
250 0 : <<",";
251 0 : std::cout<<i->ll<<",";
252 0 : std::cout<<std::endl;
253 : }
254 0 : }
255 :
256 : }
257 :
258 :
|