UQTk: Uncertainty Quantification Toolkit  3.1.1
mcmc.h
Go to the documentation of this file.
1 /* =====================================================================================
2 
3  The UQ Toolkit (UQTk) version 3.1.1
4  Copyright (2021) NTESS
5  https://www.sandia.gov/UQToolkit/
6  https://github.com/sandialabs/UQTk
7 
8  Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
9  Under the terms of Contract DE-NA0003525 with NTESS, the U.S. Government
10  retains certain rights in this software.
11 
12  This file is part of The UQ Toolkit (UQTk)
13 
14  UQTk is open source software: you can redistribute it and/or modify
15  it under the terms of BSD 3-Clause License
16 
17  UQTk is distributed in the hope that it will be useful,
18  but WITHOUT ANY WARRANTY; without even the implied warranty of
19  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20  BSD 3 Clause License for more details.
21 
22  You should have received a copy of the BSD 3 Clause License
23  along with UQTk. If not, see https://choosealicense.com/licenses/bsd-3-clause/.
24 
25  Questions? Contact the UQTk Developers at <uqtk-developers@software.sandia.gov>
26  Sandia National Laboratories, Livermore, CA, USA
27 ===================================================================================== */
31 
32 #ifndef UQTKMCMC_H_SEEN
33 #define UQTKMCMC_H_SEEN
34 
35 #include "dsfmt_add.h"
36 #include "arrayio.h"
37 #include "arraytools.h"
38 
39 #include <iostream>
40 #include <string.h>
41 #include <stdio.h>
42 #include <sstream>
43 
44 using namespace std; // needed for python string conversion
45 
47 public:
48  virtual double eval(Array1D<double>&){
49  return 3.14;
50  };
51  virtual ~LogPosteriorBase(){};
52 };
53 
54 class base{
55 public:
56  virtual double fun(double* x, int n){return 0.0;}
57 };
58 
59 class main{
60 public:
61  base* b_;
62 
63  main(base& b){
64  b_ = &b;
65  }
66 };
67 //*****************************************
68 
72 class MCMC{
73 public:
74  // Struct for the chain state
75  struct chainstate{
76  int step;
78  double alfa;
79  double post;
80  };
81 
82  // Constructors:
83 
85  MCMC(double (*logposterior)(Array1D<double>&, void *), void *postinfo);
89  MCMC();
90 
91  // Set or initialization functions:
92 
94  void setWriteFlag(int I);
96  void setFcnAccept(void (*fcnAccept)(void *));
97  void setFcnReject(void (*fcnReject)(void *));
99  void setChainDim(int chdim);
102  void initChainPropCov(Array2D<double>& propcov);
105  void initChainPropCovDiag(Array1D<double>& sig);
107  void setOutputInfo(string outtype, string file,int freq_file, int freq_screen);
109  void namesPrepended();
111  void setSeed(int seed);
113  void setLower(double lower, int i);
115  void setUpper(double upper, int i);
117  void setDefaultDomain();
119  void setPostInfo(void *postinfo);
120 
121  // Get functions:
122 
124  void getChainPropCov(Array2D<double>& propcov);
126  string getFilename();
128  int getWriteFlag();
130  void getSamples(int burnin, int every,Array2D<double>& samples);
132  void getSamples(Array2D<double>& samples);
134  void getFcnAccept(void (*fcnAccept)(void *));
135  void getFcnReject(void (*fcnReject)(void *));
137  string getOutputType();
139  int getFileFreq();
141  int getScreenFreq();
143  bool getNamesPrepended();
145  int getSeed();
147  double getLower(int i);
149  double getUpper(int i);
151  bool getDimInit();
153  void getPostInfo(void *post);
155  bool getPropCovInit();
157  bool getOutputInit();
159  int getLastWrite();
161  bool getFcnAcceptInit();
162  bool getFcnRejectInit();
164  virtual int getNSubSteps(){return 1;};
166  int getLowerFlag(int i);
167  int getUpperFlag(int i);
169  void getAcceptRatio(double * accrat);
171  double getAcceptRatio();
173  int GetChainDim() const;
174 
175  // Chain Functions:
176 
178  void resetChainState();
180  void resetChainFilename(string filename);
182  void parseBinChain(string filename, Array1D<chainstate>& readchain);
184  void writeFullChainTxt(string filename, Array1D<chainstate> fullchain);
186  void getFullChain(Array1D<chainstate>& readchain);
188  void appendMAP();
190  double getMode(Array1D<double>& MAPparams);
192  int getFullChainSize();
193 
194  // Functions to make sure the code respects the interface for chainstates:
195 
197  void setCurrentStateStep(int i);
199  void getCurrentStateState(Array1D<double>& state);
201  double getCurrentStatePost();
203  void setCurrentStateState(Array1D<double>& newState);
205  void setCurrentStatePost(double newPost);
207  void setCurrentStateAlfa(double newAlfa);
209  double getModeStatePost();
211  void getModeStateState(Array1D<double>& state);
212 
213  // Run functions:
214 
216  virtual void runOptim(Array1D<double>& start);
218  virtual void runChain(int ncalls, Array1D<double>& chstart) = 0;
220  virtual void runChain(int ncalls) = 0;
222  void runAcceptFcn();
224  void runRejectFcn();
225 
226 
227 
229  bool newModeFound();
231  double evalLogPosterior(Array1D<double>& m);
233  bool inDomain(Array1D<double>& m);
235  void writeChainTxt(string filename);
237  void writeChainBin(string filename);
238 
240  void setNewMode(bool value);
241 
242 
243 
244 protected:
246  void setAcceptRatio(double d);
248  void addCurrentState();
250  void updateMode();
252  void setLastWrite(int i);
253 
254  dsfmt_t RandomState;
255 
256 
257 private:
258  int WRITE_FLAG; // Write Flag
259  int FLAG; // Flag
260  LogPosteriorBase* L_; // Pointer to the LogPosterior base passed in through contstructor
261  struct outputInfo{
262  string outtype;
263  string filename;
266  } outputinfo_;
267  int chainDim_; // Chain dimensions
268  double (*logPosterior_)(Array1D<double>&, void *) = NULL; // Pointer to log-posterior function
269  void (*fcnAccept_)(void *) = NULL; // Pointer to accept function
270  void (*fcnReject_)(void *) = NULL; // Pointer to reject function
271  void *postInfo_ = NULL; // Void pointer to the posterior info (data)
272  Array2D<double> chcov; // Chain proposal distributions (before the adaptivity starts)
273  int seed_; // Random seed for mcmc
274 
275  virtual double probOldNew(Array1D<double>& a, Array1D<double>& b){return 0.0;}; // Evaluate old|new probabilities and new|old probabilities
276  //double evallogMVN_diag(Array1D<double>& x,Array1D<double>& mu,Array1D<double>& sig2); // Evaluate MVN
277 
278  chainstate currState_; // The current chain state
279  chainstate modeState_; // The current MAP state
280  Array1D<chainstate> fullChain_; // Array of chain states
281 
282  //void updateMode(); // Function to update the chain mode
283 
284  int lastwrite_; // Indicates up to which state
285  bool namesPrepend = false;
286 
287  bool newMode_ = false; // Flag to indicate whether a new mode is found during the last call of runChain, initalized as false
288 
289  double accRatio_ = -1.0; // Acceptance ratio of the chain, initialized as -1.0
290 
291  // Flags to indicate whether the corresponding parameters are initialized or not
292  bool chaindimInit_ = false;
293  bool propcovInit_ = false;
294  bool methodInit_ = false;
295  bool outputInit_ = false;
296 
297  bool fcnAcceptFlag_ = false; // Flag that indicates whether the accept function is given or not
298  bool fcnRejectFlag_ = false; // Flag that indicates whether the reject function is given or not
299 
304 
309 
310 };
311 
312 #endif /* UQTKMCMC_H_SEEN */
Header file for array read/write utilities.
Header file for array tools.
Definition: Array1D.h:472
Definition: Array1D.h:262
Stores data of any type T in a 1D array.
Definition: Array1D.h:61
Definition: mcmc.h:46
virtual ~LogPosteriorBase()
Definition: mcmc.h:51
virtual double eval(Array1D< double > &)
Definition: mcmc.h:48
Markov Chain Monte Carlo base class. Implemented the basic and most general MCMC algorithms.
Definition: mcmc.h:72
Array2D< double > chcov
Definition: mcmc.h:272
LogPosteriorBase * L_
Definition: mcmc.h:260
chainstate currState_
Definition: mcmc.h:275
Array1D< int > lower_flag_
Lower bound existence flags.
Definition: mcmc.h:306
Array1D< int > upper_flag_
Upper bound existence flags.
Definition: mcmc.h:308
int chainDim_
Definition: mcmc.h:267
dsfmt_t RandomState
Definition: mcmc.h:254
virtual void runChain(int ncalls)=0
Start an MCMC chain with trivial initial condition.
virtual void runChain(int ncalls, Array1D< double > &chstart)=0
The actual function that generates MCMC.
int seed_
Definition: mcmc.h:273
int FLAG
Definition: mcmc.h:259
Array1D< double > Lower_
Lower bounds.
Definition: mcmc.h:301
virtual int getNSubSteps()
Get function for number of sub steps.
Definition: mcmc.h:164
virtual double probOldNew(Array1D< double > &a, Array1D< double > &b)
Definition: mcmc.h:275
Array1D< double > Upper_
Upper bounds.
Definition: mcmc.h:303
int WRITE_FLAG
Definition: mcmc.h:258
int lastwrite_
Definition: mcmc.h:284
Array1D< chainstate > fullChain_
Definition: mcmc.h:280
chainstate modeState_
Definition: mcmc.h:279
Definition: mcmc.h:54
virtual double fun(double *x, int n)
Definition: mcmc.h:56
Definition: mcmc.h:59
base * b_
Definition: mcmc.h:61
main(base &b)
Definition: mcmc.h:63
Definition: mcmc.h:75
double alfa
Definition: mcmc.h:78
Array1D< double > state
Definition: mcmc.h:77
double post
Definition: mcmc.h:79
int step
Definition: mcmc.h:76
Definition: mcmc.h:261
string filename
Definition: mcmc.h:263
int freq_file
Definition: mcmc.h:264
string outtype
Definition: mcmc.h:262
int freq_screen
Definition: mcmc.h:265