add handling of nonallocating functions, add plot option

This commit is contained in:
Gaspard Jankowiak 2025-02-12 11:04:26 +01:00
commit 3b80b8c720
4 changed files with 62 additions and 8 deletions

View file

@ -1,9 +1,10 @@
module TaylorTest
export check
export check, check!
import LinearAlgebra: norm, dot
import TensorOperations: @tensor
import UnicodePlots
"""
`check(f, Jf, x[, constant_components]; f_kwargs...)`
@ -19,14 +20,18 @@ julia> f = x -> cos(x); Jf = x -> -sin(x); check(f, Jf, rand())
true
```
"""
function check(f, Jf, x, constant_components::Vector{Int}=Int[]; f_kwargs...)
direction = 2 * (rand(size(x)...) .- 0.5)
for cst in constant_components
direction[cst] = 0
end
function check(f, Jf, x, constant_components::Vector{Int}=Int[]; taylortestplot::Bool=false, taylortestdirection=nothing, f_kwargs...)
if isnothing(taylortestdirection)
direction = 2 * (rand(size(x)...) .- 0.5)
for cst in constant_components
direction[cst] = 0
end
if direction isa Array
direction ./= norm(direction)
if direction isa Array
direction ./= norm(direction)
end
else
direction = taylortestdirection
end
ε_array = 10.0 .^ (-5:1:-1)
@ -50,6 +55,10 @@ function check(f, Jf, x, constant_components::Vector{Int}=Int[]; f_kwargs...)
error_array = [norm(f(x + ε * direction; f_kwargs...) - f_x - ε * Jf_x_direction) for ε in ε_array]
end
if taylortestplot
println(UnicodePlots.lineplot(ε_array, error_array; xscale=:log10, yscale=:log10, xlabel="ε", title="|| f(x + ε·d) - f(x) - ε Jf(x)[d] ||"))
end
m = maximum(error_array)
if m < 1e-8
@warn "f looks constant or linear!"
@ -61,4 +70,23 @@ function check(f, Jf, x, constant_components::Vector{Int}=Int[]; f_kwargs...)
return isapprox(order, 1; atol=0.5)
end
function check!(f!, Jf!, x, size_f_x, size_Jf_x, constant_components::Vector{Int}=Int[];
taylortestdirection=nothing, taylortestplot::Bool=false, f_kwargs...)
f = x -> begin
# must be allocated everytime to avoid aliasing
f_x = zeros(size_f_x...)
f!(f_x, x; f_kwargs...)
f_x
end
Jf = x -> begin
# must be allocated everytime to avoid aliasing
Jf_x = zeros(size_Jf_x...)
Jf!(Jf_x, x; f_kwargs...)
Jf_x
end
return check(f, Jf, x, constant_components, taylortestdirection=taylortestdirection, taylortestplot=taylortestplot)
end
end # module TaylorTest