Skip to content

Commit a65c30b

Browse files
authored
Merge pull request #149 from JuliaMath/nh/backends
Backend based AbstractNFFT implementation
2 parents 26fc18d + d38aaab commit a65c30b

File tree

13 files changed

+257
-55
lines changed

13 files changed

+257
-55
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ jobs:
4747
using Pkg
4848
Pkg.add(PackageSpec(path=pwd(), subdir="AbstractNFFTs"))
4949
Pkg.add(PackageSpec(path=pwd(), subdir="NFFTTools"))
50-
Pkg.add(PackageSpec(path=pwd(), subdir="CuNFFT"))
5150
- uses: julia-actions/julia-buildpkg@v1
5251
- uses: julia-actions/julia-runtest@v1
5352
env:
@@ -70,7 +69,6 @@ jobs:
7069
using Pkg
7170
Pkg.add(PackageSpec(path=pwd(), subdir="AbstractNFFTs"))
7271
Pkg.add(PackageSpec(path=pwd(), subdir="NFFTTools"))
73-
Pkg.add(PackageSpec(path=pwd(), subdir="CuNFFT"))
7472
- uses: julia-actions/julia-buildpkg@latest
7573
- uses: julia-actions/julia-docdeploy@latest
7674
env:

AbstractNFFTs/Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "AbstractNFFTs"
22
uuid = "7f219486-4aa7-41d6-80a7-e08ef20ceed7"
33
author = ["Tobias Knopp <[email protected]>"]
4-
version = "0.8.3"
4+
version = "0.9.0"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
99
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
10+
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1011

1112
[weakdeps]
1213
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -16,4 +17,5 @@ AbstractNFFTsChainRulesCoreExt = "ChainRulesCore"
1617

1718
[compat]
1819
julia = "1.6"
19-
ChainRulesCore = "1"
20+
ChainRulesCore = "1"
21+
ScopedValues = "1"

AbstractNFFTs/src/AbstractNFFTs.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@ module AbstractNFFTs
33
using LinearAlgebra
44
using Printf
55

6+
# Remove this difference once 1.11 or higher becomes lower bound
7+
if VERSION >= v"1.11"
8+
using Base.ScopedValues
9+
else
10+
using ScopedValues
11+
end
12+
13+
614
# interface
15+
export AbstractNFFTBackend, nfft_backend, with
716
export AbstractFTPlan, AbstractRealFTPlan, AbstractComplexFTPlan,
817
AbstractNFFTPlan, AbstractNFCTPlan, AbstractNFSTPlan, AbstractNNFFTPlan,
918
plan_nfft, plan_nfct, plan_nfst, mul!, size_in, size_out, nodes!
@@ -25,6 +34,7 @@ include("misc.jl")
2534
include("interface.jl")
2635
include("derived.jl")
2736

37+
2838
@static if !isdefined(Base, :get_extension)
2939
import Requires
3040
end

AbstractNFFTs/src/derived.jl

Lines changed: 152 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,64 +10,189 @@ planfunc = Symbol("plan_"*"$op")
1010

1111
# The following automatically call the plan_* version for type Array
1212

13-
$(planfunc)(k::AbstractArray, N::Union{Integer,NTuple{D,Int}}, args...; kargs...) where {D} =
14-
$(planfunc)(Array, k, N, args...; kargs...)
13+
$(planfunc)(b::AbstractNFFTBackend, k::AbstractArray, N::Union{Integer,NTuple{D,Int}}, args...; kargs...) where {D} =
14+
$(planfunc)(b, Array, k, N, args...; kargs...)
1515

16-
$(planfunc)(k::AbstractArray, y::AbstractArray, args...; kargs...) =
17-
$(planfunc)(Array, k, y, args...; kargs...)
16+
$(planfunc)(b::AbstractNFFTBackend, k::AbstractArray, y::AbstractArray, args...; kargs...) =
17+
$(planfunc)(b, Array, k, y, args...; kargs...)
18+
19+
$(planfunc)(k::AbstractArray, args...; kargs...) = $(planfunc)(active_backend(), k, args...; kargs...)
1820

1921
# The follow convert 1D parameters into the format required by the plan
2022

21-
$(planfunc)(Q::Type, k::AbstractVector, N::Integer, rest...; kwargs...) =
22-
$(planfunc)(Q, collect(reshape(k,1,length(k))), (N,), rest...; kwargs...)
23+
$(planfunc)(b::AbstractNFFTBackend, Q::Type, k::AbstractVector, N::Integer, rest...; kwargs...) =
24+
$(planfunc)(b, Q, collect(reshape(k,1,length(k))), (N,), rest...; kwargs...)
25+
26+
$(planfunc)(b::AbstractNFFTBackend, Q::Type, k::AbstractVector, N::NTuple{D,Int}, rest...; kwargs...) where {D} =
27+
$(planfunc)(b, Q, collect(reshape(k,1,length(k))), N, rest...; kwargs...)
2328

24-
$(planfunc)(Q::Type, k::AbstractVector, N::NTuple{D,Int}, rest...; kwargs...) where {D} =
25-
$(planfunc)(Q, collect(reshape(k,1,length(k))), N, rest...; kwargs...)
29+
$(planfunc)(b::AbstractNFFTBackend, Q::Type, k::AbstractMatrix, N::NTuple{D,Int}, rest...; kwargs...) where {D} =
30+
$(planfunc)(b, Q, collect(k), N, rest...; kwargs...)
2631

27-
$(planfunc)(Q::Type, k::AbstractMatrix, N::NTuple{D,Int}, rest...; kwargs...) where {D} =
28-
$(planfunc)(Q, collect(k), N, rest...; kwargs...)
32+
$(planfunc)(Q::Type, args...; kwargs...) = $(planfunc)(active_backend(), Q, args...; kwargs...)
2933

34+
$(planfunc)(::Missing, args...; kwargs...) = no_backend_error()
3035
end
3136
end
3237

3338
## NNFFT constructor
34-
plan_nnfft(Q::Type, k::AbstractVector, y::AbstractVector, rest...; kwargs...) =
35-
plan_nnfft(Q, collect(reshape(k,1,length(k))), collect(reshape(y,1,length(k))), rest...; kwargs...)
36-
39+
plan_nnfft(Q::Type, args...; kwargs...) = plan_nnfft(active_backend(), Q, args...; kwargs...)
40+
plan_nnfft(b::AbstractNFFTBackend, Q::Type, k::AbstractVector, y::AbstractVector, rest...; kwargs...) =
41+
plan_nnfft(b, Q, collect(reshape(k,1,length(k))), collect(reshape(y,1,length(k))), rest...; kwargs...)
42+
plan_nnfft(::Missing, args...; kwargs...) = no_backend_error()
3743

3844

3945
###############################################
4046
# Allocating trafo functions with plan creation
4147
###############################################
4248

49+
"""
50+
nfft(k, f, rest...; kwargs...)
51+
nfft(backend, k, f, rest...; kwargs...)
52+
53+
calculates the nfft of the array `f` for the nodes contained in the matrix `k`
54+
The output is a vector of length M=`size(nodes,2)`.
55+
56+
Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`.
57+
Backends can also be set with a scoped value overriding the current active backend within a scope:
58+
59+
```julia
60+
julia> NFFT.activate!()
61+
62+
julia> nfft(k, f, rest...; kwargs...) # uses NFFT
63+
64+
julia> with(nfft_backend => NonuniformFFTs.backend()) do
65+
nfft(k, f, rest...; kwargs...) # uses NonuniformFFTs
66+
end
67+
```
68+
"""
69+
nfft
70+
"""
71+
nfft_adjoint(k, N, fHat, rest...; kwargs...)
72+
nfft_adjoint(backend, k, N, fHat, rest...; kwargs...)
73+
74+
calculates the adjoint nfft of the vector `fHat` for the nodes contained in the matrix `k`.
75+
The output is an array of size `N`.
76+
77+
Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`.
78+
Backends can also be set with a scoped value overriding the current active backend within a scope:
79+
80+
```julia
81+
julia> NFFT.activate!()
82+
83+
julia> nfft_adjoint(k, N, fHat, rest...; kwargs...) # uses NFFT
84+
85+
julia> with(nfft_backend => NonuniformFFTs.backend()) do
86+
nfft_adjoint(k, N, fHat, rest...; kwargs...) # uses NonuniformFFTs
87+
end
88+
```
89+
"""
90+
nfft_adjoint
91+
"""
92+
nfft_transpose(k, N, fHat, rest...; kwargs...)
93+
nfft_transpose(backend, k, N, fHat, rest...; kwargs...)
94+
95+
calculates the transpose nfft of the vector `fHat` for the nodes contained in the matrix `k`.
96+
The output is an array of size `N`.
97+
98+
Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`.
99+
Backends can also be set with a scoped value overriding the current active backend within a scope:
100+
101+
```julia
102+
julia> NFFT.activate!()
103+
104+
julia> nfft_transpose(k, N, fHat, rest...; kwargs...) # uses NFFT
105+
106+
julia> with(nfft_backend => NonuniformFFTs.backend()) do
107+
nfft_transpose(k, N, fHat, rest...; kwargs...) # uses NonuniformFFTs
108+
end
109+
```
110+
"""
111+
nfft_transpose
112+
113+
"""
114+
nfct(k, f, rest...; kwargs...)
115+
nfct(backend, k, f, rest...; kwargs...)
116+
117+
calculates the nfct of the array `f` for the nodes contained in the matrix `k`
118+
The output is a vector of length M=`size(nodes,2)`.
119+
120+
Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`.
121+
"""
122+
nfct
123+
"""
124+
nfct_adjoint(k, N, fHat, rest...; kwargs...)
125+
nfct_adjoint(backend, k, N, fHat, rest...; kwargs...)
126+
127+
calculates the adjoint nfct of the vector `fHat` for the nodes contained in the matrix `k`.
128+
The output is an array of size `N`.
129+
130+
Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`.
131+
"""
132+
nfct_adjoint
133+
"""
134+
nfct_transpose(k, N, fHat, rest...; kwargs...)
135+
nfct_transpose(backend, k, N, fHat, rest...; kwargs...)
136+
137+
calculates the transpose nfct of the vector `fHat` for the nodes contained in the matrix `k`.
138+
The output is an array of size `N`.
139+
140+
Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`.
141+
"""
142+
nfct_transpose
143+
144+
"""
145+
nfst(k, f, rest...; kwargs...)
146+
nfst(backend, k, f, rest...; kwargs...)
147+
148+
calculates the nfst of the array `f` for the nodes contained in the matrix `k`
149+
The output is a vector of length M=`size(nodes,2)`.
150+
151+
Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`.
152+
"""
153+
nfst
154+
"""
155+
nfst_adjoint(k, N, fHat, rest...; kwargs...)
156+
nfst_adjoint(backend, k, N, fHat, rest...; kwargs...)
157+
158+
calculates the adjoint nfst of the vector `fHat` for the nodes contained in the matrix `k`.
159+
The output is an array of size `N`.
160+
161+
Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`.
162+
"""
163+
nfst_adjoint
164+
"""
165+
nfst_transpose(k, N, fHat, rest...; kwargs...)
166+
nfst_transpose(backend, k, N, fHat, rest...; kwargs...)
167+
168+
calculates the transpose nfst of the vector `fHat` for the nodes contained in the matrix `k`.
169+
The output is an array of size `N`.
170+
171+
Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`.
172+
"""
173+
nfst_transpose
174+
43175
for (op,trans) in zip([:nfft, :nfct, :nfst],
44176
[:adjoint, :transpose, :transpose])
45177
planfunc = Symbol("plan_$(op)")
46178
tfunc = Symbol("$(op)_$(trans)")
47179
@eval begin
48180

49-
# TODO fix comments (how?)
50-
"""
51-
nfft(k, f, rest...; kwargs...)
52-
53-
calculates the nfft of the array `f` for the nodes contained in the matrix `k`
54-
The output is a vector of length M=`size(nodes,2)`
55-
"""
56-
function $(op)(k, f::AbstractArray; kargs...)
181+
$(op)(k, f::AbstractArray; kargs...) = $(op)(active_backend(), k, f::AbstractArray; kargs...)
182+
function $(op)(b::AbstractNFFTBackend, k, f::AbstractArray; kargs...)
57183
p = $(planfunc)(k, size(f); kargs... )
58184
return p * f
59185
end
186+
$(op)(::Missing, k, f::AbstractArray; kargs...) = no_backend_error()
60187

61-
"""
62-
nfft_adjoint(k, N, fHat, rest...; kwargs...)
63188

64-
calculates the adjoint nfft of the vector `fHat` for the nodes contained in the matrix `k`.
65-
The output is an array of size `N`
66-
"""
67-
function $(tfunc)(k, N, fHat; kargs...)
189+
$(tfunc)(k, N, fHat; kargs...) = $(tfunc)(active_backend(), k, N, fHat; kargs...)
190+
function $(tfunc)(b::AbstractNFFTBackend, k, N, fHat; kargs...)
68191
p = $(planfunc)(k, N; kargs...)
69192
return $(trans)(p) * fHat
70193
end
194+
$(tfunc)(::Missing, k, N, fHat; kargs...) = no_backend_error()
195+
71196

72197
end
73198
end

AbstractNFFTs/src/interface.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,33 @@
1+
abstract type AbstractNFFTBackend end
2+
struct BackendReference
3+
ref::Ref{Union{Missing, AbstractNFFTBackend}}
4+
BackendReference(val::Union{Missing, AbstractNFFTBackend}) = new(Ref{Union{Missing, AbstractNFFTBackend}}(val))
5+
end
6+
Base.setindex!(ref::BackendReference, val::Union{Missing, AbstractNFFTBackend}) = ref.ref[] = val
7+
Base.setindex!(ref::BackendReference, val::Module) = setindex!(ref, val.backend())
8+
Base.getindex(ref::BackendReference) = getindex(ref.ref)::Union{Missing, AbstractNFFTBackend}
9+
Base.convert(::Type{BackendReference}, val::AbstractNFFTBackend) = BackendReference(val)
10+
const nfft_backend = ScopedValue(BackendReference(missing))
11+
12+
"""
13+
set_active_backend!(back::Union{Missing, Module, AbstractNFFTBackend})
14+
15+
Set the default NFFT plan backend. A module `back` must implement `back.backend()`.
16+
"""
17+
set_active_backend!(back::Module) = set_active_backend!(back.backend())
18+
function set_active_backend!(back::Union{Missing, AbstractNFFTBackend})
19+
nfft_backend[][] = back
20+
end
21+
active_backend() = nfft_backend[][]
22+
function no_backend_error()
23+
error(
24+
"""
25+
No default backend available!
26+
Make sure to also "import/using" an NFFT backend such as NFFT or NonuniformFFTs.
27+
"""
28+
)
29+
end
30+
131
"""
232
AbstractFTPlan{T,D,R}
333

NFFTTools/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1212
[compat]
1313
julia = "1.6"
1414
AbstractFFTs = "1.0"
15-
AbstractNFFTs = "0.6, 0.7, 0.8"
15+
AbstractNFFTs = "0.6, 0.7, 0.8, 0.9"
1616
FFTW = "1"

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NFFT"
22
uuid = "efe261a4-0d2b-5849-be55-fc731d526b0d"
33
authors = ["Tobias Knopp <[email protected]>"]
4-
version = "0.13.7"
4+
version = "0.14"
55

66
[deps]
77
AbstractNFFTs = "7f219486-4aa7-41d6-80a7-e08ef20ceed7"
@@ -19,7 +19,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1919

2020
[compat]
2121
Adapt = "3, 4"
22-
AbstractNFFTs = "0.8"
22+
AbstractNFFTs = "0.9"
2323
BasicInterpolators = "0.6.5, 0.7"
2424
DataFrames = "1.3.1, 1.4.1"
2525
FFTW = "1.5"

0 commit comments

Comments
 (0)