From 29072e245f798d02218c5c3bae4769e03873eecf Mon Sep 17 00:00:00 2001 From: Gaspard Jankowiak Date: Wed, 15 May 2024 15:27:28 +0200 Subject: [PATCH] add_log_barriers --- Utils.jl | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 75 insertions(+), 2 deletions(-) diff --git a/Utils.jl b/Utils.jl index 7692bbd..80748f1 100644 --- a/Utils.jl +++ b/Utils.jl @@ -73,9 +73,7 @@ function lowpass_filter(i; σ=1) i_padded = view(i, [fill(firstindex(i), half_w); eachindex(i); fill(lastindex(i), half_w)]) k = exp.(-0.5 * ((x / σ) .^ 2)) / (2 * σ * sqrt(2π)) k ./= sum(k) - @show sum(k) r = DSP.conv(i_padded, k)[2half_w+1:end-2half_w] - @show size(r) return r end @@ -83,4 +81,79 @@ function chi2(I_data, I_model, err) return sum((I_data .- I_model ./ err) .^ 2) end +function add_log_barriers(f::Function, constraints::Vector{Tuple{T,T}}; padding_factor=0.1) where {T<:Real} + n = length(constraints) + + lower_shifts = Float64[] + upper_shifts = Float64[] + + padding_factors = if padding_factor isa Vector + padding_factor + else + fill(padding_factor, n) + end + + k = 0 + barrier_widths = map(constraints) do (lower, upper) + k += 1 + if upper <= lower + throw("admissible set is empty, check constraints") + end + if isfinite(lower) + push!(lower_shifts, lower) + if isfinite(upper) + push!(upper_shifts, upper) + return padding_factors[k] * (upper - lower) + else + push!(upper_shifts, Inf) + return 1.0 + end + else + push!(lower_shifts, -Inf) + if isfinite(upper) + push!(upper_shifts, upper) + return 1.0 + else + push!(upper_shifts, Inf) + return 1.0 + end + end + end + + function std_lower_barrier(s::Float64) + return max(-log(s + 1.0), 0.0)^2 + end + + function barrier(x::Vector{Float64}) + lb = (x .- lower_shifts) ./ barrier_widths + ub = (upper_shifts .- x) ./ barrier_widths + + lb_violations = findall(<=(-1.0), lb) + ub_violations = findall(<=(-1.0), ub) + + if !isempty(lb_violations) + err_msg = "" + for i in lb_violations + err_msg *= " · x[$i] = $(x[i]) < $(constraints[i][1]), width = $(barrier_widths[i])" + end + @error err_msg + throw("lower bound violation for indice(s) $(lb_violations)") + end + + if !isempty(ub_violations) + err_msg = "" + for i in ub_violations + err_msg *= " · x[$i] = $(x[i]) > $(constraints[i][2]), width = $(barrier_widths[i])" + end + @error err_msg + throw("upper bound violation for indice(s) $(ub_violations)") + end + + v = sum(std_lower_barrier.(lb) + std_lower_barrier.(ub)) + return v + end + + return x::Vector{Float64} -> f(x) .+ barrier(x) +end + end # module Utils