add constrained GaussNewton

This commit is contained in:
Gaspard Jankowiak 2024-05-22 13:23:16 +02:00
parent b2a7f1ba13
commit 80b7fe9c8d
2 changed files with 25 additions and 7 deletions

View file

@ -125,7 +125,7 @@ function lowpass_filter(i; σ=1)
end end
function residuals(I_data, I_model, err) function residuals(I_data, I_model, err)
return (I_data .- I_model ./ err) .^ 2 return (I_data .- I_model ./ err)
end end
function chi2(I_data, I_model, err) function chi2(I_data, I_model, err)
@ -199,7 +199,7 @@ function add_log_barriers(f::Function, constraints::Vector{Tuple{T,T}}; padding_
if !isempty(ub_violations) if !isempty(ub_violations)
err_msg = "" err_msg = ""
for i in ub_violations for i in ub_violations
err_msg *= " · x[$i] = $(x[i]) > $(constraints[i][2]), width = $(barrier_widths[i])" err_msg *= "\n · x[$i] = $(x[i]) > $(constraints[i][2]), width = $(barrier_widths[i])"
end end
@error err_msg @error err_msg
throw("upper bound violation for indice(s) $(ub_violations)") throw("upper bound violation for indice(s) $(ub_violations)")
@ -226,4 +226,11 @@ function print_check_bounds(param_names, param_values, lb, ub)
end end
end end
function box_projector(lower::Vector{T}, upper::Vector{T}) where {T<:Real}
function clamp(x::Vector{Float64})
return max.(min.(x, upper), lower)
end
return clamp
end
end # module Utils end # module Utils

21
test.jl
View file

@ -59,6 +59,8 @@ w_all = Utils.compute_logscale_weights(q_all)
intensity_reduced, P_reduced, lb_reduced, ub_reduced = Utils.reduce_to_free_parameters(meta, PLUV.intensity, params_init, lower_bounds, upper_bounds, q) intensity_reduced, P_reduced, lb_reduced, ub_reduced = Utils.reduce_to_free_parameters(meta, PLUV.intensity, params_init, lower_bounds, upper_bounds, q)
projector = Utils.box_projector(lb_reduced, ub_reduced)
simple_bounds = collect(zip(lb_reduced, ub_reduced)) simple_bounds = collect(zip(lb_reduced, ub_reduced))
simple_init = 0.5 * (lb_reduced .+ ub_reduced) simple_init = 0.5 * (lb_reduced .+ ub_reduced)
@ -74,6 +76,8 @@ scaling_factors = fill(0.0, length(simple_init))
# scaling_factors[6] = 1e1 # scaling_factors[6] = 1e1
#scaling_factors[7] = 1e3 #scaling_factors[7] = 1e3
# scaling_factors[18] = 1e4
bounds = MH.boxconstraints(lb=lb_reduced, ub=ub_reduced) bounds = MH.boxconstraints(lb=lb_reduced, ub=ub_reduced)
function obj_χ2(P) function obj_χ2(P)
@ -107,11 +111,13 @@ I_mean_5k, _ = PLUV.intensity(mean_5k_full, q_all)
# Gauss-Newton # Gauss-Newton
if true if true
initial_guess = Utils.reduce_to_free_parameters(meta, collect(Float64, values(params_init))) data_file_init = Utils.reduce_to_free_parameters(meta, collect(Float64, values(params_init)))
initial_guess = simple_init #initial_guess = simple_init
# initial_guess = best_5k_params initial_guess = best_5k_params
_, result = GN.optimize(barriered_obj, initial_guess; show_trace=true, iscale=1, ZCP=1e-2, ZCPMIN=1e-2) _, result = GN.optimize(obj_residuals, initial_guess; projector=projector, autodiff=:central)
# _, result = GN.optimize(barriered_obj, initial_guess; show_trace=true, iscale=2, D0=scaling_factors, ZCP=1e-2, ZCPMIN=1e-2)
# _, result = GN.optimize(barriered_obj, initial_guess; show_trace=true, iscale=1, ZCP=1e-2, ZCPMIN=1e-2)
# _, result = GN.optimize(barriered_obj, initial_guess) # _, result = GN.optimize(barriered_obj, initial_guess)
if GN.has_converged(result) if GN.has_converged(result)
@ -128,6 +134,11 @@ if true
else else
@error "Gauss-Newton did not converge" @error "Gauss-Newton did not converge"
end end
@info "Best result GN: Σr² = $(sum(x -> x^2, obj_residuals(P_best)))"
@info "Best result SA: Σr² = $(sum(x -> x^2, obj_residuals(best_5k_params)))"
@info "Result box middle: Σr² = $(sum(x -> x^2, obj_residuals(simple_init)))"
@info "Result data file: Σr² = $(sum(x -> x^2, obj_residuals(data_file_init)))"
end end
# Metaheuristics # Metaheuristics
@ -147,7 +158,7 @@ if true
I_initial, _ = intensity_reduced(initial_guess) I_initial, _ = intensity_reduced(initial_guess)
M.lines!(ax, q, I_initial, label="initial", linestyle=:dash, linewidth=2) M.lines!(ax, q, I_initial, label="initial", linestyle=:dash, linewidth=2)
M.lines!(ax, q, I_best, label="MH best (julia)") M.lines!(ax, q, I_best, label="GN best (julia)")
M.scatter!(ax, q, I_data, label="data") M.scatter!(ax, q, I_data, label="data")
M.lines!(ax, q, I_best_5k, label="TSA best (5k)") M.lines!(ax, q, I_best_5k, label="TSA best (5k)")
M.lines!(ax, q_all, I_mean_5k, label="TSA mean (5k)") M.lines!(ax, q_all, I_mean_5k, label="TSA mean (5k)")