diff --git a/Project.toml b/Project.toml index 6ef345f..12db6cf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "StochasticRounding" uuid = "3843c9a1-1f18-49ff-9d99-1b4c8a8e97ed" authors = ["Milan Kloewer and StochasticRounding.jl contributors"] -version = "0.8.3" +version = "0.9.0" [deps] BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" diff --git a/src/general.jl b/src/general.jl index 6831e51..73d9510 100644 --- a/src/general.jl +++ b/src/general.jl @@ -12,10 +12,10 @@ for t in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt end # other floats, irrational and rationals -(::Type{T})(x::Real) where {T<:AbstractStochasticFloat} = stochastic_float(float(T)(x)) -(::Type{T})(x::Rational) where {T<:AbstractStochasticFloat} = stochastic_float(float(T)(x)) +(::Type{T})(x::Real) where {T<:AbstractStochasticFloat} = stochastic_round(T,x) +(::Type{T})(x::Rational) where {T<:AbstractStochasticFloat} = stochastic_round(stochastic_float(float(T)),x) (::Type{T})(x::AbstractStochasticFloat) where {T<:AbstractFloat} = convert(T,float(x)) -(::Type{T})(x::AbstractStochasticFloat) where {T<:AbstractStochasticFloat} = stochastic_float(convert(float(T),float(x))) +(::Type{T})(x::AbstractStochasticFloat) where {T<:AbstractStochasticFloat} = stochastic_round(T,float(x)) DoubleFloats.Double64(x::T) where T<:AbstractStochasticFloat = Double64(float(x)) # masks same as for deterministic floats @@ -104,8 +104,8 @@ Base.abs(x::AbstractStochasticFloat) = stochastic_float(abs(float(x))) # stochastic rounding export stochastic_round -stochastic_round(T::Type{<:AbstractFloat},x::Real) = stochastic_round(T,widen(stochastic_float(T))(x)) -stochastic_round(T::Type{<:AbstractStochasticFloat},x::AbstractFloat) = stochastic_float(stochastic_round(float(T),x)) +stochastic_round(T::Type{<:AbstractFloat}, x::Real) = stochastic_round(T,widen(stochastic_float(T))(x)) +stochastic_round(T::Type{<:AbstractStochasticFloat}, x::Real) = stochastic_float(stochastic_round(float(T),x)) # Comparison for op in (:(==), :<, :<=, :isless) diff --git a/test/conversions.jl b/test/conversions.jl index 2de8f63..ad50e25 100644 --- a/test/conversions.jl +++ b/test/conversions.jl @@ -1,9 +1,9 @@ -@testset "Converting Stochastic FP to BFloat16sr" begin +@testset "Deterministic conversion of Stochastic FP to BFloat16sr" begin for k in 1:100000 trueVal = randn(Float64) - # Convert to each deterministic type to ensure the result is + # Convert to each deterministic type to ensure the result is # representable in all precisions - trueVal=Float16(Float32(BFloat16(trueVal))) + trueVal = Float16(Float32(BFloat16(trueVal))) bfloatVal = BFloat16sr(trueVal) float16Val = Float16sr(trueVal) float32Val = Float32sr(trueVal) @@ -14,12 +14,12 @@ end end -@testset "Converting Stochastic FP to Float16sr" begin +@testset "Deterministic conversion of Stochastic FP to Float16sr" begin for k in 1:100000 trueVal = randn(Float64) - # Convert to each deterministic type to ensure the result is + # Convert to each deterministic type to ensure the result is # representable in all precisions - trueVal=Float16(Float32(BFloat16(trueVal))) + trueVal = Float16(Float32(BFloat16(trueVal))) bfloatVal = BFloat16sr(trueVal) float16Val = Float16sr(trueVal) float32Val = Float32sr(trueVal) @@ -30,12 +30,12 @@ end end end -@testset "Converting Stochastic FP to Float32sr" begin +@testset "Deterministic conversion of Stochastic FP to Float32sr" begin for k in 1:100000 trueVal = randn(Float64) - # Convert to each deterministic type to ensure the result is + # Convert to each deterministic type to ensure the result is # representable in all precisions - trueVal=Float16(Float32(BFloat16(trueVal))) + trueVal = Float16(Float32(BFloat16(trueVal))) bfloatVal = BFloat16sr(trueVal) float16Val = Float16sr(trueVal) float32Val = Float32sr(trueVal) @@ -45,3 +45,33 @@ end @test float32Val == Float32sr(float16Val) end end + + +@testset "Stochastic conversion of Float64 to $(SR)" for SR in [Float16sr, BFloat16sr, Float32sr] + + # corresponding FP type + FP = float(SR) + + for k in 1:100 + trueVal = randn(Float64) + + # make sure that this is not representable by FP + if trueVal == FP(trueVal) + trueVal = nextfloat(trueVal) + end + + roundedVal = SR(trueVal) + + # check that rounding is not deterministic + is_stochastic = false + for l in 1:10000 + if roundedVal != SR(trueVal) + is_stochastic = true + break + end + end + + # the redundant "== true" produces a better error message in the test log + @test is_stochastic == true + end +end