tsa_saxs/Utils.jl

159 lines
4.3 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

module Utils
import TOML: parsefile
import DelimitedFiles: readdlm
import DSP
function load_init_params(filename::String)
meta = parsefile(filename)["metadata"]
params = NamedTuple(Dict(Symbol(p["symbol"]) => p["initial_value"] for p in meta["parameters"]))
lower_bounds = [p["min"] for p in meta["parameters"]]
upper_bounds = [p["max"] for p in meta["parameters"]]
return (meta=meta, params=params, lower_bounds=lower_bounds, upper_bounds=upper_bounds)
end
function select_columns(QIE_array, qmin, qmax, bin, err_mul)
#""Select q / I(q) / err(q) data - Change I(q) value from mm^-1 to nm^-1'
Q = Float64[]
IDATA = Float64[]
ERR = Float64[]
for i in axes(QIE_array, 1)
# keep only the values of q within bounds
if qmin <= QIE_array[i, 1] <= qmax
# skip all but the (k*bin)-th values
if i % bin == 0
if QIE_array[i, 2] > 0
push!(Q, QIE_array[i, 1])
push!(IDATA, 1e-6 * QIE_array[i, 2])
if err_mul == 0
push!(ERR, 0.01 * 1e-6 * QIE_array[i, 2])
elseif QIE_array[i, 3] * err_mul > 0.01 * QIE_array[i, 2]
push!(ERR, 1e-6 * QIE_array[i, 3] * err_mul)
else
push!(ERR, 0.01 * 1e-6 * QIE_array[i, 2])
end
end
end
end
end
return (Q, IDATA, ERR)
end
"""
compute_logscale_weights(q)
Compute the weights W, which are proportional to the distance in the log space
"""
function compute_logscale_weights(q)
diff_log_q = diff(log10.(q))
left_diff_log_q = view(diff_log_q, [firstindex(diff_log_q); eachindex(diff_log_q)])
right_diff_log_q = view(diff_log_q, [eachindex(diff_log_q); lastindex(diff_log_q)])
return left_diff_log_q + right_diff_log_q
end
function load_config(filename::String)
meta, params_init, lower_bounds, upper_bounds = load_init_params(filename)
f = open(meta["data_file"], "r")
qie_data = readdlm(f, '\t')
close(f)
return (meta=meta,
params_init=params_init,
lower_bounds=lower_bounds,
upper_bounds=upper_bounds,
qie_deta=qie_data)
end
function lowpass_filter(i; σ=1)
half_w = 10
x = range(-5, 5; length=2half_w + 1)
i_padded = view(i, [fill(firstindex(i), half_w); eachindex(i); fill(lastindex(i), half_w)])
k = exp.(-0.5 * ((x / σ) .^ 2)) / (2 * σ * sqrt(2π))
k ./= sum(k)
r = DSP.conv(i_padded, k)[2half_w+1:end-2half_w]
return r
end
function chi2(I_data, I_model, err)
return sum((I_data .- I_model ./ err) .^ 2)
end
function add_log_barriers(f::Function, constraints::Vector{Tuple{T,T}}; padding_factor=0.1) where {T<:Real}
n = length(constraints)
lower_shifts = Float64[]
upper_shifts = Float64[]
padding_factors = if padding_factor isa Vector
padding_factor
else
fill(padding_factor, n)
end
k = 0
barrier_widths = map(constraints) do (lower, upper)
k += 1
if upper <= lower
throw("admissible set is empty, check constraints")
end
if isfinite(lower)
push!(lower_shifts, lower)
if isfinite(upper)
push!(upper_shifts, upper)
return padding_factors[k] * (upper - lower)
else
push!(upper_shifts, Inf)
return 1.0
end
else
push!(lower_shifts, -Inf)
if isfinite(upper)
push!(upper_shifts, upper)
return 1.0
else
push!(upper_shifts, Inf)
return 1.0
end
end
end
function std_lower_barrier(s::Float64)
return max(-log(s + 1.0), 0.0)^2
end
function barrier(x::Vector{Float64})
lb = (x .- lower_shifts) ./ barrier_widths
ub = (upper_shifts .- x) ./ barrier_widths
lb_violations = findall(<=(-1.0), lb)
ub_violations = findall(<=(-1.0), ub)
if !isempty(lb_violations)
err_msg = ""
for i in lb_violations
err_msg *= " · x[$i] = $(x[i]) < $(constraints[i][1]), width = $(barrier_widths[i])"
end
@error err_msg
throw("lower bound violation for indice(s) $(lb_violations)")
end
if !isempty(ub_violations)
err_msg = ""
for i in ub_violations
err_msg *= " · x[$i] = $(x[i]) > $(constraints[i][2]), width = $(barrier_widths[i])"
end
@error err_msg
throw("upper bound violation for indice(s) $(ub_violations)")
end
v = sum(std_lower_barrier.(lb) + std_lower_barrier.(ub))
return v
end
return x::Vector{Float64} -> f(x) .+ barrier(x)
end
end # module Utils