tsa_saxs/Utils.jl

160 lines
4.3 KiB
Julia
Raw Normal View History

2024-03-15 13:08:33 +01:00
module Utils
import TOML: parsefile
import DelimitedFiles: readdlm
import DSP
function load_init_params(filename::String)
meta = parsefile(filename)["metadata"]
params = NamedTuple(Dict(Symbol(p["symbol"]) => p["initial_value"] for p in meta["parameters"]))
lower_bounds = [p["min"] for p in meta["parameters"]]
upper_bounds = [p["max"] for p in meta["parameters"]]
return (meta=meta, params=params, lower_bounds=lower_bounds, upper_bounds=upper_bounds)
end
function select_columns(QIE_array, qmin, qmax, bin, err_mul)
#""Select q / I(q) / err(q) data - Change I(q) value from mm^-1 to nm^-1'
Q = Float64[]
IDATA = Float64[]
ERR = Float64[]
for i in axes(QIE_array, 1)
# keep only the values of q within bounds
if qmin <= QIE_array[i, 1] <= qmax
# skip all but the (k*bin)-th values
if i % bin == 0
if QIE_array[i, 2] > 0
push!(Q, QIE_array[i, 1])
push!(IDATA, 1e-6 * QIE_array[i, 2])
if err_mul == 0
push!(ERR, 0.01 * 1e-6 * QIE_array[i, 2])
elseif QIE_array[i, 3] * err_mul > 0.01 * QIE_array[i, 2]
push!(ERR, 1e-6 * QIE_array[i, 3] * err_mul)
else
push!(ERR, 0.01 * 1e-6 * QIE_array[i, 2])
end
end
end
end
end
return (Q, IDATA, ERR)
end
"""
compute_logscale_weights(q)
Compute the weights W, which are proportional to the distance in the log space
"""
function compute_logscale_weights(q)
diff_log_q = diff(log10.(q))
left_diff_log_q = view(diff_log_q, [firstindex(diff_log_q); eachindex(diff_log_q)])
right_diff_log_q = view(diff_log_q, [eachindex(diff_log_q); lastindex(diff_log_q)])
return left_diff_log_q + right_diff_log_q
end
function load_config(filename::String)
meta, params_init, lower_bounds, upper_bounds = load_init_params(filename)
f = open(meta["data_file"], "r")
qie_data = readdlm(f, '\t')
close(f)
return (meta=meta,
params_init=params_init,
lower_bounds=lower_bounds,
upper_bounds=upper_bounds,
qie_deta=qie_data)
end
function lowpass_filter(i; σ=1)
half_w = 10
x = range(-5, 5; length=2half_w + 1)
i_padded = view(i, [fill(firstindex(i), half_w); eachindex(i); fill(lastindex(i), half_w)])
k = exp.(-0.5 * ((x / σ) .^ 2)) / (2 * σ * sqrt(2π))
k ./= sum(k)
r = DSP.conv(i_padded, k)[2half_w+1:end-2half_w]
return r
end
function chi2(I_data, I_model, err)
return sum((I_data .- I_model ./ err) .^ 2)
end
2024-05-15 15:27:28 +02:00
function add_log_barriers(f::Function, constraints::Vector{Tuple{T,T}}; padding_factor=0.1) where {T<:Real}
n = length(constraints)
lower_shifts = Float64[]
upper_shifts = Float64[]
padding_factors = if padding_factor isa Vector
padding_factor
else
fill(padding_factor, n)
end
k = 0
barrier_widths = map(constraints) do (lower, upper)
k += 1
if upper <= lower
throw("admissible set is empty, check constraints")
end
if isfinite(lower)
push!(lower_shifts, lower)
if isfinite(upper)
push!(upper_shifts, upper)
return padding_factors[k] * (upper - lower)
else
push!(upper_shifts, Inf)
return 1.0
end
else
push!(lower_shifts, -Inf)
if isfinite(upper)
push!(upper_shifts, upper)
return 1.0
else
push!(upper_shifts, Inf)
return 1.0
end
end
end
function std_lower_barrier(s::Float64)
return max(-log(s + 1.0), 0.0)^2
end
function barrier(x::Vector{Float64})
lb = (x .- lower_shifts) ./ barrier_widths
ub = (upper_shifts .- x) ./ barrier_widths
lb_violations = findall(<=(-1.0), lb)
ub_violations = findall(<=(-1.0), ub)
if !isempty(lb_violations)
err_msg = ""
for i in lb_violations
err_msg *= " · x[$i] = $(x[i]) < $(constraints[i][1]), width = $(barrier_widths[i])"
end
@error err_msg
throw("lower bound violation for indice(s) $(lb_violations)")
end
if !isempty(ub_violations)
err_msg = ""
for i in ub_violations
err_msg *= " · x[$i] = $(x[i]) > $(constraints[i][2]), width = $(barrier_widths[i])"
end
@error err_msg
throw("upper bound violation for indice(s) $(ub_violations)")
end
v = sum(std_lower_barrier.(lb) + std_lower_barrier.(ub))
return v
end
return x::Vector{Float64} -> f(x) .+ barrier(x)
end
2024-03-15 13:08:33 +01:00
end # module Utils