add_log_barriers
This commit is contained in:
parent
32e6a23e65
commit
29072e245f
1 changed files with 75 additions and 2 deletions
77
Utils.jl
77
Utils.jl
|
@ -73,9 +73,7 @@ function lowpass_filter(i; σ=1)
|
|||
i_padded = view(i, [fill(firstindex(i), half_w); eachindex(i); fill(lastindex(i), half_w)])
|
||||
k = exp.(-0.5 * ((x / σ) .^ 2)) / (2 * σ * sqrt(2π))
|
||||
k ./= sum(k)
|
||||
@show sum(k)
|
||||
r = DSP.conv(i_padded, k)[2half_w+1:end-2half_w]
|
||||
@show size(r)
|
||||
return r
|
||||
end
|
||||
|
||||
|
@ -83,4 +81,79 @@ function chi2(I_data, I_model, err)
|
|||
return sum((I_data .- I_model ./ err) .^ 2)
|
||||
end
|
||||
|
||||
function add_log_barriers(f::Function, constraints::Vector{Tuple{T,T}}; padding_factor=0.1) where {T<:Real}
|
||||
n = length(constraints)
|
||||
|
||||
lower_shifts = Float64[]
|
||||
upper_shifts = Float64[]
|
||||
|
||||
padding_factors = if padding_factor isa Vector
|
||||
padding_factor
|
||||
else
|
||||
fill(padding_factor, n)
|
||||
end
|
||||
|
||||
k = 0
|
||||
barrier_widths = map(constraints) do (lower, upper)
|
||||
k += 1
|
||||
if upper <= lower
|
||||
throw("admissible set is empty, check constraints")
|
||||
end
|
||||
if isfinite(lower)
|
||||
push!(lower_shifts, lower)
|
||||
if isfinite(upper)
|
||||
push!(upper_shifts, upper)
|
||||
return padding_factors[k] * (upper - lower)
|
||||
else
|
||||
push!(upper_shifts, Inf)
|
||||
return 1.0
|
||||
end
|
||||
else
|
||||
push!(lower_shifts, -Inf)
|
||||
if isfinite(upper)
|
||||
push!(upper_shifts, upper)
|
||||
return 1.0
|
||||
else
|
||||
push!(upper_shifts, Inf)
|
||||
return 1.0
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function std_lower_barrier(s::Float64)
|
||||
return max(-log(s + 1.0), 0.0)^2
|
||||
end
|
||||
|
||||
function barrier(x::Vector{Float64})
|
||||
lb = (x .- lower_shifts) ./ barrier_widths
|
||||
ub = (upper_shifts .- x) ./ barrier_widths
|
||||
|
||||
lb_violations = findall(<=(-1.0), lb)
|
||||
ub_violations = findall(<=(-1.0), ub)
|
||||
|
||||
if !isempty(lb_violations)
|
||||
err_msg = ""
|
||||
for i in lb_violations
|
||||
err_msg *= " · x[$i] = $(x[i]) < $(constraints[i][1]), width = $(barrier_widths[i])"
|
||||
end
|
||||
@error err_msg
|
||||
throw("lower bound violation for indice(s) $(lb_violations)")
|
||||
end
|
||||
|
||||
if !isempty(ub_violations)
|
||||
err_msg = ""
|
||||
for i in ub_violations
|
||||
err_msg *= " · x[$i] = $(x[i]) > $(constraints[i][2]), width = $(barrier_widths[i])"
|
||||
end
|
||||
@error err_msg
|
||||
throw("upper bound violation for indice(s) $(ub_violations)")
|
||||
end
|
||||
|
||||
v = sum(std_lower_barrier.(lb) + std_lower_barrier.(ub))
|
||||
return v
|
||||
end
|
||||
|
||||
return x::Vector{Float64} -> f(x) .+ barrier(x)
|
||||
end
|
||||
|
||||
end # module Utils
|
||||
|
|
Loading…
Reference in a new issue