using BenchmarkTools
using FFTW
using LinearAlgebra
using Plots
Rotation using FFT
\[ \frac{d f}{dt} + (v \frac{d f}{dx} - x \frac{d f}{dv}) = 0 \]
\[ x \in [-\pi, \pi],\qquad y \in [-\pi, \pi] \qquad \mbox{ and } \qquad t \in [0, 200\pi] \]
Julia type for mesh information
struct OneDMesh
:: Float64
xmin :: Float64
xmax :: Int
nx end
OneDMesh( -π, π, 128)
struct TwoDMesh
:: Int
nx :: Int
ny :: Float64
xmin :: Float64
xmax :: Float64
ymin :: Float64
ymax :: Float64
dx :: Float64
dy :: Vector{Float64}
x :: Vector{Float64}
y
function TwoDMesh( xmin, xmax, nx, ymin, ymax, ny)
= (xmax-xmin)/nx, (ymax-ymin)/ny
dx, dy = LinRange(xmin, xmax, nx+1)[1:end-1] # we remove the end point
x = LinRange(ymin, ymax, ny+1)[1:end-1] # for periodic boundary condition
y new( nx, ny, xmin, xmax, ymin, ymax, dx, dy, x, y)
end
end
= TwoDMesh(-π, π, 128, -π, π, 256) mesh
TwoDMesh(128, 256, -3.141592653589793, 3.141592653589793, -3.141592653589793, 3.141592653589793, 0.04908738521234052, 0.02454369260617026, [-3.141592653589793, -3.0925052683774523, -3.043417883165112, -2.9943304979527716, -2.9452431127404313, -2.89615572752809, -2.84706834231575, -2.7979809571034093, -2.748893571891069, -2.699806186678728 … 2.650718801466388, 2.6998061866787286, 2.748893571891069, 2.7979809571034098, 2.8470683423157497, 2.8961557275280905, 2.9452431127404317, 2.9943304979527716, 3.0434178831651124, 3.0925052683774528], [-3.141592653589793, -3.117048960983623, -3.0925052683774528, -3.0679615757712826, -3.0434178831651124, -3.0188741905589414, -2.994330497952771, -2.969786805346601, -2.945243112740431, -2.9206994201342606 … 2.8961557275280905, 2.9206994201342606, 2.945243112740431, 2.969786805346601, 2.994330497952771, 3.0188741905589414, 3.0434178831651124, 3.0679615757712826, 3.0925052683774528, 3.117048960983623])
@show mesh.xmin, mesh.xmax, mesh.nx, mesh.dx
(mesh.xmin, mesh.xmax, mesh.nx, mesh.dx) = (-3.141592653589793, 3.141592653589793, 128, 0.04908738521234052)
(-3.141592653589793, 3.141592653589793, 128, 0.04908738521234052)
Initialization of f : 2d array of double float
= zeros(Float64,(mesh.nx,mesh.ny))
f
for (i, x) in enumerate(mesh.x), (j, y) in enumerate(mesh.y)
= exp(-(x-1)*(x-1)/0.1)*exp(-(y-1)*(y-1)/0.1)
f[i,j]
end
Julia function to compute exact solution
function compute_exact_solution(final_time, mesh)
= zeros(Float64,(mesh.nx, mesh.ny))
f for (i, x) in enumerate(mesh.x), (j, y) in enumerate(mesh.y)
= cos(final_time)*x - sin(final_time)*y
xn = sin(final_time)*x + cos(final_time)*y
yn = exp(-(xn-1)*(xn-1)/0.1)*exp(-(yn-1)*(yn-1)/0.1)
f[i,j] end
fend
compute_exact_solution (generic function with 1 method)
= compute_exact_solution(0.0, mesh)
f contour(f)
Create the gif to show what we are computing
function create_gif_animation(mesh, nsteps)
@gif for t in LinRange(0, 2π, nsteps)
f(x,y) = exp(-((cos(t)*x-sin(t)*y)-1)^2/0.2)*exp(-((sin(t)*x+cos(t)*y)-1)^2/0.2)
= plot(mesh.x, mesh.y, f, st = [:contour])
p
plot!(p[1])
plot!(zlims=(-0.01,1.01))
end
end
create_gif_animation (generic function with 1 method)
create_gif_animation(mesh, 100);
[ Info: Saved animation to /home/runner/work/math-julia/math-julia/tmp.gif
Function to compute error
function compute_error(f, f_exact)
maximum(abs.(f .- f_exact))
end
compute_error (generic function with 1 method)
Naive translation of a matlab code
function naive_translation_from_matlab(final_time, nsteps, mesh::TwoDMesh)
= final_time/nsteps
dt
= 2π/(mesh.xmax-mesh.xmin) .* fftfreq(mesh.nx, mesh.nx)
kx = 2π/(mesh.ymax-mesh.ymin) .* fftfreq(mesh.ny, mesh.ny)
ky
= compute_exact_solution(0.0, mesh)
f
for n=1:nsteps
for (i, x) in enumerate(mesh.x)
:]=real(ifft(exp.(1im*x*ky*tan(dt/2)).*fft(f[i,:])))
f[i,end
for (j, y) in enumerate(mesh.y)
:,j]=real(ifft(exp.(-1im*y*kx*sin(dt)).*fft(f[:,j])))
f[end
for (i, x) in enumerate(mesh.x)
:]=real(ifft(exp.(1im*x*ky*tan(dt/2)).*fft(f[i,:])))
f[i,end
end
fend
naive_translation_from_matlab (generic function with 1 method)
= 1000, 200
nsteps, final_time = naive_translation_from_matlab(final_time, nsteps, mesh)
sol1 = compute_exact_solution(final_time, mesh)
sol2 println( " error = ", compute_error(sol1, sol2))
@btime naive_translation_from_matlab(final_time, nsteps, mesh);
error = 1.9817480989559044e-13
10.670 s (7168002 allocations: 7.95 GiB)
Vectorized version
- We remove the for loops over direction
x
andy
by creating the 2d arraysexky
andekxy
. - We save cpu time by computing them before the loop over time
function vectorized(final_time, nsteps, mesh::TwoDMesh)
= final_time/nsteps
dt
= 2π/(mesh.xmax-mesh.xmin) .* fftfreq(mesh.nx, mesh.nx)
kx = 2π/(mesh.ymax-mesh.ymin) .* fftfreq(mesh.ny, mesh.ny)
ky
= compute_exact_solution(0.0, mesh)
f
= exp.( 1im*tan(dt/2) .* mesh.x .* ky')
exky = exp.(-1im*sin(dt) .* mesh.y' .* kx )
ekxy
for n = 1:nsteps
.= real(ifft(exky .* fft(f, 2), 2))
f .= real(ifft(ekxy .* fft(f, 1), 1))
f .= real(ifft(exky .* fft(f, 2), 2))
f end
fend
vectorized (generic function with 1 method)
= 1000, 200
nsteps, final_time = vectorized(final_time, nsteps, mesh)
sol1 = compute_exact_solution(final_time, mesh)
sol2 println( " error = ", compute_error(sol1, sol2))
@btime vectorized(final_time, nsteps, mesh);
error = 1.7908811801681452e-13
2.149 s (48006 allocations: 6.60 GiB)
Inplace computation
- We remove the Float64-Complex128 conversion by allocating the distribution function
f
as a Complex array - Note that we need to use the inplace assignement operator “.=” to initialize the
f
array. - We use inplace computation for fft with the “bang” operator
!
function inplace(final_time, nsteps, mesh::TwoDMesh)
= final_time/nsteps
dt
= 2π/(mesh.xmax-mesh.xmin)*[0:mesh.nx÷2-1;mesh.nx÷2-mesh.nx:-1]
kx = 2π/(mesh.ymax-mesh.ymin)*[0:mesh.ny÷2-1;mesh.ny÷2-mesh.ny:-1]
ky
= zeros(ComplexF64,(mesh.nx,mesh.ny))
f .= compute_exact_solution(0.0, mesh)
f
= exp.( 1im*tan(dt/2) .* mesh.x .* ky')
exky = exp.(-1im*sin(dt) .* mesh.y' .* kx )
ekxy
for n = 1:nsteps
fft!(f, 2)
.= exky .* f
f ifft!(f,2)
fft!(f, 1)
.= ekxy .* f
f ifft!(f, 1)
fft!(f, 2)
.= exky .* f
f ifft!(f,2)
end
real(f)
end
inplace (generic function with 1 method)
= 1000, 200
nsteps, final_time = inplace(final_time, nsteps, mesh)
sol1 = compute_exact_solution(final_time, mesh)
sol2 println( " error = ", compute_error(sol1, sol2))
@btime inplace(final_time, nsteps, mesh);
error = 1.8686623861233415e-13
1.499 s (18014 allocations: 3.56 MiB)
Use plans for fft
- When you apply multiple fft on array with same shape and size, it is recommended to use fftw plan to improve computations.
- Let’s try to initialize our two fft along x and y with plans.
function with_fft_plans(final_time, nsteps, mesh::TwoDMesh)
= final_time/nsteps
dt
= 2π/(mesh.xmax-mesh.xmin)*[0:mesh.nx÷2-1;mesh.nx÷2-mesh.nx:-1]
kx = 2π/(mesh.ymax-mesh.ymin)*[0:mesh.ny÷2-1;mesh.ny÷2-mesh.ny:-1]
ky
= zeros(ComplexF64,(mesh.nx,mesh.ny))
f .= compute_exact_solution(0.0, mesh)
f = similar(f)
f̂
= exp.( 1im*tan(dt/2) .* mesh.x .* ky')
exky = exp.(-1im*sin(dt) .* mesh.y' .* kx )
ekxy
= plan_fft(f, 1)
Px = plan_fft(f, 2)
Py
for n = 1:nsteps
.= Py * f
f̂ .= f̂ .* exky
f̂ .= Py \ f̂
f
.= Px * f
f̂ .= f̂ .* ekxy
f̂ .= Px \ f̂
f
.= Py * f
f̂ .= f̂ .* exky
f̂ .= Py \ f̂
f
end
real(f)
end
with_fft_plans (generic function with 1 method)
= 1000, 200
nsteps, final_time = with_fft_plans(final_time, nsteps, mesh)
sol1 = compute_exact_solution(final_time, mesh)
sol2 println( " error = ", compute_error(sol1, sol2))
@btime with_fft_plans(final_time, nsteps, mesh);
error = 1.8686623861233415e-13
2.160 s (18030 allocations: 2.93 GiB)
Inplace computation and fft plans
To apply fft plan to an array A, we use a preallocated output array  by calling mul!(Â, plan, A)
. The input array A must be a complex floating-point array like the output Â. The inverse-transform is computed inplace by applying inv(P)
with ldiv!(A, P, Â)
.
function with_fft_plans_inplace(final_time, nsteps, mesh::TwoDMesh)
= final_time/nsteps
dt
= 2π/(mesh.xmax-mesh.xmin) .* fftfreq(mesh.nx, mesh.nx)
kx = 2π/(mesh.ymax-mesh.ymin) .* fftfreq(mesh.ny, mesh.ny)
ky
= zeros(ComplexF64,(mesh.nx,mesh.ny))
f .= compute_exact_solution(0.0, mesh)
f = similar(f)
f̂
= exp.( 1im*tan(dt/2) .* mesh.x .* ky')
exky = exp.(-1im*sin(dt) .* mesh.y' .* kx )
ekxy
= plan_fft(f, 1)
Px = plan_fft(f, 2)
Py
for n = 1:nsteps
mul!(f̂, Py, f)
.= f̂ .* exky
f̂ ldiv!(f, Py, f̂)
mul!(f̂, Px, f)
.= f̂ .* ekxy
f̂ ldiv!(f, Px, f̂)
mul!(f̂, Py, f)
.= f̂ .* exky
f̂ ldiv!(f, Py, f̂)
end
real(f)
end
with_fft_plans_inplace (generic function with 1 method)
= 1000, 200
nsteps, final_time = with_fft_plans_inplace(final_time, nsteps, mesh)
sol1 = compute_exact_solution(final_time, mesh)
sol2 println( " error = ", compute_error(sol1, sol2 ))
@btime with_fft_plans_inplace(final_time, nsteps, mesh);
error = 1.8686623861233415e-13
1.587 s (6026 allocations: 3.82 MiB)
Explicit transpose of f
- Multidimensional arrays in Julia are stored in column-major order.
- FFTs along y are slower than FFTs along x
- We can speed-up the computation by allocating the transposed
f
and transpose f for each advection along y.
function with_fft_transposed(final_time, nsteps, mesh::TwoDMesh)
= final_time/nsteps
dt
= 2π/(mesh.xmax-mesh.xmin) .* fftfreq(mesh.nx, mesh.nx)
kx = 2π/(mesh.ymax-mesh.ymin) .* fftfreq(mesh.ny, mesh.ny)
ky
= zeros(ComplexF64,(mesh.nx,mesh.ny))
f = similar(f)
f̂ = zeros(ComplexF64,(mesh.ny,mesh.nx))
fᵗ = similar(fᵗ)
f̂ᵗ
= exp.( 1im*tan(dt/2) .* mesh.x' .* ky )
exky = exp.(-1im*sin(dt) .* mesh.y' .* kx )
ekxy
set_num_threads(4)
FFTW.= plan_fft(f, 1, flags=FFTW.PATIENT)
Px = plan_fft(fᵗ, 1, flags=FFTW.PATIENT)
Py
.= compute_exact_solution(0.0, mesh)
f
for n = 1:nsteps
transpose!(fᵗ,f)
mul!(f̂ᵗ, Py, fᵗ)
.= f̂ᵗ .* exky
f̂ᵗ ldiv!(fᵗ, Py, f̂ᵗ)
transpose!(f,fᵗ)
mul!(f̂, Px, f)
.= f̂ .* ekxy
f̂ ldiv!(f, Px, f̂)
transpose!(fᵗ,f)
mul!(f̂ᵗ, Py, fᵗ)
.= f̂ᵗ .* exky
f̂ᵗ ldiv!(fᵗ, Py, f̂ᵗ)
transpose!(f,fᵗ)
end
real(f)
end
with_fft_transposed (generic function with 1 method)
= 1000, 200
nsteps, final_time = with_fft_transposed(final_time, nsteps, mesh)
sol1 = compute_exact_solution(final_time, mesh)
sol2 println( " error = ", compute_error(sol1, sol2))
@btime with_fft_transposed(final_time, nsteps, mesh);
error = 1.8435917383389174e-13
634.783 ms (6038 allocations: 6.82 MiB)
= 400π, 1000
final_time, nsteps = TwoDMesh(-π, π, 512, -π, π, 256) mesh
TwoDMesh(512, 256, -3.141592653589793, 3.141592653589793, -3.141592653589793, 3.141592653589793, 0.01227184630308513, 0.02454369260617026, [-3.141592653589793, -3.129320807286708, -3.1170489609836225, -3.104777114680538, -3.0925052683774528, -3.0802334220743677, -3.067961575771282, -3.055689729468197, -3.0434178831651124, -3.0311460368620273 … 3.018874190558941, 3.031146036862027, 3.0434178831651115, 3.0556897294681966, 3.0679615757712817, 3.0802334220743672, 3.092505268377452, 3.104777114680537, 3.117048960983622, 3.1293208072867076], [-3.141592653589793, -3.117048960983623, -3.0925052683774528, -3.0679615757712826, -3.0434178831651124, -3.0188741905589414, -2.994330497952771, -2.969786805346601, -2.945243112740431, -2.9206994201342606 … 2.8961557275280905, 2.9206994201342606, 2.945243112740431, 2.969786805346601, 2.994330497952771, 3.0188741905589414, 3.0434178831651124, 3.0679615757712826, 3.0925052683774528, 3.117048960983623])
= @benchmark inplace(final_time, nsteps, mesh)
inplace_bench = @benchmark vectorized(final_time, nsteps, mesh)
vectorized_bench = @benchmark with_fft_plans(final_time, nsteps, mesh)
with_fft_plans_bench = @benchmark with_fft_plans_inplace(final_time, nsteps, mesh)
with_fft_plans_inplace_bench = @benchmark with_fft_transposed(final_time, nsteps, mesh) with_fft_transposed_bench
BenchmarkTools.Trial: 3 samples with 1 evaluation. Range (min … max): 2.432 s … 2.446 s ┊ GC (min … max): 0.00% … 0.04% Time (median): 2.438 s ┊ GC (median): 0.00% Time (mean ± σ): 2.438 s ± 7.039 ms ┊ GC (mean ± σ): 0.01% ± 0.02% █ █ █ █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁ 2.43 s Histogram: frequency by time 2.45 s < Memory estimate: 26.32 MiB, allocs estimate: 6038.
= Dict()
d "vectorized"] = minimum(vectorized_bench.times) / 1e6
d["inplace"] = minimum(inplace_bench.times) / 1e6
d["with_fft_plans"] = minimum(with_fft_plans_bench.times) / 1e6
d["with_fft_plans_inplace"] = minimum(with_fft_plans_inplace_bench.times) / 1e6
d["with_fft_transposed"] = minimum(with_fft_transposed_bench.times) / 1e6; d[
for (key, value) in sort(collect(d), by=last)
println(rpad(key, 25, "."), lpad(round(value, digits=1), 6, "."))
end
with_fft_transposed......2431.7
inplace..................4035.2
with_fft_plans_inplace...4137.8
with_fft_plans...........6020.9
vectorized...............6167.4
Conclusion
- Using pre-allocations of memory and inplace computation is very important
- Try to always do computation on data contiguous in memory
- Use
plans
for fft