From 3b80b8c7202985bfff7e25b0091086ba652a63b1 Mon Sep 17 00:00:00 2001 From: Gaspard Jankowiak Date: Wed, 12 Feb 2025 11:04:26 +0100 Subject: [PATCH] add handling of nonallocating functions, add plot option --- Project.toml | 2 ++ src/TaylorTest.jl | 44 +++++++++++++++++++++++++++++++++++-------- test/nonallocating.jl | 23 ++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 62 insertions(+), 8 deletions(-) create mode 100644 test/nonallocating.jl diff --git a/Project.toml b/Project.toml index f7c06c9..e71b04b 100644 --- a/Project.toml +++ b/Project.toml @@ -7,12 +7,14 @@ version = "0.1.2" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" [compat] LinearAlgebra = "1.11.0" SpecialFunctions = "2.5.0" TensorOperations = "5.1.3" Test = "1.11.0" +UnicodePlots = "3.7.2" [extras] SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/src/TaylorTest.jl b/src/TaylorTest.jl index 96575ef..419bb26 100644 --- a/src/TaylorTest.jl +++ b/src/TaylorTest.jl @@ -1,9 +1,10 @@ module TaylorTest -export check +export check, check! import LinearAlgebra: norm, dot import TensorOperations: @tensor +import UnicodePlots """ `check(f, Jf, x[, constant_components]; f_kwargs...)` @@ -19,14 +20,18 @@ julia> f = x -> cos(x); Jf = x -> -sin(x); check(f, Jf, rand()) true ``` """ -function check(f, Jf, x, constant_components::Vector{Int}=Int[]; f_kwargs...) - direction = 2 * (rand(size(x)...) .- 0.5) - for cst in constant_components - direction[cst] = 0 - end +function check(f, Jf, x, constant_components::Vector{Int}=Int[]; taylortestplot::Bool=false, taylortestdirection=nothing, f_kwargs...) + if isnothing(taylortestdirection) + direction = 2 * (rand(size(x)...) .- 0.5) + for cst in constant_components + direction[cst] = 0 + end - if direction isa Array - direction ./= norm(direction) + if direction isa Array + direction ./= norm(direction) + end + else + direction = taylortestdirection end ε_array = 10.0 .^ (-5:1:-1) @@ -50,6 +55,10 @@ function check(f, Jf, x, constant_components::Vector{Int}=Int[]; f_kwargs...) error_array = [norm(f(x + ε * direction; f_kwargs...) - f_x - ε * Jf_x_direction) for ε in ε_array] end + if taylortestplot + println(UnicodePlots.lineplot(ε_array, error_array; xscale=:log10, yscale=:log10, xlabel="ε", title="|| f(x + ε·d) - f(x) - ε Jf(x)[d] ||")) + end + m = maximum(error_array) if m < 1e-8 @warn "f looks constant or linear!" @@ -61,4 +70,23 @@ function check(f, Jf, x, constant_components::Vector{Int}=Int[]; f_kwargs...) return isapprox(order, 1; atol=0.5) end +function check!(f!, Jf!, x, size_f_x, size_Jf_x, constant_components::Vector{Int}=Int[]; + taylortestdirection=nothing, taylortestplot::Bool=false, f_kwargs...) + + f = x -> begin + # must be allocated everytime to avoid aliasing + f_x = zeros(size_f_x...) + f!(f_x, x; f_kwargs...) + f_x + end + Jf = x -> begin + # must be allocated everytime to avoid aliasing + Jf_x = zeros(size_Jf_x...) + Jf!(Jf_x, x; f_kwargs...) + Jf_x + end + + return check(f, Jf, x, constant_components, taylortestdirection=taylortestdirection, taylortestplot=taylortestplot) +end + end # module TaylorTest diff --git a/test/nonallocating.jl b/test/nonallocating.jl new file mode 100644 index 0000000..22226c7 --- /dev/null +++ b/test/nonallocating.jl @@ -0,0 +1,23 @@ +@testset "Non-allocating functions" begin + f! = (f_x, x) -> (f_x[1] = x[1]^2 - 2x[2]^2; f_x) + Jf! = (Jf_x, x) -> (Jf_x[1] = 2x[1]; Jf_x[2] = -4x[2]; Jf_x) + Hf! = (Hf_x, x) -> begin + Hf_x[1, 1] = 2.0 + Hf_x[1, 2] = Hf_x[2, 1] = 0.0 + Hf_x[2, 2] = -4.0 + Hf_x + end + + size_f_x = (1,) + size_Jf_x = (1, 2) + size_Hf_x = (2, 2) + + f_x = zeros(size_f_x...) + Jf_x = zeros(size_Jf_x...) + Hf_x = zeros(size_Hf_x...) + + x = rand(2) + + @test TaylorTest.check!(f!, Jf!, x, size_f_x, size_Jf_x) + @test TaylorTest.check!(Jf!, Hf!, x, size_Jf_x, size_Hf_x) +end diff --git a/test/runtests.jl b/test/runtests.jl index 09a9ec0..bfbcbbe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,3 +9,4 @@ include("trig_functions.jl") include("gauss.jl") include("erf.jl") include("tensors.jl") +include("nonallocating.jl")