Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type stability of intermediates in conversions #81

Merged
merged 6 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/constants.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ bitsize(::Type{Float32}) = 32
bitsize(::Type{Float64}) = 64

# add sign mask for uints
Base.sign_mask(::Type{UInt8}) = 0x80
Base.sign_mask(::Type{UInt16}) = 0x8000
Base.sign_mask(::Type{UInt32}) = 0x8000_0000
Base.sign_mask(::Type{UInt64}) = 0x8000_0000_0000_0000
Expand Down
129 changes: 75 additions & 54 deletions src/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,65 @@ Base.inttype(::Type{Posit16_1}) = Int16
Base.inttype(::Type{Posit32}) = Int32

# generic conversion to UInt/Int
Base.unsigned(x::AbstractPosit) = reinterpret(Base.uinttype(typeof(x)),x)
Base.signed(x::AbstractPosit) = reinterpret(Base.inttype(typeof(x)),x)
Base.unsigned(x::AbstractPosit) = reinterpret(Base.uinttype(typeof(x)), x)
Base.signed(x::AbstractPosit) = reinterpret(Base.inttype(typeof(x)), x)

# BOOL
for PositType in (:Posit8, :Posit16, :Posit32, :Posit16_1)
@eval begin
$PositType(x::Bool) = x ? one($PositType) : zero($PositType)
Base.promote_rule(::Type{Bool},::Type{$PositType}) = $PositType
$PositType(x::Bool) = x ? one($PositType) : zero($PositType)
Base.promote_rule(::Type{Bool}, ::Type{$PositType}) = $PositType
end
end

# easier for development purposes
Posit8(x::UInt8) = reinterpret(Posit8,x)
Posit16(x::UInt16) = reinterpret(Posit16,x)
Posit16_1(x::UInt16) = reinterpret(Posit16_1,x)
Posit32(x::UInt32) = reinterpret(Posit32,x)
Posit8(x::UInt8) = reinterpret(Posit8, x)
Posit16(x::UInt16) = reinterpret(Posit16, x)
Posit16_1(x::UInt16) = reinterpret(Posit16_1, x)
Posit32(x::UInt32) = reinterpret(Posit32, x)

# BETWEEN Posits
# upcasting: append with zeros.
Posit16(x::Posit8) = reinterpret(Posit16,(unsigned(x) % UInt16) << 8)
Posit32(x::Posit8) = reinterpret(Posit32,(unsigned(x) % UInt32) << 24)
Posit32(x::Posit16) = reinterpret(Posit32,(unsigned(x) % UInt32) << 16)
Posit16(x::Posit8) = reinterpret(Posit16, (unsigned(x) % UInt16) << 8)
Posit32(x::Posit8) = reinterpret(Posit32, (unsigned(x) % UInt32) << 24)
Posit32(x::Posit16) = reinterpret(Posit32, (unsigned(x) % UInt32) << 16)

# downcasting: apply round to nearest
Posit8(x::Posit16) = posit(Posit8,x)
Posit8(x::Posit32) = posit(Posit8,x)
Posit16(x::Posit32) = posit(Posit16,x)
Posit8(x::Posit16) = posit(Posit8, x)
Posit8(x::Posit32) = posit(Posit8, x)
Posit16(x::Posit32) = posit(Posit16, x)

# conversion to and from Posit16_1 via floats as number of exponent bits changes
Posit16_1(x::AbstractPosit) = Posit16_1(float(x))
Posit8(x::Posit16_1) = Posit8(float(x))
Posit16(x::Posit16_1) = Posit16(float(x))
Posit32(x::Posit16_1) = Posit32(float(x))

function posit(::Type{PositN1},x::PositN2) where {PositN1<:AbstractPosit,PositN2<:AbstractPosit}
return reinterpret(PositN1,bitround(Base.uinttype(PositN1),unsigned(x)))
function posit(::Type{PositN1}, x::PositN2) where {PositN1<:AbstractPosit, PositN2<:AbstractPosit}
return reinterpret(PositN1, bitround(Base.uinttype(PositN1), unsigned(x)))
end

function bitround(::Type{UIntN1},ui::UIntN2) where {UIntN1<:Unsigned,UIntN2<:Unsigned}
Δbits = bitsize(UIntN2) - bitsize(UIntN1) # difference in bits
"""Bitround an unsigned integer `ui` to another bitsize `UIntN1`.
Rounds/downcasts using round to nearest or upcasts (append with zeros)."""
function bitround(::Type{UIntN1}, ui::UIntN2) where {UIntN1<:Unsigned, UIntN2<:Unsigned}
Δbits = bitsize(UIntN2) - bitsize(UIntN1) # difference in bit sizes

# ROUND TO NEAREST, tie to even: create ulp/2 = ..007ff.. or ..0080..
ulp_half = ~Base.sign_mask(UIntN2) >> bitsize(UIntN1) # create ..007ff.. (just smaller than ulp/2)
ulp_half += ((ui >> Δbits) & 0x1) # turn into ..0080.. for odd (=round up if tie)
ui += ulp_half # +ulp/2 and
ui_trunc = (ui >> Δbits) % UIntN1 # round down via >> is round nearest

# round down via >> is round nearest, but use % UInt64 in case of upcasting to not lose any bits
# and append with zeros
ui_trunc = ((ui % UInt64) >> Δbits) % UIntN1
return ui_trunc
end

"""Bitround an unsigned integer `ui::UIntN` to the same bitsize `UIntN`,
which is just identity. Return `ui`. Special case of `bitround` from
one unsigned integer size to another."""
bitround(::Type{UIntN}, ui::UIntN) where {UIntN<:Unsigned} = ui

# Due to only 1 exponent bit define Posit16_1(::AbstractPosit) via float conversion
Posit16_1(x::T) where {T<:Union{Posit8,Posit16,Posit32}} = Posit16_1(float(x))

Expand All @@ -72,54 +82,65 @@ Posit32(x::Signed) = Posit32(Float64(x))
Base.Int(x::AbstractPosit) = Int(Float64(x))

# promotions
Base.promote_rule(::Type{Int},::Type{T}) where {T<:AbstractPosit} = T
Base.promote_rule(::Type{Int}, ::Type{T}) where {T<:AbstractPosit} = T

# FROM FLOATS
Posit8(x::T) where {T<:Base.IEEEFloat} = posit(Posit8,x)
Posit16(x::T) where {T<:Base.IEEEFloat} = posit(Posit16,x)
Posit16_1(x::T) where {T<:Base.IEEEFloat} = posit(Posit16_1,x)
Posit32(x::T) where {T<:Base.IEEEFloat} = posit(Posit32,x)
Posit8(x::T) where {T<:Base.IEEEFloat} = posit(Posit8, x)
Posit16(x::T) where {T<:Base.IEEEFloat} = posit(Posit16, x)
Posit16_1(x::T) where {T<:Base.IEEEFloat} = posit(Posit16_1, x)
Posit32(x::T) where {T<:Base.IEEEFloat} = posit(Posit32, x)

function posit(::Type{PositN},x::FloatN) where {PositN<:AbstractPosit,FloatN<:Base.IEEEFloat}
"""
Convert float `x` (any size) to a posit of type `PositN` (e.g. Posit16, Posit32) via
round to nearest."""
function posit(::Type{PositN}, x::FloatN) where {PositN<:AbstractPosit, FloatN<:Base.IEEEFloat}

UIntN = Base.uinttype(FloatN) # unsigned integer corresponding to FloatN
IntN = Base.inttype(FloatN) # signed integer corresponding to FloatN
ui = reinterpret(UIntN,x) # reinterpret input
ui = reinterpret(UIntN, x) # reinterpret input

# extract exponent bits and shift to tail, then remove bias
e = (ui & Base.exponent_mask(FloatN)) >> Base.significand_bits(FloatN)
e = reinterpret(IntN,e) - IntN(Base.exponent_bias(FloatN))
e = reinterpret(IntN, e) - IntN(Base.exponent_bias(FloatN))
signbit_e = signbit(e) # sign of exponent
k = e >> Base.exponent_bits(PositN) # k-value for useed^k in posits

# ASSEMBLE POSIT REGIME, EXPONENT, MANTISSA
# get posit exponent_bits and shift to starting from bitposition 3 (they'll be shifted in later)
exponent_bits = e & Base.exponent_mask(PositN)
exponent_bits <<= bitsize(FloatN)-2-Base.exponent_bits(PositN)
# get posit exponent_bits and shift to starting from bitposition 3 (they'll be shifted in later)
# always construct with 64 bits, always construct with 64 bits, chop off in bitround
local regime::Int64
local exponent::Int64
local mantissa::Int64

# REGIME: create 01000... (for |x|<1) or 10000... (|x| >= 1), push in later
regime = signed(Base.sign_mask(Float64) >> signbit_e)

# create 01000... (for |x|<1) or 10000... (|x| > 1)
regime_bits = reinterpret(IntN,Base.sign_mask(FloatN) >> signbit_e)
# EXPONENT: push behind regime bits rree00... for 2 exp bits ee
exponent = signed(e & Base.exponent_mask(PositN))
exponent <<= 62 - Base.exponent_bits(PositN)

# extract mantissa bits and push to behind exponent rre..emm... (regime still hasn't been shifted)
mantissa = reinterpret(IntN,ui & Base.significand_mask(FloatN))
mantissa <<= Base.exponent_bits(FloatN) - Base.exponent_bits(PositN) - 1
# MANTISSA: extract bits and push to behind exponent rre..emm... (regime still hasn't been shifted)
mantissa = reinterpret(IntN, ui & Base.significand_mask(FloatN))
mantissa <<= 62 - Base.exponent_bits(PositN) - Base.significand_bits(FloatN)

# combine regime, exponent, mantissa and arithmetic bitshift for 11..110em or 00..001em
regime_exponent_mantissa = regime_bits | exponent_bits | mantissa
regime_exponent_mantissa = regime | exponent | mantissa
regime_exponent_mantissa >>= (abs(k+1) + signbit_e) # arithmetic bitshift
regime_exponent_mantissa &= ~Base.sign_mask(FloatN) # remove possible sign bit from arith shift
regime_exponent_mantissa &= ~Base.sign_mask(Float64) # remove possible sign bit from arith shift

# round to nearest of the result
p_rounded = bitround(Base.uinttype(PositN),reinterpret(UIntN,regime_exponent_mantissa))
# round to nearest of the result and truncate to posit bitsize
p = bitround(Base.uinttype(PositN), unsigned(regime_exponent_mantissa))

# no under or overflow rounding mode
max_k = (Base.exponent_bias(FloatN) >> Base.exponent_bits(PositN)) + 1
p_rounded -= Base.inttype(PositN)(sign(k)*(bitsize(PositN) <= abs(k) < max_k))
p -= Base.inttype(PositN)(sign(k)*(bitsize(PositN) <= abs(k) < max_k))
p = signbit(x) ? -p : p # two's complement for negative numbers

p_rounded = signbit(x) ? -p_rounded : p_rounded # two's complement for negative numbers
return reinterpret(PositN,p_rounded)
return reinterpret(PositN, p)
end

posit(::Type{PositN}, x::Float16) where {PositN<:AbstractPosit} = posit(PositN, Float32(x))

## TO FLOATS
# corresponding float types for round-free conversion (they don't match in bitsize though!)
Base.floattype(::Type{Posit8}) = Float32 # Posit8, 16 are subsets of Float32
Expand All @@ -128,18 +149,18 @@ Base.floattype(::Type{Posit16_1}) = Float32
Base.floattype(::Type{Posit32}) = Float64 # Posit32 is a subset of Float64

# generic conversion to float
Base.float(x::AbstractPosit) = convert(Base.floattype(typeof(x)),x)
Base.Float32(x::AbstractPosit) = float(Float32,x)
Base.Float64(x::AbstractPosit) = float(Float64,x)
Base.float(x::AbstractPosit) = convert(Base.floattype(typeof(x)), x)
Base.Float32(x::AbstractPosit) = float(Float32, x)
Base.Float64(x::AbstractPosit) = float(Float64, x)

# The dynamic range of Float16 is smaller than Posit8/16/32
# for correct rounding convert first to Float32/64
Base.Float16(x::Posit8) = Float16(float(Float32,x))
Base.Float16(x::Posit16) = Float16(float(Float32,x))
Base.Float16(x::Posit16_1) = Float16(float(Float32,x))
Base.Float16(x::Posit32) = Float16(float(Float64,x))
Base.Float16(x::Posit8) = Float16(float(Float32, x))
Base.Float16(x::Posit16) = Float16(float(Float32, x))
Base.Float16(x::Posit16_1) = Float16(float(Float32, x))
Base.Float16(x::Posit32) = Float16(float(Float64, x))

function Base.float(::Type{FloatN},x::PositN) where {FloatN<:Base.IEEEFloat,PositN<:AbstractPosit}
function Base.float(::Type{FloatN}, x::PositN) where {FloatN<:Base.IEEEFloat, PositN<:AbstractPosit}

UIntN = Base.uinttype(FloatN) # corresponding UInt for floattype
n_bits = bitsize(PositN) # number of bits in posit format
Expand All @@ -164,20 +185,20 @@ function Base.float(::Type{FloatN},x::PositN) where {FloatN<:Base.IEEEFloat,Posi

# ASSEMBLE FLOAT EXPONENT
# useed^k * 2^e = 2^(2^n_exponent_bits*k+e), ie get k-value from number of regime bits,
# << n_exponent_bits for *2^exponent_bits, add exponent bits and Float exponent bias (=15,127,1023)
k = (-1+2sign_exponent)*n_regimebits - sign_exponent
# << n_exponent_bits for *2^exponent_bits, add exponent bits and Float exponent bias (=15,127,1023)
k = (-1 + 2sign_exponent)*n_regimebits - sign_exponent
exponent = ((k << Base.exponent_bits(PositN)) + exponent_bits + Base.exponent_bias(FloatN)) % UIntN
exponent <<= Base.significand_bits(FloatN)

# set exponent (and 1st mantissa bit) to NaN for NaR inputs
# set exponent to 0 for zero(Posit8) input
nan_ui = reinterpret(UIntN,nan(FloatN))
nan_ui = reinterpret(UIntN, nan(FloatN))
exponent = n_regimebits == n_bits ? (signbitx ? nan_ui : zero(exponent)) : exponent

# assemble sign, exponent and mantissa bits
sign = signbitx*Base.sign_mask(FloatN)
f = sign | exponent | mantissa # concatenate sign, exponent and mantissa
return reinterpret(FloatN,f)
return reinterpret(FloatN, f)
end

# BIGFLOAT
Expand Down
96 changes: 96 additions & 0 deletions test/bitround.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
@testset "Bitround identity" begin
@testset for UIntN in (UInt8, UInt16, UInt32, UInt64)
N = 10
for ui in rand(UIntN, N)
@test SoftPosit.bitround(UIntN, ui) == ui
end
end
end

@testset "Bitround round to nearest, tie to even" begin

N = 100_000

# UInt16 -> UInt8
for ui8 in rand(UInt8, N)
ui16 = (ui8 % UInt16) << 8 # pad with zeros
@test SoftPosit.bitround(UInt8, ui16) == ui8

ui16_ones = ui16 | 0x00ff # pad with ones
@test SoftPosit.bitround(UInt8, ui16_ones) == ui8 + 0x1

ui16_rd = ui16 | 0x007f # just less than ulp (all round down)
@test SoftPosit.bitround(UInt8, ui16_rd) == ui8

ui16_ru = ui16 | 0x0081 # just more than ulp (all round up)
@test SoftPosit.bitround(UInt8, ui16_ru) == ui8 + 0x1

ui16_tie = ui16 | 0x0080 # tie to even
@test SoftPosit.bitround(UInt8, ui16_tie) == ui8 + (ui8 & 0x1)
end

# UInt32 -> UInt16
for ui16 in rand(UInt16, N)
ui32 = (ui16 % UInt32) << 16 # pad with zeros
@test SoftPosit.bitround(UInt16, ui32) == ui16

ui32_ones = ui32 | 0x0000_ffff # pad with ones
@test SoftPosit.bitround(UInt16, ui32_ones) == ui16 + 0x1

ui32_rd = ui32 | 0x0000_7fff # just less than ulp (all round down)
@test SoftPosit.bitround(UInt16, ui32_rd) == ui16

ui32_ru = ui32 | 0x0000_8001 # just more than ulp (all round up)
@test SoftPosit.bitround(UInt16, ui32_ru) == ui16 + 0x1

ui32_tie = ui32 | 0x0000_8000 # tie to even
@test SoftPosit.bitround(UInt16, ui32_tie) == ui16 + (ui16 & 0x1)
end

# UInt64 -> UInt32
for ui32 in rand(UInt32, N)
ui64 = (ui32 % UInt64) << 32 # pad with zeros
@test SoftPosit.bitround(UInt32, ui64) == ui32

ui64_ones = ui64 | 0x0000_0000_ffff_ffff # pad with ones
@test SoftPosit.bitround(UInt32, ui64_ones) == ui32 + 0x1

ui64_rd = ui64 | 0x0000_0000_7fff_ffff # just less than ulp (all round down)
@test SoftPosit.bitround(UInt32, ui64_rd) == ui32

ui64_ru = ui64 | 0x0000_0000_8000_0001 # just more than ulp (all round up)
@test SoftPosit.bitround(UInt32, ui64_ru) == ui32 + 0x1

ui64_tie = ui64 | 0x0000_8000_0000 # tie to even
@test SoftPosit.bitround(UInt32, ui64_tie) == ui32 + (ui32 & 0x1)
end

# UInt32 -> UInt8
for ui8 in rand(UInt8, N)
ui32 = (ui8 % UInt32) << 24 # pad with zeros
@test SoftPosit.bitround(UInt8, ui32) == ui8

ui32_ones = ui32 | 0x00ff_ffff # pad with ones
@test SoftPosit.bitround(UInt8, ui32_ones) == ui8 + 0x1

ui32_rd = ui32 | 0x007f_ffff # just less than ulp (all round down)
@test SoftPosit.bitround(UInt8, ui32_rd) == ui8

ui32_ru = ui32 | 0x0080_0001 # just more than ulp (all round up)
@test SoftPosit.bitround(UInt8, ui32_ru) == ui8 + 0x1

ui32_tie = ui32 | 0x0080_0000 # tie to even
@test SoftPosit.bitround(UInt8, ui32_tie) == ui8 + (ui8 & 0x1)
end
end

@testset "Bitround upcasting (=no rounding, pad with zeros)" begin
@testset for (UIntN1, UIntN2) in zip((UInt8, UInt16, UInt32),
(UInt16, UInt32, UInt64))
N = 100000
for ui in rand(UIntN1, N)
Δb = SoftPosit.bitsize(UIntN2) - SoftPosit.bitsize(UIntN1)
@test SoftPosit.bitround(UIntN2, ui) >> Δb == ui
end
end
end
Loading
Loading