11import numpy as np
22import emcee
3- import tqdm
43import corner
54from scipy .stats import rice , norm , uniform
65from matplotlib import pyplot as plt
98
109class MCMC :
1110 """
12- Performs sampling of exponential data
11+ Performs sampling of exponential data using MCMC
1312 """
1413 def __init__ (self ,
1514 data ,
@@ -26,6 +25,28 @@ def __init__(self,
2625 """
2726 Parameters
2827 ----------
28+ data : array_like
29+ The data to be fitted
30+ b : array_like
31+ The b-values for the data
32+ data_scale : float, optional
33+ The scale of the data, by default 1e-5
34+ parameter_scale : tuple, optional
35+ The scale of the parameters, by default (1e-7, 1e-11, 1e-9)
36+ bounds : tuple, optional
37+ The bounds for the parameters, by default ((0, 1), (0, 1), (0, 1))
38+ priors : tuple, optional
39+ The priors for the parameters, by default None
40+ gaussian_noise : bool, optional
41+ Whether the noise is gaussian, by default False
42+ nwalkers : int, optional
43+ The number of walkers, by default 16
44+ nsteps : int, optional
45+ The number of steps, by default 10000
46+ burn_in : int, optional
47+ The burn in, by default 2000
48+ progress : bool, optional
49+ Whether to show progress, by default True
2950
3051 """
3152 self .data = np .atleast_2d (np .asarray (data ))
@@ -68,60 +89,38 @@ def signal(self, f, D, D_star):
6889
6990 def biexp_loglike_gauss (self , f , D , D_star ):
7091 expected = self .signal (f , D , D_star )
71- # check this!
72- # print(f'likelihood {norm.logpdf(self.data, loc=expected/self.data_scale, scale=self.data_scale)}')
7392 return np .sum (norm .logpdf (self .data , loc = expected , scale = self .data_scale ), 1 )
7493
75- # def biexp_loglike_gauss_full(self, f, D, D_star):
76- # expected = self.signal(f, D, D_star)
77- # print(f'expected {expected}')
78- # print(f'data {self.data}')
79- # return norm.logpdf(self.data, loc=expected, scale=self.data_scale)
80-
8194 def biexp_loglike_rice (self , f , D , D_star ):
8295 expected = self .signal (f , D , D_star )
83- # print(f'expected {expected}')
8496 return np .sum (rice .logpdf (self .data , b = expected / self .data_scale , scale = self .data_scale ), 1 )
8597
8698 def posterior (self , params ):
8799 params = np .atleast_2d (params )
88100 total = self .bounds_prior (params )
89- # print(f'bounds params {total}')
90101 neginf = np .isneginf (total )
91- # print(f'neginf {neginf}')
92102 f = params [~ neginf , 0 ]
93103 D = params [~ neginf , 1 ]
94104 D_star = params [~ neginf , 2 ]
95105 prior = self .prior (params [~ neginf , :])
96- # print(f'prior {prior}')
97106 likelihood = self .likelihood (f , D , D_star )
98- # print(f'likelihood {likelihood}')
99107 total [~ neginf ] += prior + likelihood
100108 return total
101109
102110 def sample (self , initial_pos ):
103- # f = initial_pos[0]
104- # D = initial_pos[1]
105- # D_star = initial_pos[2]
106- # print(f'initial pos likelihood {self.biexp_loglike_gauss_full(f, D, D_star)}')
107- print (f'initial pos likelihood { self .posterior (initial_pos )} ' )
111+ # print(f'initial pos likelihood {self.posterior(initial_pos)}')
108112 sampler = emcee .EnsembleSampler (self .nwalkers , 3 , self .posterior , vectorize = True )
109113 pos = initial_pos + self .parameter_scale * np .random .randn (self .nwalkers , self .ndim )
110- # print(f'pos {pos}')
111- # print(f'nsteps {self.nsteps}')
112114 sampler .run_mcmc (pos , self .nsteps , progress = True )
113115 self .chain = sampler .get_chain (discard = self .burn_in , flat = True )
114116 self .means = np .mean (self .chain , 0 )
115117 self .stds = np .std (self .chain , 0 )
116- print (f'final pos likelihood { self .posterior (self .means )} ' )
117- # print(f'final pos likelihood {self.biexp_loglike_gauss_full(self.means[0], self.means[1], self.means[2])}')
118- # print(f'chain {self.chain}')
118+ # print(f'final pos likelihood {self.posterior(self.means)}')
119119 return self .means , self .stds
120120
121121 def plot (self , truths = None , labels = ('f' , 'D' , 'D*' ), overplot = None ):
122122 if truths is None :
123123 truths = self .means
124- # print(f'chain size {self.chain.shape}')
125124 fig = corner .corner (self .chain , labels = labels , truths = truths )
126125 fig .suptitle ("Sampling of the IVIM data" , fontsize = 16 )
127126 if overplot is not None :
0 commit comments