add handling of nonallocating functions, add plot option
This commit is contained in:
parent
422c4ab05a
commit
3b80b8c720
4 changed files with 62 additions and 8 deletions
|
@ -7,12 +7,14 @@ version = "0.1.2"
|
|||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
|
||||
|
||||
[compat]
|
||||
LinearAlgebra = "1.11.0"
|
||||
SpecialFunctions = "2.5.0"
|
||||
TensorOperations = "5.1.3"
|
||||
Test = "1.11.0"
|
||||
UnicodePlots = "3.7.2"
|
||||
|
||||
[extras]
|
||||
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
module TaylorTest
|
||||
|
||||
export check
|
||||
export check, check!
|
||||
|
||||
import LinearAlgebra: norm, dot
|
||||
import TensorOperations: @tensor
|
||||
import UnicodePlots
|
||||
|
||||
"""
|
||||
`check(f, Jf, x[, constant_components]; f_kwargs...)`
|
||||
|
@ -19,14 +20,18 @@ julia> f = x -> cos(x); Jf = x -> -sin(x); check(f, Jf, rand())
|
|||
true
|
||||
```
|
||||
"""
|
||||
function check(f, Jf, x, constant_components::Vector{Int}=Int[]; f_kwargs...)
|
||||
direction = 2 * (rand(size(x)...) .- 0.5)
|
||||
for cst in constant_components
|
||||
direction[cst] = 0
|
||||
end
|
||||
function check(f, Jf, x, constant_components::Vector{Int}=Int[]; taylortestplot::Bool=false, taylortestdirection=nothing, f_kwargs...)
|
||||
if isnothing(taylortestdirection)
|
||||
direction = 2 * (rand(size(x)...) .- 0.5)
|
||||
for cst in constant_components
|
||||
direction[cst] = 0
|
||||
end
|
||||
|
||||
if direction isa Array
|
||||
direction ./= norm(direction)
|
||||
if direction isa Array
|
||||
direction ./= norm(direction)
|
||||
end
|
||||
else
|
||||
direction = taylortestdirection
|
||||
end
|
||||
|
||||
ε_array = 10.0 .^ (-5:1:-1)
|
||||
|
@ -50,6 +55,10 @@ function check(f, Jf, x, constant_components::Vector{Int}=Int[]; f_kwargs...)
|
|||
error_array = [norm(f(x + ε * direction; f_kwargs...) - f_x - ε * Jf_x_direction) for ε in ε_array]
|
||||
end
|
||||
|
||||
if taylortestplot
|
||||
println(UnicodePlots.lineplot(ε_array, error_array; xscale=:log10, yscale=:log10, xlabel="ε", title="|| f(x + ε·d) - f(x) - ε Jf(x)[d] ||"))
|
||||
end
|
||||
|
||||
m = maximum(error_array)
|
||||
if m < 1e-8
|
||||
@warn "f looks constant or linear!"
|
||||
|
@ -61,4 +70,23 @@ function check(f, Jf, x, constant_components::Vector{Int}=Int[]; f_kwargs...)
|
|||
return isapprox(order, 1; atol=0.5)
|
||||
end
|
||||
|
||||
function check!(f!, Jf!, x, size_f_x, size_Jf_x, constant_components::Vector{Int}=Int[];
|
||||
taylortestdirection=nothing, taylortestplot::Bool=false, f_kwargs...)
|
||||
|
||||
f = x -> begin
|
||||
# must be allocated everytime to avoid aliasing
|
||||
f_x = zeros(size_f_x...)
|
||||
f!(f_x, x; f_kwargs...)
|
||||
f_x
|
||||
end
|
||||
Jf = x -> begin
|
||||
# must be allocated everytime to avoid aliasing
|
||||
Jf_x = zeros(size_Jf_x...)
|
||||
Jf!(Jf_x, x; f_kwargs...)
|
||||
Jf_x
|
||||
end
|
||||
|
||||
return check(f, Jf, x, constant_components, taylortestdirection=taylortestdirection, taylortestplot=taylortestplot)
|
||||
end
|
||||
|
||||
end # module TaylorTest
|
||||
|
|
23
test/nonallocating.jl
Normal file
23
test/nonallocating.jl
Normal file
|
@ -0,0 +1,23 @@
|
|||
@testset "Non-allocating functions" begin
|
||||
f! = (f_x, x) -> (f_x[1] = x[1]^2 - 2x[2]^2; f_x)
|
||||
Jf! = (Jf_x, x) -> (Jf_x[1] = 2x[1]; Jf_x[2] = -4x[2]; Jf_x)
|
||||
Hf! = (Hf_x, x) -> begin
|
||||
Hf_x[1, 1] = 2.0
|
||||
Hf_x[1, 2] = Hf_x[2, 1] = 0.0
|
||||
Hf_x[2, 2] = -4.0
|
||||
Hf_x
|
||||
end
|
||||
|
||||
size_f_x = (1,)
|
||||
size_Jf_x = (1, 2)
|
||||
size_Hf_x = (2, 2)
|
||||
|
||||
f_x = zeros(size_f_x...)
|
||||
Jf_x = zeros(size_Jf_x...)
|
||||
Hf_x = zeros(size_Hf_x...)
|
||||
|
||||
x = rand(2)
|
||||
|
||||
@test TaylorTest.check!(f!, Jf!, x, size_f_x, size_Jf_x)
|
||||
@test TaylorTest.check!(Jf!, Hf!, x, size_Jf_x, size_Hf_x)
|
||||
end
|
|
@ -9,3 +9,4 @@ include("trig_functions.jl")
|
|||
include("gauss.jl")
|
||||
include("erf.jl")
|
||||
include("tensors.jl")
|
||||
include("nonallocating.jl")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue