Type stability

A function is type stable when you can derive what the output of the function needs to be.

function square_plus_one(v::T) where T <:Number
    g = v * v
    return g + 1
end
square_plus_one (generic function with 1 method)
v = rand()
0.21337149175107006
@code_warntype square_plus_one(v)
MethodInstance for square_plus_one(::Float64)
  from square_plus_one(v::T) where T<:Number @ Main In[2]:1
Static Parameters
  T = Float64
Arguments
  #self#::Core.Const(square_plus_one)
  v::Float64
Locals
  g::Float64
Body::Float64
1 ─      (g = v * v)
│   %2 = (g + 1)::Float64
└──      return %2

w = 5
5
@code_warntype square_plus_one(w)
MethodInstance for square_plus_one(::Int64)
  from square_plus_one(v::T) where T<:Number @ Main In[2]:1
Static Parameters
  T = Int64
Arguments
  #self#::Core.Const(square_plus_one)
  v::Int64
Locals
  g::Int64
Body::Int64
1 ─      (g = v * v)
│   %2 = (g + 1)::Int64
└──      return %2

Great! In the above two examples, we were able to predict what the output will be. This is because:

function square_plus_one(v::T) where T <:Number
    g = v*v         # Type(T * T) ==> T
    return g+1      # Type(T + Int)) ==> "max" (T,Int)
end

Note that in both calls the return type was different, once Float64 and once Int64. But the function is still type stable.


function zero_or_val(x::Real)
    if x >= 0
        return x
    else
        return 0
    end
end
@code_warntype zero_or_val(0.2)
MethodInstance for zero_or_val(::Float64)
  from zero_or_val(x::Real) @ Main In[7]:1
Arguments
  #self#::Core.Const(zero_or_val)
  x::Float64
Body::Union{Float64, Int64}
1 ─ %1 = (x >= 0)::Bool
└──      goto #3 if not %1
2 ─      return x
3 ─      return 0

You can avoid type instable code by using the promote_type function which returns the highest of the two types passed.

function zero_or_val_stable(x::Real)
    if x >= 0
        y = x
    else
        y = 0
    end
    T = promote_type(typeof(x),Int)
    return T(y)
end
@code_warntype zero_or_val_stable(0.2)
MethodInstance for zero_or_val_stable(::Float64)
  from zero_or_val_stable(x::Real) @ Main In[8]:1
Arguments
  #self#::Core.Const(zero_or_val_stable)
  x::Float64
Locals
  T::Type{Float64}
  y::Union{Float64, Int64}
Body::Float64
1 ─       Core.NewvarNode(:(T))
│         Core.NewvarNode(:(y))
│   %3  = (x >= 0)::Bool
└──       goto #3 if not %3
2 ─       (y = x)
└──       goto #4
3 ─       (y = 0)
4 ┄ %8  = Main.typeof(x)::Core.Const(Float64)
│         (T = Main.promote_type(%8, Main.Int))
│   %10 = (T::Core.Const(Float64))(y)::Float64
└──       return %10

Break functions into multiple definitions

using LinearAlgebra

function mynorm(A)
    if isa(A, Vector)
        return sqrt(real(dot(A,A)))
    elseif isa(A, Matrix)
        return maximum(svdvals(A))
    else
        error("mynorm: invalid argument")
    end
end

This can be written more concisely and efficiently as:

norm(x::Vector) = sqrt(real(dot(x, x)))

norm(A::Matrix) = maximum(svdvals(A))

Avoid changing the type of a variable

Let us say we want to play the following game, I give you a vector of numbers. And you want to accumulate the sum as follows. For each number in the vector, you toss a coin (rand()), if it is heads (>=0.5), you add 1. Otherwise, you add the number itself.

function flipcoin_then_add(v::Vector{T}) where T <: Real
    s = 0
    for vi in v
        r = rand()
        if r >=0.5
            s += 1
        else
            s += vi
        end
    end
end
flipcoin_then_add (generic function with 1 method)

function flipcoin_then_add_typed(v::Vector{T}) where T <: Real
    s = zero(T)
    for vi in v
        r = rand()
        if r >=0.5
            s += one(T)
        else
            s += vi
        end
    end
end
flipcoin_then_add_typed (generic function with 1 method)

using BenchmarkTools

myvec = rand(1000)
@show flipcoin_then_add(myvec) == flipcoin_then_add_typed(myvec)
flipcoin_then_add(myvec) == flipcoin_then_add_typed(myvec) = true
true
@btime flipcoin_then_add(rand(1000))
@btime flipcoin_then_add_typed(rand(1000))
  7.336 μs (1 allocation: 7.94 KiB)
  1.681 μs (1 allocation: 7.94 KiB)