tsa_saxs/Utils.jl

236 lines
6.6 KiB
Julia
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

module Utils
import TOML: parsefile
import DelimitedFiles: readdlm
import DSP
import Printf: @sprintf
import Crayons as C
const BYELLOW = C.Crayon(foreground=:yellow, bold=true)
const BRED = C.Crayon(foreground=:light_red, bold=true)
const RESET = C.Crayon(reset=true)
function get_free_parameters_idx(metadata::Dict)
m = metadata["parameters"]
return filter(i -> !m[i]["fixed"], eachindex(m))
end
function reduce_to_free_parameters(metadata::Dict, f::Function, params_init::NamedTuple, lb, ub, q)
default_params = Float64[params_init...]
free_idx = get_free_parameters_idx(metadata)
P_reduced = view(default_params, free_idx)
lb_reduced = lb[free_idx]
ub_reduced = ub[free_idx]
function f_reduced(P::AbstractArray{Float64})
P_reduced .= P
return f(default_params, q)
end
return f_reduced, P_reduced, lb_reduced, ub_reduced
end
function reduce_to_free_parameters(metadata::Dict, params::Vector{Float64})
free_idx = get_free_parameters_idx(metadata)
return params[free_idx]
end
function load_init_params(filename::String)
meta = parsefile(filename)["metadata"]
free_idc = get_free_parameters_idx(meta)
param_names = [p["symbol"] for p in meta["parameters"]]
param_initial = [p["initial_value"] for p in meta["parameters"]]
params = NamedTuple{Tuple(map(Symbol, param_names))}(param_initial)
lower_bounds = [p["min"] for p in meta["parameters"]]
upper_bounds = [p["max"] for p in meta["parameters"]]
free_idx = 1
println("Got $(length(param_names)) parameters:")
for (i, k) in enumerate(eachindex(param_names))
lb_str = @sprintf "%.3g" lower_bounds[i]
ub_str = @sprintf "%.3g" upper_bounds[i]
p = @sprintf "%.3g" param_initial[i]
if i in free_idc
free_idx_str = "$(lpad(free_idx, 2))"
free_idx += 1
else
free_idx_str = " "
end
println("[", BYELLOW, free_idx_str, RESET, "/$(lpad(i, 2))] $(lpad(param_names[i], 7)): $(lpad(lb_str, 7)) <= $(lpad(p, 7)) <= $(lpad(ub_str, 7)) ")
end
return (meta=meta, params=params, lower_bounds=lower_bounds, upper_bounds=upper_bounds)
end
function select_columns(QIE_array, qmin, qmax, bin, err_mul)
#""Select q / I(q) / err(q) data - Change I(q) value from mm^-1 to nm^-1'
Q = Float64[]
IDATA = Float64[]
ERR = Float64[]
for i in axes(QIE_array, 1)
# keep only the values of q within bounds
if qmin <= QIE_array[i, 1] <= qmax
# skip all but the (k*bin)-th values
if i % bin == 0
if QIE_array[i, 2] > 0
push!(Q, QIE_array[i, 1])
push!(IDATA, 1e-6 * QIE_array[i, 2])
if err_mul == 0
push!(ERR, 0.01 * 1e-6 * QIE_array[i, 2])
elseif QIE_array[i, 3] * err_mul > 0.01 * QIE_array[i, 2]
push!(ERR, 1e-6 * QIE_array[i, 3] * err_mul)
else
push!(ERR, 0.01 * 1e-6 * QIE_array[i, 2])
end
end
end
end
end
return (Q, IDATA, ERR)
end
"""
compute_logscale_weights(q)
Compute the weights W, which are proportional to the distance in the log space
"""
function compute_logscale_weights(q)
diff_log_q = diff(log10.(q))
left_diff_log_q = view(diff_log_q, [firstindex(diff_log_q); eachindex(diff_log_q)])
right_diff_log_q = view(diff_log_q, [eachindex(diff_log_q); lastindex(diff_log_q)])
return left_diff_log_q + right_diff_log_q
end
function load_config(filename::String)
meta, params_init, lower_bounds, upper_bounds = load_init_params(filename)
f = open(meta["data_file"], "r")
qie_data = readdlm(f, '\t')
close(f)
return (meta=meta,
params_init=params_init,
lower_bounds=lower_bounds,
upper_bounds=upper_bounds,
qie_deta=qie_data)
end
function lowpass_filter(i; σ=1)
half_w = 10
x = range(-5, 5; length=2half_w + 1)
i_padded = view(i, [fill(firstindex(i), half_w); eachindex(i); fill(lastindex(i), half_w)])
k = exp.(-0.5 * ((x / σ) .^ 2)) / (2 * σ * sqrt(2π))
k ./= sum(k)
r = DSP.conv(i_padded, k)[2half_w+1:end-2half_w]
return r
end
function residuals(I_data, I_model, err)
return (I_data .- I_model ./ err)
end
function chi2(I_data, I_model, err)
return sum((I_data .- I_model ./ err) .^ 2)
end
function add_log_barriers(f::Function, constraints::Vector{Tuple{T,T}}; padding_factors=0.1, mode::Symbol=:outer) where {T<:Real}
n = length(constraints)
@assert mode in [:outer, :inner]
lower_shifts = Float64[]
upper_shifts = Float64[]
pf = if padding_factors isa Vector
padding_factors
else
fill(padding_factors, n)
end
k = 0
barrier_widths = map(constraints) do (lower, upper)
k += 1
if upper <= lower
throw("admissible set is empty, check constraints")
end
if isfinite(lower)
push!(lower_shifts, lower)
if isfinite(upper)
push!(upper_shifts, upper)
return pf[k] * (upper - lower)
else
push!(upper_shifts, Inf)
return 1.0
end
else
push!(lower_shifts, -Inf)
if isfinite(upper)
push!(upper_shifts, upper)
return 1.0
else
push!(upper_shifts, Inf)
return 1.0
end
end
end
barrier_sign = mode == :outer ? 1.0 : -1.0
function std_lower_barrier(s::Float64)
return max(-log(s + barrier_sign), 0.0)^2
end
function barrier(x::Vector{Float64})
lb = (x .- lower_shifts) ./ barrier_widths
ub = (upper_shifts .- x) ./ barrier_widths
lb_violations = findall(<=(-barrier_sign), lb)
ub_violations = findall(<=(-barrier_sign), ub)
if !isempty(lb_violations)
err_msg = ""
for i in lb_violations
err_msg *= "\n · x[$i] = $(x[i]) < $(constraints[i][1]), width = $(barrier_widths[i])"
end
@error err_msg
throw("lower bound violation for indice(s) $(lb_violations)")
end
if !isempty(ub_violations)
err_msg = ""
for i in ub_violations
err_msg *= "\n · x[$i] = $(x[i]) > $(constraints[i][2]), width = $(barrier_widths[i])"
end
@error err_msg
throw("upper bound violation for indice(s) $(ub_violations)")
end
v = sum(std_lower_barrier.(lb) + std_lower_barrier.(ub))
return v
end
return x::Vector{Float64} -> f(x) .+ barrier(x)
end
function print_check_bounds(param_names, param_values, lb, ub)
for (i, v) in enumerate(param_values)
lb_str = @sprintf "%.3g" lb[i]
ub_str = @sprintf "%.3g" ub[i]
p = @sprintf "%.3g" v
crayon = if lb[i] <= v <= ub[i]
RESET
else
BRED
end
println("$(lpad(param_names[i], 7)): ", crayon, "$(lpad(lb_str, 7)) <= $(lpad(p, 7)) <= $(lpad(ub_str, 7)) ")
end
end
function box_projector(lower::Vector{T}, upper::Vector{T}) where {T<:Real}
function clamp(x::Vector{Float64})
return max.(min.(x, upper), lower)
end
return clamp
end
end # module Utils