2024-12-18 17:07:43 +01:00
|
|
|
module OptimMixture
|
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
import GLMakie as GLM
|
2024-12-18 17:07:43 +01:00
|
|
|
import DelimitedFiles as DF
|
|
|
|
import BenchmarkTools as BT
|
|
|
|
import UnicodePlots as UP
|
2024-12-20 11:16:24 +01:00
|
|
|
import NLopt
|
|
|
|
import LinearAlgebra: norm, mul!, axpy!, axpby!
|
|
|
|
|
|
|
|
import Logging
|
|
|
|
Logging.global_logger(Logging.ConsoleLogger(stderr, Logging.Info))
|
2024-12-18 17:07:43 +01:00
|
|
|
|
|
|
|
const ENERGIES = [2803, 2811, 2819, 2826, 2834, 2842, 2850, 2858, 2866, 2874,
|
|
|
|
2882, 2890, 2897, 2905, 2913, 2921, 2929, 2937, 2945, 2953,
|
|
|
|
2961, 2969, 2977, 2985, 2993, 3001, 3009, 3018, 3026, 3034,
|
|
|
|
3042, 3050]
|
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
@kwdef struct Slab{T<:Real}
|
2024-12-18 17:07:43 +01:00
|
|
|
energy::Int
|
2024-12-20 11:16:24 +01:00
|
|
|
data::Matrix{T}
|
2024-12-18 17:07:43 +01:00
|
|
|
end
|
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
function load_slabs(; T::Type=Int, width::Int=512, height::Int=512)
|
|
|
|
slabs = Slab[]
|
2024-12-18 17:07:43 +01:00
|
|
|
i_min, i_max = typemax(Int), 0
|
|
|
|
for e in ENERGIES
|
2024-12-20 11:16:24 +01:00
|
|
|
slab::Slab{T} = Slab{T}(energy=e, data=DF.readdlm("test_sample/HeLa_F-SRS_512x512_$(e)cm-1.txt", ',', T))
|
2024-12-18 17:07:43 +01:00
|
|
|
push!(slabs, slab)
|
|
|
|
|
|
|
|
i_min, i_max = min(i_min, minimum(slab.data)), max(i_max, maximum(slab.data))
|
|
|
|
|
|
|
|
end
|
2024-12-20 11:16:24 +01:00
|
|
|
s = join(size(slabs[1].data), "x")
|
|
|
|
@info "Loaded $(length(slabs)) $s slabs"
|
|
|
|
return slabs, i_min, i_max
|
2024-12-18 17:07:43 +01:00
|
|
|
end
|
|
|
|
|
|
|
|
function E_loop(X::Matrix{Float64}, Y::Matrix{Float64})
|
|
|
|
s = 0
|
|
|
|
N, M = size(Y)
|
|
|
|
m = size(X, 2)
|
|
|
|
for i in 1:N
|
|
|
|
for j in N+1:N+M
|
|
|
|
x = 0
|
|
|
|
for k in 1:m
|
|
|
|
@inbounds x += X[i, k] * X[j, k]
|
|
|
|
end
|
|
|
|
@inbounds s += (x - Y[i, j-N])^2
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return s
|
|
|
|
end
|
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
function E!(X::Matrix{Float64}, Y::Matrix{Float64}, tmp::Matrix{Float64})
|
|
|
|
N, M = size(Y)
|
|
|
|
Q!(tmp, X, N)
|
|
|
|
# roughly 6ms with N = 512*512, M = 32, m = 2
|
|
|
|
# slighly faster than tmp .= tmp .- Y (7ms)
|
|
|
|
# slighly faster than tmp .-= Y (7ms)
|
|
|
|
axpy!(-1.0, Y, tmp)
|
|
|
|
# roughly 2.5m
|
|
|
|
# slightly faster than norm2(tmp)
|
|
|
|
s = sum(x -> x^2, tmp)
|
|
|
|
return s
|
|
|
|
end
|
|
|
|
|
|
|
|
function test_E(N::Int, M::Int, m::Int; benchmark::Bool=true)
|
|
|
|
X = rand(N + M, m)
|
|
|
|
Y = rand(N, M)
|
|
|
|
|
|
|
|
tmp = zeros(N, M)
|
|
|
|
|
|
|
|
E_from_loop = E_loop(X, Y)
|
|
|
|
E_from_mul = E!(X, Y, tmp)
|
|
|
|
|
|
|
|
@show abs(E_from_mul - E_from_loop)
|
|
|
|
|
|
|
|
if benchmark
|
|
|
|
@info "Running benchmarks..., this may take a while"
|
|
|
|
show(stdout, "text/plain", BT.@benchmark E_loop($X, $Y))
|
|
|
|
println()
|
|
|
|
show(stdout, "text/plain", BT.@benchmark E!($X, $Y, $tmp))
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
function Q!(dst::Matrix{Float64}, X::Matrix{Float64}, N::Int)
|
|
|
|
N_plus_M, _ = size(X)
|
|
|
|
# roughly 9ms with N = 512*512, M = 32, m = 2
|
|
|
|
# could be sped up using StrideArray?
|
|
|
|
mul!(dst, view(X, 1:N, :), view(X, N+1:N_plus_M, :)')
|
|
|
|
end
|
|
|
|
|
2024-12-18 17:07:43 +01:00
|
|
|
function Q_loop!(dst::Matrix{Float64}, X::Matrix{Float64}, N::Int)
|
|
|
|
N_plus_M, m = size(X)
|
2024-12-20 11:16:24 +01:00
|
|
|
# roughly 23.5ms with N = 512*512, M = 32, m = 2
|
2024-12-18 17:07:43 +01:00
|
|
|
for i in 1:N
|
|
|
|
for j in N+1:N_plus_M
|
|
|
|
x = 0
|
|
|
|
for k in 1:m
|
|
|
|
@inbounds x += X[i, k] * X[j, k]
|
|
|
|
end
|
|
|
|
@inbounds dst[i, j-N] = x
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
function test_Q(N::Int, M::Int, m::Int; benchmark::Bool=true)
|
|
|
|
X = rand(N + M, m)
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
Q_from_loop = zeros(N, M)
|
|
|
|
Q_from_mul = zeros(N, M)
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
Q!(Q_from_mul, X, N)
|
|
|
|
Q_loop!(Q_from_loop, X, N)
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
@show norm(Q_from_mul - Q_from_loop)
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
if benchmark
|
|
|
|
@info "Running benchmarks..., this may take a while"
|
|
|
|
show(stdout, "text/plain", BT.@benchmark Q!($Q_from_mul, $X, $N))
|
|
|
|
println()
|
|
|
|
show(stdout, "text/plain", BT.@benchmark Q_loop!($Q_from_loop, $X, $N))
|
|
|
|
end
|
2024-12-18 17:07:43 +01:00
|
|
|
end
|
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
function DE_loop!(dst::Matrix{Float64}, X::Matrix{Float64}, Y::Matrix{Float64})
|
|
|
|
dst_W = zeros(size(Y))
|
|
|
|
DE_loop!!(dst, dst_W, X, Y)
|
|
|
|
end
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
function DE!!(dst::Matrix{Float64}, dst_W::Matrix{Float64}, X::Matrix{Float64}, Y::Matrix{Float64})
|
|
|
|
N_plus_M, m = size(X)
|
|
|
|
N, M = size(Y)
|
|
|
|
# dst is the same size as X
|
|
|
|
# we need storage of size Y
|
|
|
|
Q!(dst_W, X, N)
|
|
|
|
# compute W = 2(Q(X) - Y)
|
|
|
|
axpby!(-2.0, Y, 2.0, dst_W)
|
|
|
|
mul!(view(dst, 1:N, :), dst_W, view(X, N+1:N_plus_M, :))
|
|
|
|
mul!(view(dst, N+1:N_plus_M, :), dst_W', view(X, 1:N, :))
|
2024-12-18 17:07:43 +01:00
|
|
|
end
|
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
function DE_loop!!(dst::Matrix{Float64}, dst_W::Matrix{Float64}, X::Matrix{Float64}, Y::Matrix{Float64})
|
2024-12-18 17:07:43 +01:00
|
|
|
_, m = size(X)
|
|
|
|
N, M = size(Y)
|
|
|
|
# dst is the same size as X
|
|
|
|
# we need storage of size Y
|
|
|
|
Q_loop!(dst_W, X, N)
|
|
|
|
# compute W = Q(X) - Y
|
|
|
|
dst_W .-= Y
|
|
|
|
for k in 1:m
|
|
|
|
for i in 1:N
|
|
|
|
x = 0
|
|
|
|
for j in 1:M
|
|
|
|
@inbounds x += dst_W[i, j] * X[N+j, k]
|
|
|
|
end
|
|
|
|
@inbounds dst[i, k] = x
|
|
|
|
end
|
|
|
|
for j in 1:M
|
|
|
|
x = 0
|
|
|
|
for i in 1:N
|
|
|
|
@inbounds x += dst_W[i, j] * X[i, k]
|
|
|
|
end
|
|
|
|
@inbounds dst[N+j, k] = x
|
|
|
|
end
|
|
|
|
end
|
|
|
|
dst .*= 2
|
|
|
|
end
|
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
function check_derivative_E(N::Int, M::Int, m::Int; loop::Bool=false)
|
2024-12-18 17:07:43 +01:00
|
|
|
X = rand(N + M, m)
|
|
|
|
Y = rand(N, M)
|
|
|
|
V = rand(N + M, m) .- 0.5
|
|
|
|
|
|
|
|
EX = E_loop(X, Y)
|
|
|
|
grad_EX = zeros(size(X))
|
2024-12-20 11:16:24 +01:00
|
|
|
tmp = zeros(size(Y))
|
|
|
|
if loop
|
|
|
|
DE_loop!!(grad_EX, tmp, X, Y)
|
|
|
|
else
|
|
|
|
DE!!(grad_EX, tmp, X, Y)
|
|
|
|
end
|
2024-12-18 17:07:43 +01:00
|
|
|
|
|
|
|
eps = [10.0^k for k in 1:-1:-12]
|
|
|
|
errors = zeros(size(eps))
|
|
|
|
|
|
|
|
for i in eachindex(eps)
|
2024-12-20 11:16:24 +01:00
|
|
|
if loop
|
|
|
|
EX_V_true = E_loop(X + eps[i] * V, Y)
|
|
|
|
else
|
|
|
|
EX_V_true = E!(X + eps[i] * V, Y, tmp)
|
|
|
|
end
|
2024-12-18 17:07:43 +01:00
|
|
|
EX_V_approx = EX + eps[i] * sum(grad_EX .* V)
|
2024-12-20 11:16:24 +01:00
|
|
|
# @show EX_V_true
|
|
|
|
# @show EX_V_approx
|
2024-12-18 17:07:43 +01:00
|
|
|
errors[i] = abs(EX_V_true - EX_V_approx)
|
|
|
|
end
|
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
# @show eps
|
|
|
|
# @show errors
|
2024-12-18 17:07:43 +01:00
|
|
|
|
|
|
|
foo = UP.lineplot(eps, errors, xscale=:log10, yscale=:log10)
|
|
|
|
UP.lineplot!(foo, eps, eps)
|
|
|
|
println(foo)
|
|
|
|
end
|
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
function test_DE(N::Int, M::Int, m::Int; benchmark::Bool=true)
|
|
|
|
X = rand(N + M, m)
|
|
|
|
Y = rand(N, M)
|
|
|
|
|
|
|
|
tmp_W = zeros(N, M)
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
∇E_loop = zeros(N + M, m)
|
|
|
|
∇E = zeros(N + M, m)
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
DE_loop!!(∇E_loop, tmp_W, X, Y)
|
|
|
|
DE!!(∇E, tmp_W, X, Y)
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
@show norm(∇E_loop - ∇E)
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
if benchmark
|
|
|
|
@info "Running benchmarks..., this may take a while"
|
|
|
|
show(stdout, "text/plain", BT.@benchmark DE_loop!!($∇E_loop, $tmp_W, $X, $Y))
|
|
|
|
println()
|
|
|
|
show(stdout, "text/plain", BT.@benchmark DE!!($∇E, $tmp_W, $X, $Y))
|
2024-12-18 17:07:43 +01:00
|
|
|
end
|
2024-12-20 11:16:24 +01:00
|
|
|
end
|
|
|
|
|
|
|
|
function benchmark()
|
|
|
|
slabs, _, _ = load_slabs(T=Float64)
|
|
|
|
Y::Matrix{Float64} = hcat([vec(slab.data) for slab in slabs]...)
|
|
|
|
@show size(Y)
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
N::Int, M::Int = size(Y)
|
|
|
|
m::Int = 2
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
X::Vector{Float64} = fill(0.5, (N + M) * m)
|
|
|
|
∇E::Vector{Float64} = zeros((N + M) * m)
|
|
|
|
tmp = zeros(size(Y))
|
|
|
|
|
|
|
|
show(stdout, "text/plain", BT.@benchmark DE_loop!!(reshape($∇E, $N + $M, $m), $tmp, reshape($X, $N + $M, $m), reshape($Y, $N, $M)))
|
|
|
|
println()
|
|
|
|
show(stdout, "text/plain", BT.@benchmark E_loop(reshape($X, $N + $M, $m), reshape($Y, $N, $M)))
|
2024-12-18 17:07:43 +01:00
|
|
|
end
|
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
function plot_result_mixture_1(X::Vector{Float64}, N::Int, M::Int, m::Int)
|
|
|
|
X_ = reshape(X, N + M, m)
|
|
|
|
λ = view(X_, 1:N, :)
|
|
|
|
s = view(X_, N+1:N+M, :)
|
|
|
|
|
|
|
|
@show size(λ)
|
|
|
|
@show size(s)
|
|
|
|
|
|
|
|
fig = GLM.Figure()
|
|
|
|
axes = map(1:m+1) do i
|
|
|
|
aspect = if i < m + 1
|
|
|
|
GLM.DataAspect()
|
|
|
|
else
|
|
|
|
nothing
|
|
|
|
end
|
|
|
|
GLM.Axis(fig[1, i], aspect=aspect)
|
|
|
|
end
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
n = Int(sqrt(N))
|
2024-12-18 17:07:43 +01:00
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
# Component density
|
|
|
|
GLM.image!(axes[1], reshape(view(λ, :, 1), n, n))
|
|
|
|
GLM.image!(axes[2], reshape(view(λ, :, 2), n, n))
|
|
|
|
|
|
|
|
# Spectra
|
|
|
|
GLM.lines!(axes[3], s[:, 1])
|
|
|
|
GLM.lines!(axes[3], s[:, 2])
|
|
|
|
|
|
|
|
@show norm(λ[:, 1] - λ[:, 2])
|
|
|
|
@show norm(s[:, 1] - s[:, 2])
|
|
|
|
|
|
|
|
display(fig)
|
|
|
|
end
|
|
|
|
|
|
|
|
function optim(algo::Symbol=:LD_MMA)
|
|
|
|
slabs, _, _ = load_slabs(T=Float64)
|
|
|
|
Y::Matrix{Float64} = hcat([vec(slab.data) for slab in slabs]...)
|
|
|
|
@show size(Y)
|
|
|
|
|
|
|
|
N::Int, M::Int = size(Y)
|
|
|
|
m::Int = 2
|
|
|
|
|
2024-12-20 12:42:30 +01:00
|
|
|
# Initial choice of X
|
|
|
|
# it seems that:
|
|
|
|
# * ∇E(0) = 0
|
|
|
|
# X::Vector{Float64} = zeros((N + M) * m)
|
|
|
|
# * choosing X = 0.5 and BFGS yields a relatively quick convergence to two identical components
|
|
|
|
# X::Vector{Float64} = fill(0.5, (N + M) * m)
|
|
|
|
X::Vector{Float64} = 0.5 .+ 0.1 * rand((N + M) * m)
|
2024-12-20 11:16:24 +01:00
|
|
|
∇E::Vector{Float64} = zeros((N + M) * m)
|
|
|
|
tmp = zeros(size(Y))
|
|
|
|
|
|
|
|
function objective(X::Vector{Float64}, grad::Vector{Float64})
|
|
|
|
if length(grad) > 0
|
|
|
|
DE!!(reshape(grad, N + M, m), tmp, reshape(X, N + M, m), reshape(Y, N, M))
|
|
|
|
else
|
|
|
|
@debug "No gradient requested!"
|
|
|
|
end
|
|
|
|
# @show "|∇E| = $(norm(grad))"
|
|
|
|
max_Σ_λ = maximum(sum(view(reshape(X, N + M, m), 1:N, :); dims=2))
|
|
|
|
value = E!(reshape(X, N + M, m), reshape(Y, N, M), tmp)
|
|
|
|
@info "E = $value, |∇E| = $(norm(grad)), max_Σ_λ = $max_Σ_λ"
|
|
|
|
return value
|
|
|
|
end
|
|
|
|
|
|
|
|
opt = NLopt.Opt(algo, (N + M) * m)
|
|
|
|
NLopt.lower_bounds!(opt, fill(0.0, (N + M) * m))
|
|
|
|
NLopt.upper_bounds!(opt, vec([fill(1.0, N, m); fill(Inf, M, m)]))
|
|
|
|
NLopt.min_objective!(opt, objective)
|
|
|
|
|
|
|
|
res = NLopt.optimize(opt, X)
|
|
|
|
|
2024-12-20 12:42:30 +01:00
|
|
|
@show opt
|
|
|
|
|
|
|
|
return res, opt
|
2024-12-18 17:07:43 +01:00
|
|
|
end
|
|
|
|
|
2024-12-20 11:16:24 +01:00
|
|
|
|
2024-12-18 17:07:43 +01:00
|
|
|
end # module
|
|
|
|
|
2024-12-20 12:42:30 +01:00
|
|
|
# possible algorithms from NLopt
|
|
|
|
# LD_LBFGS_NOCEDAL
|
|
|
|
# LD_LBFGS
|
|
|
|
# LD_VAR1
|
|
|
|
# LD_VAR2
|
|
|
|
# LD_TNEWTON
|
|
|
|
# LD_TNEWTON_RESTART
|
|
|
|
# LD_TNEWTON_PRECOND
|
|
|
|
# LD_TNEWTON_PRECOND_RESTART
|
|
|
|
# LD_MMA
|
|
|
|
# LD_AUGLAG
|
|
|
|
# LD_AUGLAG_EQ
|
|
|
|
# LD_SLSQP
|
|
|
|
# LD_CCSAQ
|
|
|
|
|
|
|
|
res, opt = OptimMixture.optim(:LD_LBFGS);
|
|
|
|
# res, opt = OptimMixture.optim(:LD_TNEWTON)
|
2024-12-20 11:16:24 +01:00
|
|
|
OptimMixture.plot_result_mixture_1(res[2], 512 * 512, 32, 2)
|
|
|
|
# OptimMixture.benchmark();
|
|
|
|
# OptimMixture.test_E(512 * 512, 32, 2)
|
|
|
|
# OptimMixture.test_DE(512 * 512, 32, 2)
|
|
|
|
# OptimMixture.test_Q(512 * 512, 32, 2)
|
|
|
|
# OptimMixture.check_derivative_E(512 * 512, 32, 2)
|
|
|
|
# OptimMixture.check_derivative_E(512 * 512, 32, 2; loop=true)
|
2024-12-18 17:07:43 +01:00
|
|
|
|
|
|
|
# vim: ts=2:sw=2:sts=2
|