Skip to content

Commit df11892

Browse files
authored
Add missing BroadcastStyle and MemoryLayouts for cache styles (#387)
* Add missing BroadcastStyle and MemoryLayouts for cache styles * Fix cacheddata/complex caches * Review * Loosen test * Fix cached access * StackOverflow issue * Fix resizing where layout is a DualLayout{<:AbstractCachedLayout} * The change * No need to specialise on N anymore * Broken test
1 parent d7d4f63 commit df11892

File tree

3 files changed

+47
-3
lines changed

3 files changed

+47
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LazyArrays"
22
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
3-
version = "2.9"
3+
version = "2.9.1"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/cache.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ cache_layout(::AbstractStridedLayout, O::AbstractArray) = copy(O)
5757

5858
const _cache = cache_layout # TODO: deprecate
5959
cacheddata(A::AbstractCachedArray) = view(A.data,OneTo.(A.datasize)...)
60+
cacheddata(A::Adjoint) = adjoint(cacheddata(parent(A)))
61+
cacheddata(A::Transpose) = transpose(cacheddata(parent(A)))
6062

6163
maybe_cacheddata(A::AbstractCachedArray) = cacheddata(A)
6264
maybe_cacheddata(A::SubArray{<:Any,N,<:AbstractCachedArray}) where N = cacheddata(A)
@@ -333,6 +335,9 @@ MemoryLayout(C::Type{CachedArray{T,N,DAT,ARR}}) where {T,N,DAT,ARR} = cachedlayo
333335

334336
MemoryLayout(::Type{<:AbstractCachedArray}) = GenericCachedLayout()
335337

338+
transposelayout(::AbstractCachedLayout) = GenericCachedLayout()
339+
conjlayout(::Type{<:Complex}, ::AbstractCachedLayout) = GenericCachedLayout()
340+
336341
#####
337342
# broadcasting
338343
#
@@ -345,10 +350,13 @@ CachedArrayStyle(::Val{N}) where N = CachedArrayStyle{N}()
345350
CachedArrayStyle{M}(::Val{N}) where {N,M} = CachedArrayStyle{N}()
346351

347352
BroadcastStyle(::Type{<:AbstractCachedArray{<:Any,N}}) where N = CachedArrayStyle{N}()
353+
BroadcastStyle(::Type{<:AdjOrTrans{<:Any, <:AbstractCachedArray{<:Any,N}}}) where N = CachedArrayStyle{N}()
354+
BroadcastStyle(::Type{<:AdjOrTrans{<:Any, <:SubArray{<:Any,N,<:AbstractCachedArray{<:Any,M}}}}) where {N,M} = CachedArrayStyle{M}()
348355
BroadcastStyle(::Type{<:SubArray{<:Any,N,<:AbstractCachedArray{<:Any,M}}}) where {N,M} = CachedArrayStyle{M}()
356+
BroadcastStyle(::Type{<:SubArray{<:Any,N,<:AdjOrTrans{<:Any, <:AbstractCachedArray{<:Any,M}}}}) where {N,M} = CachedArrayStyle{M}()
357+
BroadcastStyle(::Type{<:AdjOrTrans{<:Any, <:SubArray{<:Any,N,<:AdjOrTrans{<:Any,<:AbstractCachedArray{<:Any,M}}}}}) where {N,M} = CachedArrayStyle{M}()
349358
BroadcastStyle(::CachedArrayStyle{N}, ::LazyArrayStyle{M}) where {N,M} = CachedArrayStyle{max(M, N)}()
350359

351-
352360
broadcasted(::AbstractLazyArrayStyle, op, A::CachedArray) = CachedArray(broadcast(op, cacheddata(A)), broadcast(op, A.array))
353361
layout_broadcasted(::CachedLayout, _, op, A::AbstractArray, c::Number) = CachedArray(broadcast(op, cacheddata(A), c), broadcast(op, A.array, c))
354362
layout_broadcasted(_, ::CachedLayout, op, c::Number, A::CachedArray) = CachedArray(broadcast(op, c, cacheddata(A)), broadcast(op, c, A.array))
@@ -394,6 +402,7 @@ function _bc_resizecacheddata!(::AbstractCachedLayout, a)
394402
resizedata!(a, size(a)...)
395403
view(cacheddata(a), axes(a)...)
396404
end
405+
_bc_resizecacheddata!(::DualLayout{ML}, a) where {ML<:AbstractCachedLayout} = _bc_resizecacheddata!(ML(), a)
397406
_bc_resizecacheddata!(_, a) = a
398407
_bc_resizecacheddata!(a) = _bc_resizecacheddata!(MemoryLayout(a), a)
399408
resize_bcargs!(bc::Broadcasted{<:CachedArrayStyle}) = broadcasted(bc.f, map(_bc_resizecacheddata!, bc.args)...)

test/cachetests.jl

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ using LazyArrays, FillArrays, LinearAlgebra, ArrayLayouts, SparseArrays, Test
44
using StaticArrays
55
import LazyArrays: CachedArray, CachedMatrix, CachedVector, PaddedLayout, CachedLayout, resizedata!, zero!,
66
CachedAbstractArray, CachedAbstractVector, CachedAbstractMatrix, AbstractCachedArray, AbstractCachedMatrix,
7-
PaddedColumns, cacheddata, maybe_cacheddata
7+
PaddedColumns, cacheddata, LazyArrayStyle, maybe_cacheddata, Accumulate, CachedArrayStyle, GenericCachedLayout,
8+
AccumulateAbstractVector
89

910
using ..InfiniteArrays
1011
using .InfiniteArrays: OneToInf
@@ -529,6 +530,18 @@ using Infinities
529530
copyto!(dest, src)
530531
@test dest == res
531532
end
533+
534+
@testset "Avoid StackOverflow for recursive CachedArrayStyles" begin
535+
@test Matrix(view((1:5)', :, 1:1) .* view(Accumulate(*, 1:5)', :, 1:1)) == [1;;] # used to StackOverflow
536+
end
537+
538+
@testset "DualLayout{<:AbstractCachedLayout}" begin
539+
arg1 = view((1:100)', :, 1:10)
540+
arg2 = view(AccumulateAbstractVector(*, 1:100)', :, 1:10)
541+
bc = Base.Broadcast.Broadcasted(CachedArrayStyle{2}(), *, (arg1, arg2))
542+
rsz_bc = LazyArrays.resize_bcargs!(bc);
543+
@test rsz_bc.args[2] == view(arg2.parent.parent.data', :, 1:10)
544+
end
532545
end
533546

534547
@testset "maybe_cacheddata" begin
@@ -539,6 +552,28 @@ using Infinities
539552
C = [1, 2, 3]
540553
@test maybe_cacheddata(C) === C
541554
end
555+
556+
@testset "Missing BroadcastStyles/MemoryLayouts/cacheddata with CachedArrayStyles" begin
557+
A = view(Accumulate(*, [1, 2, 3])', 1:1, 1:2)
558+
B = view(transpose(Accumulate(*, [1, 2im, 3])), 1:1, 1:2)
559+
C = Accumulate(*, [1, 2im, 3])'
560+
D = transpose(Accumulate(*, [1, 2im, 3]))
561+
E = view(Accumulate(*, [1, 2im, 3])', 1:1, 1:2)
562+
F = view(Accumulate(*, [1, 2, 3]), 1:2)'
563+
G = view(Accumulate(*, [1, 2im, 3])', 1:1, 1:2)'
564+
@test all(==(CachedArrayStyle{1}()), Base.BroadcastStyle.(typeof.((A, B, C, D, E, F, G))))
565+
@test all(==(GenericCachedLayout()), MemoryLayout.(typeof.((A, B, E, G))))
566+
@test all(==(DualLayout{GenericCachedLayout}()), MemoryLayout.(typeof.((C, D, F))))
567+
@test MemoryLayout(typeof(C)) == DualLayout{GenericCachedLayout}()
568+
@test MemoryLayout(typeof(D)) == DualLayout{GenericCachedLayout}()
569+
@test cacheddata(A) === view(cacheddata(parent(parent(A)))', 1:1, 1:1)
570+
@test cacheddata(B) === view(transpose(cacheddata(parent(parent(B)))), 1:1, 1:1)
571+
@test cacheddata(C) === cacheddata(parent(C))'
572+
@test cacheddata(D) === transpose(cacheddata(parent(D)))
573+
@test cacheddata(E) === view(cacheddata(parent(parent(E)))', 1:1, 1:1)
574+
@test cacheddata(F) === view(cacheddata(parent(parent(F))), 1:1)'
575+
@test cacheddata(G) === adjoint(view(cacheddata(parent(G)), 1:1, 1:1))
576+
end
542577
end
543578

544579
end # module

0 commit comments

Comments
 (0)