Rotation using FFT

\[ \frac{d f}{dt} + (v \frac{d f}{dx} - x \frac{d f}{dv}) = 0 \]

\[ x \in [-\pi, \pi],\qquad y \in [-\pi, \pi] \qquad \mbox{ and } \qquad t \in [0, 200\pi] \]

using BenchmarkTools
using FFTW
using LinearAlgebra
using Plots

Julia type for mesh information

struct OneDMesh
    xmin :: Float64
    xmax :: Float64
    nx   :: Int
end

OneDMesh( -π, π, 128)
struct TwoDMesh
    
    nx   :: Int
    ny   :: Int
    xmin :: Float64
    xmax :: Float64
    ymin :: Float64
    ymax :: Float64
    dx   :: Float64
    dy   :: Float64
    x    :: Vector{Float64}
    y    :: Vector{Float64}
    
    function TwoDMesh( xmin, xmax, nx, ymin, ymax, ny)
        dx, dy = (xmax-xmin)/nx, (ymax-ymin)/ny
        x = LinRange(xmin, xmax, nx+1)[1:end-1]  # we remove the end point
        y = LinRange(ymin, ymax, ny+1)[1:end-1]  # for periodic boundary condition
        new( nx, ny, xmin, xmax, ymin, ymax, dx, dy, x, y)
    end
end
mesh = TwoDMesh(-π, π, 128, -π, π, 256)
TwoDMesh(128, 256, -3.141592653589793, 3.141592653589793, -3.141592653589793, 3.141592653589793, 0.04908738521234052, 0.02454369260617026, [-3.141592653589793, -3.0925052683774523, -3.043417883165112, -2.9943304979527716, -2.9452431127404313, -2.89615572752809, -2.84706834231575, -2.7979809571034093, -2.748893571891069, -2.699806186678728  …  2.650718801466388, 2.6998061866787286, 2.748893571891069, 2.7979809571034098, 2.8470683423157497, 2.8961557275280905, 2.9452431127404317, 2.9943304979527716, 3.0434178831651124, 3.0925052683774528], [-3.141592653589793, -3.117048960983623, -3.0925052683774528, -3.0679615757712826, -3.0434178831651124, -3.0188741905589414, -2.994330497952771, -2.969786805346601, -2.945243112740431, -2.9206994201342606  …  2.8961557275280905, 2.9206994201342606, 2.945243112740431, 2.969786805346601, 2.994330497952771, 3.0188741905589414, 3.0434178831651124, 3.0679615757712826, 3.0925052683774528, 3.117048960983623])
@show mesh.xmin, mesh.xmax, mesh.nx, mesh.dx
(mesh.xmin, mesh.xmax, mesh.nx, mesh.dx) = (-3.141592653589793, 3.141592653589793, 128, 0.04908738521234052)
(-3.141592653589793, 3.141592653589793, 128, 0.04908738521234052)

Initialization of f : 2d array of double float

f = zeros(Float64,(mesh.nx,mesh.ny))

for (i, x) in enumerate(mesh.x), (j, y) in enumerate(mesh.y)

    f[i,j] = exp(-(x-1)*(x-1)/0.1)*exp(-(y-1)*(y-1)/0.1)
        
end

Julia function to compute exact solution

function compute_exact_solution(final_time, mesh)
   
    f = zeros(Float64,(mesh.nx, mesh.ny))
    for (i, x) in enumerate(mesh.x), (j, y) in enumerate(mesh.y)
        xn = cos(final_time)*x - sin(final_time)*y
        yn = sin(final_time)*x + cos(final_time)*y
        f[i,j] = exp(-(xn-1)*(xn-1)/0.1)*exp(-(yn-1)*(yn-1)/0.1)
    end
    f
end
compute_exact_solution (generic function with 1 method)
f = compute_exact_solution(0.0, mesh)
contour(f)

Create the gif to show what we are computing

function create_gif_animation(mesh, nsteps)
    
    @gif for t in LinRange(0, 2π, nsteps)

        f(x,y) = exp(-((cos(t)*x-sin(t)*y)-1)^2/0.2)*exp(-((sin(t)*x+cos(t)*y)-1)^2/0.2)
        
        p = plot(mesh.x, mesh.y, f, st = [:contour])
    
        plot!(p[1])
        plot!(zlims=(-0.01,1.01))
    
    end
end
create_gif_animation (generic function with 1 method)
create_gif_animation(mesh, 100);
[ Info: Saved animation to /home/runner/work/math-julia/math-julia/tmp.gif

Function to compute error

function compute_error(f, f_exact)
    maximum(abs.(f .- f_exact))
end
compute_error (generic function with 1 method)

Naive translation of a matlab code

function naive_translation_from_matlab(final_time, nsteps, mesh::TwoDMesh)

    dt = final_time/nsteps

    kx = 2π/(mesh.xmax-mesh.xmin) .* fftfreq(mesh.nx, mesh.nx)
    ky = 2π/(mesh.ymax-mesh.ymin) .* fftfreq(mesh.ny, mesh.ny)

    f = compute_exact_solution(0.0, mesh)

    for n=1:nsteps
       
       for (i, x) in enumerate(mesh.x)
           f[i,:]=real(ifft(exp.(1im*x*ky*tan(dt/2)).*fft(f[i,:])))
       end
       
       for (j, y) in enumerate(mesh.y)
           f[:,j]=real(ifft(exp.(-1im*y*kx*sin(dt)).*fft(f[:,j])))
       end
       
       for (i, x) in enumerate(mesh.x)
           f[i,:]=real(ifft(exp.(1im*x*ky*tan(dt/2)).*fft(f[i,:])))
       end
   end

   f
end
naive_translation_from_matlab (generic function with 1 method)
nsteps, final_time = 1000, 200
sol1 = naive_translation_from_matlab(final_time, nsteps, mesh)
sol2 = compute_exact_solution(final_time, mesh)
println( " error = ", compute_error(sol1, sol2))
@btime naive_translation_from_matlab(final_time, nsteps, mesh);
 error = 1.9817480989559044e-13
  10.670 s (7168002 allocations: 7.95 GiB)

Vectorized version

  • We remove the for loops over direction x and y by creating the 2d arrays exky and ekxy.
  • We save cpu time by computing them before the loop over time
function vectorized(final_time, nsteps, mesh::TwoDMesh)

    dt = final_time/nsteps

    kx = 2π/(mesh.xmax-mesh.xmin) .* fftfreq(mesh.nx, mesh.nx)
    ky = 2π/(mesh.ymax-mesh.ymin) .* fftfreq(mesh.ny, mesh.ny)

    f = compute_exact_solution(0.0, mesh)

    exky = exp.( 1im*tan(dt/2) .* mesh.x  .* ky')
    ekxy = exp.(-1im*sin(dt)   .* mesh.y' .* kx )
    
    for n = 1:nsteps
        f .= real(ifft(exky .* fft(f, 2), 2))
        f .= real(ifft(ekxy .* fft(f, 1), 1))
        f .= real(ifft(exky .* fft(f, 2), 2))
    end

    f
end
vectorized (generic function with 1 method)
nsteps, final_time = 1000, 200
sol1 = vectorized(final_time, nsteps, mesh)
sol2 = compute_exact_solution(final_time, mesh)
println( " error = ", compute_error(sol1, sol2))
@btime vectorized(final_time, nsteps, mesh);
 error = 1.7908811801681452e-13
  2.149 s (48006 allocations: 6.60 GiB)

Inplace computation

  • We remove the Float64-Complex128 conversion by allocating the distribution function f as a Complex array
  • Note that we need to use the inplace assignement operator “.=” to initialize the f array.
  • We use inplace computation for fft with the “bang” operator !
function inplace(final_time, nsteps, mesh::TwoDMesh)

    dt = final_time/nsteps

    kx = 2π/(mesh.xmax-mesh.xmin)*[0:mesh.nx÷2-1;mesh.nx÷2-mesh.nx:-1]
    ky = 2π/(mesh.ymax-mesh.ymin)*[0:mesh.ny÷2-1;mesh.ny÷2-mesh.ny:-1]
    
    f  = zeros(ComplexF64,(mesh.nx,mesh.ny))
    f .= compute_exact_solution(0.0, mesh)

    exky = exp.( 1im*tan(dt/2) .* mesh.x  .* ky')
    ekxy = exp.(-1im*sin(dt)   .* mesh.y' .* kx )
    
    for n = 1:nsteps
        fft!(f, 2)
        f .= exky .* f
        ifft!(f,2)
        fft!(f, 1)
        f .= ekxy .* f
        ifft!(f, 1)
        fft!(f, 2)
        f .= exky .* f
        ifft!(f,2)        
    end

    real(f)
end
inplace (generic function with 1 method)
nsteps, final_time = 1000, 200
sol1 = inplace(final_time, nsteps, mesh)
sol2 = compute_exact_solution(final_time, mesh)
println( " error = ", compute_error(sol1, sol2))
@btime inplace(final_time, nsteps, mesh);
 error = 1.8686623861233415e-13
  1.499 s (18014 allocations: 3.56 MiB)

Use plans for fft

  • When you apply multiple fft on array with same shape and size, it is recommended to use fftw plan to improve computations.
  • Let’s try to initialize our two fft along x and y with plans.
function with_fft_plans(final_time, nsteps, mesh::TwoDMesh)

    dt = final_time/nsteps

    kx = 2π/(mesh.xmax-mesh.xmin)*[0:mesh.nx÷2-1;mesh.nx÷2-mesh.nx:-1]
    ky = 2π/(mesh.ymax-mesh.ymin)*[0:mesh.ny÷2-1;mesh.ny÷2-mesh.ny:-1]
    
    f  = zeros(ComplexF64,(mesh.nx,mesh.ny))
    f .= compute_exact_solution(0.0, mesh)
= similar(f)

    exky = exp.( 1im*tan(dt/2) .* mesh.x  .* ky')
    ekxy = exp.(-1im*sin(dt)   .* mesh.y' .* kx )
        
    Px = plan_fft(f, 1)
    Py = plan_fft(f, 2)
        
    for n = 1:nsteps
        
.= Py * f
.=.* exky
        f .= Py \
        
.= Px * f
.=.* ekxy 
        f .= Px \
        
.= Py * f
.=.* exky
        f .= Py \
        
    end

    real(f)
end
with_fft_plans (generic function with 1 method)
nsteps, final_time = 1000, 200
sol1 = with_fft_plans(final_time, nsteps, mesh)
sol2 = compute_exact_solution(final_time, mesh)
println( " error = ", compute_error(sol1, sol2))
@btime with_fft_plans(final_time, nsteps, mesh);
 error = 1.8686623861233415e-13
  2.160 s (18030 allocations: 2.93 GiB)

Inplace computation and fft plans

To apply fft plan to an array A, we use a preallocated output array  by calling mul!(Â, plan, A). The input array A must be a complex floating-point array like the output Â. The inverse-transform is computed inplace by applying inv(P) with ldiv!(A, P, Â).

function with_fft_plans_inplace(final_time, nsteps, mesh::TwoDMesh)

    dt = final_time/nsteps

    kx = 2π/(mesh.xmax-mesh.xmin) .* fftfreq(mesh.nx, mesh.nx)
    ky = 2π/(mesh.ymax-mesh.ymin) .* fftfreq(mesh.ny, mesh.ny)
    
    f  = zeros(ComplexF64,(mesh.nx,mesh.ny))
    f .= compute_exact_solution(0.0, mesh)
= similar(f)

    exky = exp.( 1im*tan(dt/2) .* mesh.x  .* ky')
    ekxy = exp.(-1im*sin(dt)   .* mesh.y' .* kx )

    Px = plan_fft(f, 1)    
    Py = plan_fft(f, 2)
        
    for n = 1:nsteps
        
        mul!(f̂, Py, f)
.=.* exky
        ldiv!(f, Py, f̂)
        
        mul!(f̂, Px, f)
.=.* ekxy 
        ldiv!(f, Px, f̂)
        
        mul!(f̂, Py, f)
.=.* exky
        ldiv!(f, Py, f̂)
        
    end

    real(f)
end
with_fft_plans_inplace (generic function with 1 method)
nsteps, final_time = 1000, 200
sol1 = with_fft_plans_inplace(final_time, nsteps, mesh)
sol2 = compute_exact_solution(final_time, mesh)
println( " error = ", compute_error(sol1, sol2 ))
@btime with_fft_plans_inplace(final_time, nsteps, mesh);
 error = 1.8686623861233415e-13
  1.587 s (6026 allocations: 3.82 MiB)

Explicit transpose of f

  • Multidimensional arrays in Julia are stored in column-major order.
  • FFTs along y are slower than FFTs along x
  • We can speed-up the computation by allocating the transposed f and transpose f for each advection along y.
function with_fft_transposed(final_time, nsteps, mesh::TwoDMesh)

    dt = final_time/nsteps

    kx = 2π/(mesh.xmax-mesh.xmin) .* fftfreq(mesh.nx, mesh.nx)
    ky = 2π/(mesh.ymax-mesh.ymin) .* fftfreq(mesh.ny, mesh.ny)
    
    f  = zeros(ComplexF64,(mesh.nx,mesh.ny))
= similar(f)
    fᵗ = zeros(ComplexF64,(mesh.ny,mesh.nx))
    f̂ᵗ = similar(fᵗ)

    exky = exp.( 1im*tan(dt/2) .* mesh.x' .* ky )
    ekxy = exp.(-1im*sin(dt)   .* mesh.y' .* kx )
    
    FFTW.set_num_threads(4)
    Px = plan_fft(f,  1, flags=FFTW.PATIENT)    
    Py = plan_fft(fᵗ, 1, flags=FFTW.PATIENT)
    
    f .= compute_exact_solution(0.0, mesh)
    
    for n = 1:nsteps

        transpose!(fᵗ,f)
        mul!(f̂ᵗ, Py, fᵗ)
        f̂ᵗ .= f̂ᵗ .* exky
        ldiv!(fᵗ, Py, f̂ᵗ)
        transpose!(f,fᵗ)
        
        mul!(f̂, Px, f)
.=.* ekxy 
        ldiv!(f, Px, f̂)
        
        transpose!(fᵗ,f)
        mul!(f̂ᵗ, Py, fᵗ)
        f̂ᵗ .= f̂ᵗ .* exky
        ldiv!(fᵗ, Py, f̂ᵗ)
        transpose!(f,fᵗ)

    end

    real(f)

end
with_fft_transposed (generic function with 1 method)
nsteps, final_time = 1000, 200
sol1 = with_fft_transposed(final_time, nsteps, mesh)
sol2 = compute_exact_solution(final_time, mesh)
println( " error = ", compute_error(sol1, sol2))
@btime with_fft_transposed(final_time, nsteps, mesh);
 error = 1.8435917383389174e-13
  634.783 ms (6038 allocations: 6.82 MiB)
final_time, nsteps = 400π, 1000
mesh = TwoDMesh(-π, π, 512, -π, π, 256)
TwoDMesh(512, 256, -3.141592653589793, 3.141592653589793, -3.141592653589793, 3.141592653589793, 0.01227184630308513, 0.02454369260617026, [-3.141592653589793, -3.129320807286708, -3.1170489609836225, -3.104777114680538, -3.0925052683774528, -3.0802334220743677, -3.067961575771282, -3.055689729468197, -3.0434178831651124, -3.0311460368620273  …  3.018874190558941, 3.031146036862027, 3.0434178831651115, 3.0556897294681966, 3.0679615757712817, 3.0802334220743672, 3.092505268377452, 3.104777114680537, 3.117048960983622, 3.1293208072867076], [-3.141592653589793, -3.117048960983623, -3.0925052683774528, -3.0679615757712826, -3.0434178831651124, -3.0188741905589414, -2.994330497952771, -2.969786805346601, -2.945243112740431, -2.9206994201342606  …  2.8961557275280905, 2.9206994201342606, 2.945243112740431, 2.969786805346601, 2.994330497952771, 3.0188741905589414, 3.0434178831651124, 3.0679615757712826, 3.0925052683774528, 3.117048960983623])
inplace_bench = @benchmark inplace(final_time, nsteps, mesh)
vectorized_bench = @benchmark vectorized(final_time, nsteps, mesh)
with_fft_plans_bench = @benchmark with_fft_plans(final_time, nsteps, mesh)
with_fft_plans_inplace_bench = @benchmark with_fft_plans_inplace(final_time, nsteps, mesh)
with_fft_transposed_bench = @benchmark with_fft_transposed(final_time, nsteps, mesh)
BenchmarkTools.Trial: 3 samples with 1 evaluation.
 Range (minmax):  2.432 s 2.446 s   GC (min … max): 0.00% … 0.04%
 Time  (median):     2.438 s              GC (median):    0.00%
 Time  (mean ± σ):   2.438 s ± 7.039 ms   GC (mean ± σ):  0.01% ± 0.02%
   ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  2.43 s        Histogram: frequency by time        2.45 s <
 Memory estimate: 26.32 MiB, allocs estimate: 6038.
d = Dict() 
d["vectorized"] = minimum(vectorized_bench.times) / 1e6
d["inplace"] = minimum(inplace_bench.times) / 1e6
d["with_fft_plans"] = minimum(with_fft_plans_bench.times) / 1e6
d["with_fft_plans_inplace"] = minimum(with_fft_plans_inplace_bench.times) / 1e6
d["with_fft_transposed"] = minimum(with_fft_transposed_bench.times) / 1e6;
for (key, value) in sort(collect(d), by=last)
    println(rpad(key, 25, "."), lpad(round(value, digits=1), 6, "."))
end
with_fft_transposed......2431.7
inplace..................4035.2
with_fft_plans_inplace...4137.8
with_fft_plans...........6020.9
vectorized...............6167.4

Conclusion

  • Using pre-allocations of memory and inplace computation is very important
  • Try to always do computation on data contiguous in memory
  • Use plans for fft