64 lines
1.9 KiB
Julia
64 lines
1.9 KiB
Julia
module TaylorTest
|
|
|
|
export check
|
|
|
|
import LinearAlgebra: norm, dot
|
|
import TensorOperations: @tensor
|
|
|
|
"""
|
|
`check(f, Jf, x[, constant_components]; f_kwargs...)`
|
|
|
|
Returns true if `Jf` approximates the derivative/gradient/Jacobian of `f` at point `x` (along a random direction).
|
|
`f_kwargs` are keywords arguments to be passed to `f` and `Jf`.
|
|
`constant_components` is an optional `Vector{Int}` corresponding to components of the direction which should be set to zero,
|
|
effectively ignoring the dependency of `f` on these components.
|
|
|
|
# Examples
|
|
```julia-repl
|
|
julia> f = x -> cos(x); Jf = x -> -sin(x); check(f, Jf, rand())
|
|
true
|
|
```
|
|
"""
|
|
function check(f, Jf, x, constant_components::Vector{Int}=Int[]; f_kwargs...)
|
|
direction = 2 * (rand(size(x)...) .- 0.5)
|
|
for cst in constant_components
|
|
direction[cst] = 0
|
|
end
|
|
|
|
if direction isa Array
|
|
direction ./= norm(direction)
|
|
end
|
|
|
|
ε_array = 10.0 .^ (-5:1:-1)
|
|
|
|
n = size(ε_array)
|
|
f_x = f(x; f_kwargs...)
|
|
Jf_x = Jf(x; f_kwargs...)
|
|
|
|
if ndims(f_x) == 0
|
|
# if f is a scalar function, avoid potential ∇f vs Jac(f) by using dot
|
|
error_array = [norm(f(x + ε * direction; f_kwargs...) - (f_x + ε * dot(Jf_x, direction))) for ε in ε_array]
|
|
elseif ndims(f_x) == 1 || prod(size(f_x)) == maximum(size(f_x))
|
|
# f is essentially a vector, potentially horizontal
|
|
# f simply use matrix multiplication
|
|
error_array = [norm(vec(f(x + ε * direction; f_kwargs...) - f_x) - ε * Jf_x * direction) for ε in ε_array]
|
|
else
|
|
@tensor begin
|
|
Jf_x_direction[i,j] := Jf_x[i,j,k] * direction[k]
|
|
end
|
|
|
|
error_array = [norm(f(x + ε * direction; f_kwargs...) - f_x - ε * Jf_x_direction) for ε in ε_array]
|
|
end
|
|
|
|
m = maximum(error_array)
|
|
if m < 1e-8
|
|
@warn "f looks linear!"
|
|
return true
|
|
end
|
|
|
|
order = trunc(([ones(n) log.(ε_array)]\log.(error_array))[2] - 1; digits=2)
|
|
@info "Approximation order ~ $order"
|
|
return isapprox(order, 1; atol=0.5)
|
|
end
|
|
|
|
end # module TaylorTest
|