function square_plus_one(v::T) where T <:Number
= v * v
g return g + 1
end
square_plus_one (generic function with 1 method)
A function is type stable when you can derive what the output of the function needs to be.
square_plus_one (generic function with 1 method)
MethodInstance for square_plus_one(::Float64)
from square_plus_one(v::T) where T<:Number @ Main In[2]:1
Static Parameters
T = Float64
Arguments
#self#::Core.Const(square_plus_one)
v::Float64
Locals
g::Float64
Body::Float64
1 ─ (g = v * v)
│ %2 = (g + 1)::Float64
└── return %2
MethodInstance for square_plus_one(::Int64)
from square_plus_one(v::T) where T<:Number @ Main In[2]:1
Static Parameters
T = Int64
Arguments
#self#::Core.Const(square_plus_one)
v::Int64
Locals
g::Int64
Body::Int64
1 ─ (g = v * v)
│ %2 = (g + 1)::Int64
└── return %2
Great! In the above two examples, we were able to predict what the output will be. This is because:
function square_plus_one(v::T) where T <:Number
g = v*v # Type(T * T) ==> T
return g+1 # Type(T + Int)) ==> "max" (T,Int)
end
Note that in both calls the return type was different, once Float64
and once Int64
. But the function is still type stable.
function zero_or_val(x::Real)
if x >= 0
return x
else
return 0
end
end
@code_warntype zero_or_val(0.2)
MethodInstance for zero_or_val(::Float64)
from zero_or_val(x::Real) @ Main In[7]:1
Arguments
#self#::Core.Const(zero_or_val)
x::Float64
Body::Union{Float64, Int64}
1 ─ %1 = (x >= 0)::Bool
└── goto #3 if not %1
2 ─ return x
3 ─ return 0
You can avoid type instable code by using the promote_type
function which returns the highest of the two types passed.
function zero_or_val_stable(x::Real)
if x >= 0
y = x
else
y = 0
end
T = promote_type(typeof(x),Int)
return T(y)
end
@code_warntype zero_or_val_stable(0.2)
MethodInstance for zero_or_val_stable(::Float64)
from zero_or_val_stable(x::Real) @ Main In[8]:1
Arguments
#self#::Core.Const(zero_or_val_stable)
x::Float64
Locals
T::Type{Float64}
y::Union{Float64, Int64}
Body::Float64
1 ─ Core.NewvarNode(:(T))
│ Core.NewvarNode(:(y))
│ %3 = (x >= 0)::Bool
└── goto #3 if not %3
2 ─ (y = x)
└── goto #4
3 ─ (y = 0)
4 ┄ %8 = Main.typeof(x)::Core.Const(Float64)
│ (T = Main.promote_type(%8, Main.Int))
│ %10 = (T::Core.Const(Float64))(y)::Float64
└── return %10
using LinearAlgebra
function mynorm(A)
if isa(A, Vector)
return sqrt(real(dot(A,A)))
elseif isa(A, Matrix)
return maximum(svdvals(A))
else
error("mynorm: invalid argument")
end
end
This can be written more concisely and efficiently as:
Let us say we want to play the following game, I give you a vector of numbers. And you want to accumulate the sum as follows. For each number in the vector, you toss a coin (rand()
), if it is heads (>=0.5
), you add 1
. Otherwise, you add the number itself.
function flipcoin_then_add(v::Vector{T}) where T <: Real
s = 0
for vi in v
r = rand()
if r >=0.5
s += 1
else
s += vi
end
end
end
flipcoin_then_add (generic function with 1 method)
function flipcoin_then_add_typed(v::Vector{T}) where T <: Real
s = zero(T)
for vi in v
r = rand()
if r >=0.5
s += one(T)
else
s += vi
end
end
end
flipcoin_then_add_typed (generic function with 1 method)