diff --git a/CITATION.cff b/CITATION.cff index 9b2157f3e..f1b639f71 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -10,7 +10,7 @@ authors: title: "TensorKit.jl" version: "0.16.0" doi: "10.5281/zenodo.8421339" -date-released: "2025-12-05" +date-released: "2025-12-08" url: "https://github.com/QuantumKitHub/TensorKit.jl" preferred-citation: type: article diff --git a/docs/src/Changelog.md b/docs/src/Changelog.md index a7cc8fc0d..f2824dcde 100644 --- a/docs/src/Changelog.md +++ b/docs/src/Changelog.md @@ -20,7 +20,7 @@ When releasing a new version, move the "Unreleased" changes to a new version sec ## [Unreleased](https://github.com/QuantumKitHub/TensorKit.jl/compare/v0.16.0...HEAD) -## [0.16.0](https://github.com/QuantumKitHub/TensorKit.jl/releases/tag/v0.16.0) - 2025-12-05 +## [0.16.0](https://github.com/QuantumKitHub/TensorKit.jl/releases/tag/v0.16.0) - 2025-12-08 ### Added @@ -38,6 +38,7 @@ When releasing a new version, move the "Unreleased" changes to a new version sec - Major documentation update/overhaul ([#289](https://github.com/QuantumKitHub/TensorKit.jl/pull/289)) - Added symmetric tensor tutorial as appendix ([#316](https://github.com/QuantumKitHub/TensorKit.jl/pull/316)) - Improved error messages throughout codebase ([#309](https://github.com/QuantumKitHub/TensorKit.jl/pull/309)) +- `eigvals` and `svdvals` now output `SectorVector` objects, which do behave as `AbstractVector` but also have the option of iterating the blocks through `Base.pairs`. ([#324](https://github.com/QuantumKitHub/TensorKit.jl/pull/309) ### Deprecated @@ -52,6 +53,7 @@ When releasing a new version, move the "Unreleased" changes to a new version sec - Avoid unnecessary copy in `twist` for tensors with bosonic braiding ([#305](https://github.com/QuantumKitHub/TensorKit.jl/pull/305)) - Small fixes and typos ([#295](https://github.com/QuantumKitHub/TensorKit.jl/pull/295)) +- `eig_vals`, `svd_vals`, etc now all output `SectorVector` objects instead of `DiagonalTensorMap`s, in line with how MatrixAlgebraKit returns `Vector`s instead of `Diagonal`s ([#324](https://github.com/QuantumKitHub/TensorKit.jl/pull/309) ## [0.15.3](https://github.com/QuantumKitHub/TensorKit.jl/releases/tag/v0.15.3) - 2025-10-30 diff --git a/ext/TensorKitFiniteDifferencesExt.jl b/ext/TensorKitFiniteDifferencesExt.jl index 63c9711e3..8a19ae62a 100644 --- a/ext/TensorKitFiniteDifferencesExt.jl +++ b/ext/TensorKitFiniteDifferencesExt.jl @@ -1,7 +1,7 @@ module TensorKitFiniteDifferencesExt using TensorKit -using TensorKit: sqrtdim, invsqrtdim +using TensorKit: sqrtdim, invsqrtdim, SectorVector using VectorInterface: scale! using FiniteDifferences @@ -31,6 +31,25 @@ function FiniteDifferences.to_vec(t::DiagonalTensorMap) return x_vec, DiagonalTensorMap_from_vec end +function FiniteDifferences.to_vec(v::SectorVector{T, <:Sector}) where {T} + v_normalized = similar(v) + for (c, b) in pairs(v) + scale!(v_normalized[c], b, sqrtdim(c)) + end + vec = parent(v_normalized) + vec_real = T <: Real ? vec : collect(reinterpret(real(T), vec)) + + function from_vec(x_real) + x = T <: Real ? x_real : reinterpret(T, x_real) + v_result = SectorVector(x, v.structure) + for (c, b) in pairs(v_result) + scale!(b, invsqrtdim(c)) + end + return v_result + end + return vec_real, from_vec +end + end # TODO: Investigate why the approach below doesn't work diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 580052ec7..8e49a368a 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -79,7 +79,7 @@ export left_orth, right_orth, left_null, right_null, left_polar, left_polar!, right_polar, right_polar!, qr_full, qr_compact, qr_null, lq_full, lq_compact, lq_null, qr_full!, qr_compact!, qr_null!, lq_full!, lq_compact!, lq_null!, - svd_compact!, svd_full!, svd_trunc!, svd_compact, svd_full, svd_trunc, + svd_compact!, svd_full!, svd_trunc!, svd_compact, svd_full, svd_trunc, svd_vals, svd_vals!, exp, exp!, eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eigh_vals!, eigh_vals, eig_full!, eig_full, eig_trunc!, eig_trunc, eig_vals!, eig_vals, @@ -222,6 +222,7 @@ end include("tensors/abstracttensor.jl") include("tensors/backends.jl") include("tensors/blockiterator.jl") +include("tensors/sectorvector.jl") include("tensors/tensor.jl") include("tensors/adjoint.jl") include("tensors/linalg.jl") diff --git a/src/factorizations/diagonal.jl b/src/factorizations/diagonal.jl index cea915995..c54bef00e 100644 --- a/src/factorizations/diagonal.jl +++ b/src/factorizations/diagonal.jl @@ -2,6 +2,8 @@ # ----------------- _repack_diagonal(d::DiagonalTensorMap) = Diagonal(d.data) +MAK.diagview(t::DiagonalTensorMap) = SectorVector(t.data, TensorKit.diagonalblockstructure(space(t))) + for f in ( :svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, diff --git a/src/factorizations/factorizations.jl b/src/factorizations/factorizations.jl index dce2067ee..73fbc45eb 100644 --- a/src/factorizations/factorizations.jl +++ b/src/factorizations/factorizations.jl @@ -6,7 +6,7 @@ module Factorizations export copy_oftype, factorisation_scalartype, one!, truncspace using ..TensorKit -using ..TensorKit: AdjointTensorMap, SectorDict, blocktype, foreachblock, one! +using ..TensorKit: AdjointTensorMap, SectorDict, SectorVector, blocktype, foreachblock, one! using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals!, eigen, eigen!, isposdef, isposdef!, ishermitian @@ -44,13 +44,13 @@ function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...) tcopy = copy_oftype(t, factorisation_scalartype(LinearAlgebra.eigen, t)) return LinearAlgebra.eigvals!(tcopy; kwargs...) end -LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...) = diagview(eig_vals!(t)) +LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...) = eig_vals!(t) function LinearAlgebra.svdvals(t::AbstractTensorMap) tcopy = copy_oftype(t, factorisation_scalartype(svd_vals!, t)) return LinearAlgebra.svdvals!(tcopy) end -LinearAlgebra.svdvals!(t::AbstractTensorMap) = diagview(svd_vals!(t)) +LinearAlgebra.svdvals!(t::AbstractTensorMap) = svd_vals!(t) #--------------------------------------------------# # Checks for hermiticity and positive definiteness # diff --git a/src/factorizations/matrixalgebrakit.jl b/src/factorizations/matrixalgebrakit.jl index 412fd8761..4564a8137 100644 --- a/src/factorizations/matrixalgebrakit.jl +++ b/src/factorizations/matrixalgebrakit.jl @@ -44,7 +44,11 @@ for f! in ( end # Handle these separately because single output instead of tuple -for f! in (:qr_null!, :lq_null!, :project_hermitian!, :project_antihermitian!, :project_isometric!) +for f! in ( + :qr_null!, :lq_null!, + :svd_vals!, :eig_vals!, :eigh_vals!, + :project_hermitian!, :project_antihermitian!, :project_isometric!, + ) @eval function MAK.$f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) foreachblock(t, N) do _, (tblock, Nblock) Nblock′ = $f!(tblock, Nblock, alg) @@ -56,19 +60,6 @@ for f! in (:qr_null!, :lq_null!, :project_hermitian!, :project_antihermitian!, : end end -# Handle these separately because single output instead of tuple -for f! in (:svd_vals!, :eig_vals!, :eigh_vals!) - @eval function MAK.$f!(t::AbstractTensorMap, N, alg::AbstractAlgorithm) - foreachblock(t, N) do _, (tblock, Nblock) - Nblock′ = $f!(tblock, diagview(Nblock), alg) - # deal with the case where the output is not the same as the input - diagview(Nblock) === Nblock′ || copy!(diagview(Nblock), Nblock′) - return nothing - end - return N - end -end - # Singular value decomposition # ---------------------------- function MAK.initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::AbstractAlgorithm) @@ -90,7 +81,8 @@ end function MAK.initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) V_cod = infimum(fuse(codomain(t)), fuse(domain(t))) - return DiagonalTensorMap{real(scalartype(t))}(undef, V_cod) + T = real(scalartype(t)) + return SectorVector{T}(undef, V_cod) end # Eigenvalue decomposition @@ -114,13 +106,13 @@ end function MAK.initialize_output(::typeof(eigh_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) V_D = fuse(domain(t)) T = real(scalartype(t)) - return D = DiagonalTensorMap{Tc}(undef, V_D) + return SectorVector{T}(undef, V_D) end function MAK.initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm) V_D = fuse(domain(t)) Tc = complex(scalartype(t)) - return D = DiagonalTensorMap{Tc}(undef, V_D) + return SectorVector{Tc}(undef, V_D) end # QR decomposition diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 84f2569e0..dd2e663c6 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -78,10 +78,10 @@ end function MAK.truncate( ::typeof(left_null!), (U, S)::NTuple{2, AbstractTensorMap}, strategy::TruncationStrategy ) - extended_S = SectorDict( - c => vcat(diagview(b), zeros(eltype(b), max(0, size(b, 1) - size(b, 2)))) - for (c, b) in blocks(S) - ) + extended_S = zerovector!(SectorVector{eltype(S)}(undef, fuse(codomain(U)))) + for (c, b) in blocks(S) + copyto!(extended_S[c], diagview(b)) # copyto! since `b` might be shorter + end ind = MAK.findtruncated(extended_S, strategy) V_truncated = truncate_space(space(S, 1), ind) Ũ = similar(U, codomain(U) ← V_truncated) @@ -91,10 +91,10 @@ end function MAK.truncate( ::typeof(right_null!), (S, Vᴴ)::NTuple{2, AbstractTensorMap}, strategy::TruncationStrategy ) - extended_S = SectorDict( - c => vcat(diagview(b), zeros(eltype(b), max(0, size(b, 2) - size(b, 1)))) - for (c, b) in blocks(S) - ) + extended_S = zerovector!(SectorVector{eltype(S)}(undef, fuse(domain(Vᴴ)))) + for (c, b) in blocks(S) + copyto!(extended_S[c], diagview(b)) # copyto! since `b` might be shorter + end ind = MAK.findtruncated(extended_S, strategy) V_truncated = truncate_space(dual(space(S, 2)), ind) Ṽᴴ = similar(Vᴴ, V_truncated ← domain(Vᴴ)) @@ -177,26 +177,40 @@ function _findnexttruncvalue( end end +function _sort_and_perm(values::SectorVector; by = identity, rev::Bool = false) + values_sorted = similar(values) + perms = SectorDict( + ( + begin + p = sortperm(v; by, rev) + vs = values_sorted[c] + vs .= view(v, p) + c => p + end + ) for (c, v) in pairs(values) + ) + return values_sorted, perms +end + # findtruncated # ------------- # Generic fallback -function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationStrategy) +function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationStrategy) return MAK.findtruncated(values, strategy) end -function MAK.findtruncated(values::SectorDict, ::NoTruncation) - return SectorDict(c => Colon() for (c, b) in values) +function MAK.findtruncated(values::SectorVector, ::NoTruncation) + return SectorDict(c => Colon() for c in keys(values)) end -function MAK.findtruncated(values::SectorDict, strategy::TruncationByOrder) - perms = SectorDict(c => (sortperm(d; strategy.by, strategy.rev)) for (c, d) in values) - values_sorted = SectorDict(c => d[perms[c]] for (c, d) in values) +function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder) + values_sorted, perms = _sort_and_perm(values; strategy.by, strategy.rev) inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany)) return SectorDict(c => perms[c][I] for (c, I) in inds) end -function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByOrder) +function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder) I = keytype(values) - truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in values) + truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in pairs(values)) totaldim = sum(dim(c) * d for (c, d) in truncdim; init = 0) while totaldim > strategy.howmany next = _findnexttruncvalue(values, truncdim; strategy.by, strategy.rev) @@ -209,32 +223,31 @@ function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByOrder) return SectorDict(c => Base.OneTo(d) for (c, d) in truncdim) end -function MAK.findtruncated(values::SectorDict, strategy::TruncationByFilter) - return SectorDict(c => findall(strategy.filter, d) for (c, d) in values) +function MAK.findtruncated(values::SectorVector, strategy::TruncationByFilter) + return SectorDict(c => findall(strategy.filter, d) for (c, d) in pairs(values)) end -function MAK.findtruncated(values::SectorDict, strategy::TruncationByValue) +function MAK.findtruncated(values::SectorVector, strategy::TruncationByValue) atol = rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) - return SectorDict(c => MAK.findtruncated(d, strategy′) for (c, d) in values) + return SectorDict(c => MAK.findtruncated(d, strategy′) for (c, d) in pairs(values)) end -function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByValue) +function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByValue) atol = rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) - return SectorDict(c => MAK.findtruncated_svd(d, strategy′) for (c, d) in values) + return SectorDict(c => MAK.findtruncated_svd(d, strategy′) for (c, d) in pairs(values)) end -function MAK.findtruncated(values::SectorDict, strategy::TruncationByError) - perms = SectorDict(c => sortperm(d; by = abs, rev = true) for (c, d) in values) - values_sorted = SectorDict(c => d[perms[c]] for (c, d) in Sd) +function MAK.findtruncated(values::SectorVector, strategy::TruncationByError) + values_sorted, perms = _sort_and_perm(values; strategy.by, strategy.rev) inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany)) return SectorDict(c => perms[c][I] for (c, I) in inds) end -function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByError) +function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByError) I = keytype(values) - truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in values) + truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in pairs(values)) by(c, v) = abs(v)^strategy.p * dim(c) - Nᵖ = sum(((c, v),) -> sum(Base.Fix1(by, c), v), values) + Nᵖ = sum(((c, v),) -> sum(Base.Fix1(by, c), v), pairs(values)) ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ) truncerrᵖ = zero(real(scalartype(valtype(values)))) next = _findnexttruncvalue(values, truncdim) @@ -248,16 +261,16 @@ function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationByError) return SectorDict{I, Base.OneTo{Int}}(c => Base.OneTo(d) for (c, d) in truncdim) end -function MAK.findtruncated(values::SectorDict, strategy::TruncationSpace) +function MAK.findtruncated(values::SectorVector, strategy::TruncationSpace) blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev) return SectorDict(c => MAK.findtruncated(d, blockstrategy(c)) for (c, d) in values) end -function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationSpace) +function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationSpace) blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev) - return SectorDict(c => MAK.findtruncated_svd(d, blockstrategy(c)) for (c, d) in values) + return SectorDict(c => MAK.findtruncated_svd(d, blockstrategy(c)) for (c, d) in pairs(values)) end -function MAK.findtruncated(values::SectorDict, strategy::TruncationIntersection) +function MAK.findtruncated(values::SectorVector, strategy::TruncationIntersection) inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components) return SectorDict( c => mapreduce( @@ -266,7 +279,7 @@ function MAK.findtruncated(values::SectorDict, strategy::TruncationIntersection) ) for c in intersect(map(keys, inds)...) ) end -function MAK.findtruncated_svd(values::SectorDict, strategy::TruncationIntersection) +function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationIntersection) inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components) return SectorDict( c => mapreduce( @@ -278,13 +291,12 @@ end # Truncation error # ---------------- -MAK.truncation_error(values::SectorDict, ind) = - MAK.truncation_error!(SectorDict(c => copy(v) for (c, v) in values), ind) +MAK.truncation_error(values::SectorVector, ind) = MAK.truncation_error!(copy(values), ind) -function MAK.truncation_error!(values::SectorDict, ind) +function MAK.truncation_error!(values::SectorVector, ind) for (c, ind_c) in ind v = values[c] v[ind_c] .= zero(eltype(v)) end - return TensorKit._norm(values, 2, zero(real(eltype(valtype(values))))) + return norm(values) end diff --git a/src/tensors/blockiterator.jl b/src/tensors/blockiterator.jl index d33cb0f4a..40a0e1edb 100644 --- a/src/tensors/blockiterator.jl +++ b/src/tensors/blockiterator.jl @@ -30,7 +30,7 @@ for c in union(blocksectors.(ts)...) end ``` """ -function foreachblock(f, t::AbstractTensorMap, ts::AbstractTensorMap...; scheduler = nothing) +function foreachblock(f, t, ts...; scheduler = nothing) tensors = (t, ts...) allsectors = union(blocksectors.(tensors)...) foreach(allsectors) do c @@ -38,7 +38,7 @@ function foreachblock(f, t::AbstractTensorMap, ts::AbstractTensorMap...; schedul end return nothing end -function foreachblock(f, t::AbstractTensorMap; scheduler = nothing) +function foreachblock(f, t; scheduler = nothing) foreach(blocks(t)) do (c, b) return f(c, (b,)) end diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index 747a44320..44e7be002 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -300,7 +300,7 @@ function LinearAlgebra.rank( dim(t) == 0 && return r S = LinearAlgebra.svdvals(t) tol = max(atol, rtol * maximum(first, values(S))) - for (c, b) in S + for (c, b) in pairs(S) if !isempty(b) r += dim(c) * count(>(tol), b) end diff --git a/src/tensors/sectorvector.jl b/src/tensors/sectorvector.jl new file mode 100644 index 000000000..a2a081058 --- /dev/null +++ b/src/tensors/sectorvector.jl @@ -0,0 +1,105 @@ +""" + struct SectorVector{T, I, A <: AbstractVector{T}} <: AbstractVector{T} + +A representation of a vector with values of type `T`, with certain regions labeled by keys of type `I`. +These objects behave as their underlying parent vectors of type `A`, but additionally can be indexed through +keys of type `I` to produce the appropriate views. +Intuitively, these objects can be thought of as the combination of an `AbstractVector` and an `AbstractDict`. +""" +struct SectorVector{T, I, A <: AbstractVector{T}} <: AbstractVector{T} + data::A + structure::SectorDict{I, UnitRange{Int}} +end + +function SectorVector{T}(::UndefInitializer, V::ElementarySpace) where {T} + data = Vector{T}(undef, reduceddim(V)) + structure = diagonalblockstructure(V ← V) + return SectorVector(data, structure) +end + +Base.parent(v::SectorVector) = v.data + +# AbstractVector interface +# ------------------------ +Base.eltype(::Type{SectorVector{T, I, A}}) where {T, I, A} = T +Base.IndexStyle(::Type{SectorVector{T, I, A}}) where {T, I, A} = Base.IndexLinear() + +@inline Base.getindex(v::SectorVector, i::Int) = getindex(parent(v), i) +@inline Base.setindex!(v::SectorVector, val, i::Int) = setindex!(parent(v), val, i) + +Base.size(v::SectorVector, args...) = size(parent(v), args...) + +Base.similar(v::SectorVector) = SectorVector(similar(v.data), v.structure) +Base.similar(v::SectorVector, ::Type{T}) where {T} = SectorVector(similar(v.data, T), v.structure) + +Base.copy(v::SectorVector) = SectorVector(copy(v.data), v.structure) + +# AbstractDict interface +# ---------------------- +Base.keytype(v::SectorVector) = keytype(typeof(v)) +Base.keytype(::Type{SectorVector{T, I, A}}) where {T, I, A} = I +Base.valtype(v::SectorVector) = valtype(typeof(v)) +Base.valtype(::Type{SectorVector{T, I, A}}) where {T, I, A} = SubArray{T, 1, A, Tuple{UnitRange{Int}}, true} + +@inline Base.getindex(v::SectorVector{<:Any, I}, key::I) where {I} = view(v.data, v.structure[key]) +@inline Base.setindex!(v::SectorVector{<:Any, I}, val, key::I) where {I} = copy!(view(v.data, v.structure[key]), val) + +Base.keys(v::SectorVector) = keys(v.structure) +Base.values(v::SectorVector) = (v[c] for c in keys(v)) +Base.pairs(v::SectorVector) = SectorDict(c => v[c] for c in keys(v)) + +# TensorKit interface +# ------------------- +sectortype(::Type{T}) where {T <: SectorVector} = keytype(T) + +Base.similar(v::SectorVector, V::ElementarySpace) = SectorVector(undef, V) + +blocksectors(v::SectorVector) = keys(v) +blocks(v::SectorVector) = pairs(v) +block(v::SectorVector{T, I, A}, c::I) where {T, I, A} = Base.getindex(v, c) + +# VectorInterface and LinearAlgebra interface +# ---------------------------------------------- +VectorInterface.zerovector(v::SectorVector, ::Type{S}) where {S <: Number} = + SectorVector(zerovector(parent(v), S), v.structure) +VectorInterface.zerovector!(v::SectorVector) = (zerovector!(parent(v)); return v) +VectorInterface.zerovector!!(v::SectorVector) = (zerovector!!(parent(v)); return v) + +VectorInterface.scale(v::SectorVector, α::Number) = SectorVector(scale(parent(v), α), v.structure) +VectorInterface.scale!(v::SectorVector, α::Number) = (scale!(parent(v), α); return v) +VectorInterface.scale!!(v::SectorVector, α::Number) = (scale!!(parent(v), α); return v) + +function VectorInterface.add(v1::SectorVector, v2::SectorVector, α::Number, β::Number) + return SectorVector(add(parent(v1), parent(v2), α::Number, β::Number), v1.structure) +end +function VectorInterface.add!(v1::SectorVector, v2::SectorVector, α::Number, β::Number) + add!(parent(v1), parent(v2), α, β) + return v1 +end +function VectorInterface.add!!(v1::SectorVector, v2::SectorVector, α::Number, β::Number) + add!!(parent(v1), parent(v2), α, β) + return v1 +end + +function VectorInterface.inner(v1::SectorVector, v2::SectorVector) + v1.structure == v2.structure || throw(SpaceMismatch("Sector structures do not match")) + I = sectortype(v1) + if FusionStyle(I) isa UniqueFusion # all quantum dimensions are one + return inner(parent(v1), parent(v2)) + else + T = VectorInterface.promote_inner(v1, v2) + s = zero(T) + for c in blocksectors(v1) + b1 = block(v1, c) + b2 = block(v2, c) + s += convert(T, dim(c)) * inner(b1, b2) + end + end + return s +end + +LinearAlgebra.dot(v1::SectorVector, v2::SectorVector) = inner(v1, v2) + +function LinearAlgebra.norm(v::SectorVector, p::Real = 2) + return _norm(blocks(v), p, float(zero(real(scalartype(v))))) +end diff --git a/test/autodiff/ad.jl b/test/autodiff/ad.jl index 0625ee46c..7f956af50 100644 --- a/test/autodiff/ad.jl +++ b/test/autodiff/ad.jl @@ -36,29 +36,6 @@ function ChainRulesTestUtils.test_approx( return nothing end -# make sure that norms are computed correctly: -function FiniteDifferences.to_vec(t::SectorDict) - T = scalartype(valtype(t)) - vec = mapreduce(vcat, t; init = T[]) do (c, b) - return reshape(b, :) .* sqrt(dim(c)) - end - vec_real = T <: Real ? vec : collect(reinterpret(real(T), vec)) - - function from_vec(x_real) - x = T <: Real ? x_real : reinterpret(T, x_real) - ctr = 0 - return SectorDict( - c => ( - n = length(b); - b′ = reshape(view(x, ctr .+ (1:n)), size(b)) ./ sqrt(dim(c)); - ctr += n; - b′ - ) for (c, b) in t - ) - end - return vec_real, from_vec -end - # Float32 and finite differences don't mix well precision(::Type{<:Union{Float32, Complex{Float32}}}) = 1.0e-2 precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-5 diff --git a/test/tensors/factorizations.jl b/test/tensors/factorizations.jl index b25a9dbbd..e63312f9c 100644 --- a/test/tensors/factorizations.jl +++ b/test/tensors/factorizations.jl @@ -192,7 +192,7 @@ for V in spacelist @test isisometric(vᴴ; side = :right) s′ = LinearAlgebra.diag(s) - for (c, b) in LinearAlgebra.svdvals(t) + for (c, b) in pairs(LinearAlgebra.svdvals(t)) @test b ≈ s′[c] end @@ -312,7 +312,7 @@ for V in spacelist @test t * v ≈ v * d d′ = LinearAlgebra.diag(d) - for (c, b) in LinearAlgebra.eigvals(t) + for (c, b) in pairs(LinearAlgebra.eigvals(t)) @test sort(b; by = abs) ≈ sort(d′[c]; by = abs) end