Subsample Na dataset with DPP and fit energies with ACE
Setup experiment
Load packages.
using Unitful, UnitfulAtomic
using AtomsBase, InteratomicPotentials, PotentialLearning
using LinearAlgebra, Plots
Define paths.
base_path = haskey(ENV, "BASE_PATH") ? ENV["BASE_PATH"] : "../../"
ds_path = "$base_path/examples/data/Na/liquify_sodium.yaml";
Load datasets
Load atomistic dataset: atomistic configurations (atom positions, geometry, etc.) + DFT data (energies, forces, etc.).
confs, thermo = load_data(ds_path, YAML(:Na, u"eV", u"Å"))
confs, thermo = confs[220:end], thermo[220:end]
(DataSet{num_configs = 2092}
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces, Energy}
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces, Energy}
⋮
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces, Energy}, Vector{Dict}[[Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[109500, -118.56754033566662, 1.2468911443758697, -117.32064919129076, 90.1529877225114, 4481.320775643779, -8.243380533904219, 8.243380533904219, -8.243380533904219, 8.243380533904219, -8.243380533904219, 8.243380533904219])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[110000, -118.8191048880354, 1.3088999144548927, -117.5102049735805, 94.63635895570786, 4445.796913398657, -8.22154069868183, 8.22154069868183, -8.22154069868183, 8.22154069868183, -8.22154069868183, 8.22154069868183])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[110500, -118.66706916398354, 1.1040005887840023, -117.56306857519954, 79.82168449524809, 4443.609994398957, -8.220192399538881, 8.220192399538881, -8.220192399538881, 8.220192399538881, -8.220192399538881, 8.220192399538881])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[111000, -118.5424954601734, 1.253547941582279, -117.28894751859112, 90.63428888462772, 4457.032418979577, -8.228460752349738, 8.228460752349738, -8.228460752349738, 8.228460752349738, -8.228460752349738, 8.228460752349738])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[111500, -118.60598321165922, 1.2766945282051063, -117.32928868345411, 92.30783829512603, 4472.0581396906955, -8.23769708292504, 8.23769708292504, -8.23769708292504, 8.23769708292504, -8.23769708292504, 8.23769708292504])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[112000, -118.85324930143875, 1.3425751052444483, -117.5106741961943, 97.0711497355587, 4466.470566415444, -8.234264813932816, 8.234264813932816, -8.234264813932816, 8.234264813932816, -8.234264813932816, 8.234264813932816])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[112500, -118.71245699254627, 1.1670894723589, -117.54536752018737, 84.38315032328688, 4468.368033253137, -8.235430688433015, 8.235430688433015, -8.235430688433015, 8.235430688433015, -8.235430688433015, 8.235430688433015])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[113000, -118.76694882093753, 1.2319510396827227, -117.5349977812548, 89.07278510735168, 4446.342141937924, -8.221876779064969, 8.221876779064969, -8.221876779064969, 8.221876779064969, -8.221876779064969, 8.221876779064969])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[113500, -118.51709527877453, 1.3207562411255773, -117.19633903764895, 95.49359759887054, 4437.23245692403, -8.216257934441625, 8.216257934441625, -8.216257934441625, 8.216257934441625, -8.216257934441625, 8.216257934441625])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[114000, -118.72667005007821, 1.2772963700128563, -117.44937368006535, 92.35135278903331, 4410.550998548226, -8.199756461230859, 8.199756461230859, -8.199756461230859, 8.199756461230859, -8.199756461230859, 8.199756461230859])] … [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1150500, -111.33274572565541, 4.329760927088277, -107.00298479856714, 313.05129197672164, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1151000, -112.05085995319563, 4.005480215432024, -108.0453797377636, 289.60507925120993, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1151500, -111.93748415461936, 4.854075903442631, -107.08340825117672, 350.9604244933883, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1152000, -113.07999770701338, 4.265828336045448, -108.81416937096793, 308.4288242325647, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1152500, -112.41525332679798, 4.987262862658069, -107.42799046413991, 360.5901362393629, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1153000, -111.8906050438705, 4.365236171743513, -107.52536887212698, 315.61623063257963, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1153500, -111.78815670229865, 4.611141339238888, -107.17701536305977, 333.395718157298, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1154000, -112.488679423801, 4.081705353887927, -108.40697406991308, 295.11632536307604, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1154500, -111.6431739577287, 3.906342818047219, -107.73683113968148, 282.43722613942305, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1155000, -111.57970314361862, 4.291823239871932, -107.28787990374668, 310.3083132770473, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])]])
Split atomistic dataset into training and test.
conf_train, conf_test = confs[1:1000], confs[1001:end]
(DataSet{num_configs = 1000}
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces, Energy}
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces, Energy}
⋮
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces, Energy}, DataSet{num_configs = 1092}
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces, Energy}
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces, Energy}
⋮
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces, Energy})
Compute descriptors
Create ACE basis.
ace = ACE(species = [:Na], # species
body_order = 4, # 4-body
polynomial_degree = 8, # 8 degree polynomials
wL = 1.0, # Defaults, See ACE.jl documentation
csp = 1.0, # Defaults, See ACE.jl documentation
r0 = 1.0, # minimum distance between atoms
rcutoff = 5.0); # cutoff radius
Update training dataset by adding energy (local) descriptors.
println("Computing local descriptors of training dataset")
e_descr_train = compute_local_descriptors(conf_train, ace) # JLD.load("data/sodium_empirical_full.jld", "descriptors")
1000-element Vector{LocalDescriptors}:
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
⋮
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
LocalDescriptors{n = 108, d = 73}
Update training dataset by adding energy and force descriptors.
ds_train = DataSet(conf_train .+ e_descr_train);
Subsample dataset
Create DPP subselector.
dpp = kDPP(ds_train, GlobalMean(), DotProduct(); batch_size = 200)
kDPP(L-Ensemble.
Number of items in ground set : 1000. Max. rank : 1000. Rescaling constant α=4.306186451810486e14
, 200)
Subsample trainig dataset.
dpp_inds = get_random_subset(dpp);
Learn coefficients
Learn ACE coefficients based on ACE descriptors and DFT data.
lb = LBasisPotential(ace)
α = 1e-8
Σ = learn!(lb, ds_train[dpp_inds], α)
UnivariateLinearProblem{T, [-0.08273713775021818, 0.37819839276295397, -0.8777281958431047, 1.2847171185742354, -1.1127498943486955, 0.14315555344973063, 0.8593470514175579, 0.49789621148454444, -0.0034131801626555713, 0.010987098725977695, -0.02483526704351391, 0.032202842849628155, -0.009564484287341912, -0.07004422522778664, 0.22015100084357897, -0.009393509723934823, 0.034881382814299755, -0.10078942035718086, 0.21036686662968407, -0.34993070952597805, 0.002589784750054114, -0.009070490586691182, 0.02398256152985745, -0.003314367653771022, -0.06806104086562925, 0.0997939174953757, -0.09652412523026044, -0.09066579234406191, 0.6026250788700338, -0.2592123608849022, 0.5300771091706622, -1.1076826321980775, 0.06347046789255373, -0.22983858537847368, -0.01973458597549893, 0.8422200235268953, -2.170628964645971, 1.2654386964170001, -0.00027592899613082533, 0.0008031416055883089, -0.002183662311987966, 0.0041689644671645265, -0.006065514359706814, 0.006742908635569922, -0.00017060210131653042, 0.0006320672648969239, -0.0018203041612477877, 0.0037814800966073925, 5.7281837505689104e-5, -0.00020252282863296445, -0.00402132967886959, 0.007645799808912557, -0.014331289342802478, 0.020174624897030924, -0.004686390378841221, 0.009551748241190872, -0.02854324509518591, 0.03665089806494848, 3.307553431170819e-5, 0.0009280233847975827, -0.0001671057241870305, -0.002821332014792066, 0.006320026535190593, -0.00012065498084036508, -0.003438877198171486, 0.009906274931903436, 0.010457212611695077, -0.0003062355804316075, 0.03459158016536534, -0.05298255428925813, 0.09705245700021692, 0.025501746149979354, 0.13737821669880645], [0.2425046155083757]}
Post-process results
Update test dataset by adding energy descriptors.
println("Computing local descriptors of test dataset")
e_descr_test = compute_local_descriptors(conf_test, ace, pbar = false)
ds_test = DataSet(conf_test .+ e_descr_test);
Computing local descriptors of test dataset
Get true and predicted energy values (assuming that all configurations have the same no. of atoms).
n = size(get_system(ds_train[1]))[1]
e_train, e_train_pred = get_all_energies(ds_train)/n, get_all_energies(ds_train, lb)/n
e_test, e_test_pred = get_all_energies(ds_test)/n, get_all_energies(ds_test, lb)/n;
Compute and print metrics.
e_mae, e_rmse, e_rsq = calc_metrics(e_train, e_train_pred)
println("MAE: $e_mae, RMSE: $e_rmse, RSQ: $e_rsq")
MAE: 0.004606389186659766, RMSE: 0.006040815917254314, RSQ: 0.8187078227365816
Plot energy error.
e_err_train, e_err_test = (e_train_pred - e_train), (e_test_pred - e_test)
dpp_inds2 = get_random_subset(dpp; batch_size = 20)
p = scatter( e_train, e_err_train, label = "Training", color = :blue,
markersize = 1.5, markerstrokewidth=0)
scatter!(e_test, e_err_test, label = "Test", color = :yellow,
markersize = 1.5, markerstrokewidth=0)
scatter!(e_train[dpp_inds2], e_err_train[dpp_inds2],
color = :darkred, label = "DPP Samples",
markersize = 2.5, markerstrokewidth=0)
scatter!(xlabel = "Energy (eV/atom)", ylabel = "Error (eV/atom)",
dpi = 1000, fontsize = 16)
This page was generated using Literate.jl.