@@ -12,20 +12,18 @@ using ProgressMeter, LinearAlgebra
1212using .. Turing: PROGRESS, NamedDist, NoDist, Turing
1313using StatsFuns: logsumexp
1414using Random: GLOBAL_RNG, AbstractRNG, randexp
15- using AbstractMCMC, DynamicPPL
15+ using DynamicPPL
1616using Bijectors: _debug
1717
1818import MCMCChains: Chains
1919import AdvancedHMC; const AHMC = AdvancedHMC
2020import .. Core: getchunksize, getADtype
21- import AbstractMCMC: AbstractTransition, sample, step!, sample_init!,
22- transitions_init, sample_end!, AbstractSampler, transition_type,
23- callback, init_callback, AbstractCallback, psample
21+ import AbstractMCMC
22+ using AbstractMCMC: AbstractModel, AbstractCallback, AbstractSampler
2423import DynamicPPL: tilde, dot_tilde, getspace, get_matching_type
2524
2625export InferenceAlgorithm,
2726 Hamiltonian,
28- AbstractGibbs,
2927 GibbsComponent,
3028 StaticHamiltonian,
3129 AdaptiveHamiltonian,
@@ -44,20 +42,8 @@ export InferenceAlgorithm,
4442 SMC,
4543 CSMC,
4644 PG,
47- PIMH,
48- PMMH,
49- IPMCMC, # particle-based sampling
5045 assume,
5146 observe,
52- step,
53- WelfordVar,
54- WelfordCovar,
55- NaiveCovar,
56- get_var,
57- get_covar,
58- add_sample!,
59- reset!,
60- step!,
6147 resume
6248
6349# ######################
9581# Default Transition #
9682# #####################
9783
98- struct Transition{T, F<: AbstractFloat } <: AbstractTransition
84+ struct Transition{T, F<: AbstractFloat }
9985 θ :: T
10086 lp :: F
10187end
@@ -147,19 +133,19 @@ function AbstractMCMC.sample(
147133 chain_type= Chains,
148134 kwargs...
149135)
150- return sample (rng, model, Sampler (alg, model), N; progress= PROGRESS[], chain_type= chain_type, kwargs... )
136+ return AbstractMCMC . sample (rng, model, Sampler (alg, model), N; progress= PROGRESS[], chain_type= chain_type, kwargs... )
151137end
152138
153139function AbstractMCMC. sample (
154- model:: AbstractModel ,
140+ model:: Model ,
155141 alg:: InferenceAlgorithm ,
156142 N:: Integer ;
157143 resume_from= nothing ,
158144 chain_type= Chains,
159145 kwargs...
160146)
161147 if resume_from === nothing
162- return sample (model, Sampler (alg, model), N; progress= PROGRESS[], chain_type= chain_type, kwargs... )
148+ return AbstractMCMC . sample (model, Sampler (alg, model), N; progress= PROGRESS[], chain_type= chain_type, kwargs... )
163149 else
164150 return resume (resume_from, N)
165151 end
@@ -174,7 +160,7 @@ function AbstractMCMC.psample(
174160 chain_type= Chains,
175161 kwargs...
176162)
177- return psample (GLOBAL_RNG, model, alg, N, n_chains; progress= false , chain_type= chain_type, kwargs... )
163+ return AbstractMCMC . psample (GLOBAL_RNG, model, alg, N, n_chains; progress= false , chain_type= chain_type, kwargs... )
178164end
179165
180166function AbstractMCMC. psample (
@@ -186,7 +172,7 @@ function AbstractMCMC.psample(
186172 chain_type= Chains,
187173 kwargs...
188174)
189- return psample (rng, model, Sampler (alg, model), N, n_chains; progress= false , chain_type= chain_type, kwargs... )
175+ return AbstractMCMC . psample (rng, model, Sampler (alg, model), N, n_chains; progress= false , chain_type= chain_type, kwargs... )
190176end
191177
192178function AbstractMCMC. sample_init! (
206192function AbstractMCMC. sample_end! (
207193 :: AbstractRNG ,
208194 :: Model ,
209- :: AbstractSampler ,
195+ :: Sampler ,
210196 :: Integer ,
211- :: Vector{<:AbstractTransition} ;
197+ :: Vector ;
212198 kwargs...
213199)
214200 # Silence the default API function.
244230# Chain making utilities #
245231# #########################
246232
247- function _params_to_array (ts:: Vector{<:AbstractTransition} , spl:: Sampler )
233+ function _params_to_array (ts:: Vector , spl:: Sampler )
248234 names_set = Set {String} ()
249235 # Extract the parameter names and values from each transition.
250236 dicts = map (ts) do t
@@ -276,7 +262,7 @@ function flatten_namedtuple(nt::NamedTuple)
276262 return [vn[1 ] for vn in names_vals], [vn[2 ] for vn in names_vals]
277263end
278264
279- function get_transition_extras (ts:: Vector{<:AbstractTransition} )
265+ function get_transition_extras (ts:: Vector )
280266 # Get the extra field names from the sampler state type.
281267 # This handles things like :lp or :weight.
282268 extra_params = additional_parameters (eltype (ts))
@@ -322,8 +308,8 @@ function AbstractMCMC.bundle_samples(
322308 model:: AbstractModel ,
323309 spl:: Sampler ,
324310 N:: Integer ,
325- ts:: Vector{<:AbstractTransition} ,
326- ct :: Type{Chains} ;
311+ ts:: Vector ,
312+ :: Type{Chains} ;
327313 discard_adapt:: Bool = true ,
328314 save_state= true ,
329315 kwargs...
@@ -384,7 +370,7 @@ function resume(c::Chains, n_iter::Int; chain_type=Chains, kwargs...)
384370 @assert ! isempty (c. info) " [Turing] cannot resume from a chain without state info"
385371
386372 # Sample a new chain.
387- newchain = sample (
373+ newchain = AbstractMCMC . sample (
388374 c. info[:range ],
389375 c. info[:model ],
390376 c. info[:spl ],
@@ -432,13 +418,12 @@ include("is.jl")
432418include (" AdvancedSMC.jl" )
433419include (" gibbs.jl" )
434420include (" ../contrib/inference/sghmc.jl" )
435- include (" ../contrib/inference/AdvancedSMCExtensions.jl" )
436421
437422# ###############
438423# Typing tools #
439424# ###############
440425
441- for alg in (:SMC , :PG , :PMMH , :IPMCMC , : MH , :IS , :ESS , :Gibbs )
426+ for alg in (:SMC , :PG , :MH , :IS , :ESS , :Gibbs )
442427 @eval getspace (:: $alg{space} ) where {space} = space
443428end
444429for alg in (:HMC , :HMCDA , :NUTS , :SGLD , :SGHMC )
494479# # Fallback functions
495480
496481alg_str (spl:: Sampler ) = string (nameof (typeof (spl. alg)))
497- transition_type (spl:: Sampler ) = typeof (Transition (spl))
498482
499483# utility funcs for querying sampler information
500484require_gradient (spl:: Sampler ) = false
0 commit comments