Subsample Na dataset with DPP and fit energies with ACE

Setup experiment

Load packages.

using Unitful, UnitfulAtomic
using AtomsBase, InteratomicPotentials, PotentialLearning
using LinearAlgebra, Plots

Define paths.

base_path = haskey(ENV, "BASE_PATH") ? ENV["BASE_PATH"] : "../../"
ds_path   = "$base_path/examples/data/Na/liquify_sodium.yaml";

Load datasets

Load atomistic dataset: atomistic configurations (atom positions, geometry, etc.) + DFT data (energies, forces, etc.).

confs, thermo = load_data(ds_path, YAML(:Na, u"eV", u"Å"))
confs, thermo = confs[220:end], thermo[220:end]
(DataSet{num_configs = 2092} 
	 Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}
	 Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}
	 ⋮
	 Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}, Vector{Dict}[[Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[109500, -118.56754033566662, 1.2468911443758697, -117.32064919129076, 90.1529877225114, 4481.320775643779, -8.243380533904219, 8.243380533904219, -8.243380533904219, 8.243380533904219, -8.243380533904219, 8.243380533904219])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[110000, -118.8191048880354, 1.3088999144548927, -117.5102049735805, 94.63635895570786, 4445.796913398657, -8.22154069868183, 8.22154069868183, -8.22154069868183, 8.22154069868183, -8.22154069868183, 8.22154069868183])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[110500, -118.66706916398354, 1.1040005887840023, -117.56306857519954, 79.82168449524809, 4443.609994398957, -8.220192399538881, 8.220192399538881, -8.220192399538881, 8.220192399538881, -8.220192399538881, 8.220192399538881])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[111000, -118.5424954601734, 1.253547941582279, -117.28894751859112, 90.63428888462772, 4457.032418979577, -8.228460752349738, 8.228460752349738, -8.228460752349738, 8.228460752349738, -8.228460752349738, 8.228460752349738])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[111500, -118.60598321165922, 1.2766945282051063, -117.32928868345411, 92.30783829512603, 4472.0581396906955, -8.23769708292504, 8.23769708292504, -8.23769708292504, 8.23769708292504, -8.23769708292504, 8.23769708292504])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[112000, -118.85324930143875, 1.3425751052444483, -117.5106741961943, 97.0711497355587, 4466.470566415444, -8.234264813932816, 8.234264813932816, -8.234264813932816, 8.234264813932816, -8.234264813932816, 8.234264813932816])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[112500, -118.71245699254627, 1.1670894723589, -117.54536752018737, 84.38315032328688, 4468.368033253137, -8.235430688433015, 8.235430688433015, -8.235430688433015, 8.235430688433015, -8.235430688433015, 8.235430688433015])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[113000, -118.76694882093753, 1.2319510396827227, -117.5349977812548, 89.07278510735168, 4446.342141937924, -8.221876779064969, 8.221876779064969, -8.221876779064969, 8.221876779064969, -8.221876779064969, 8.221876779064969])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[113500, -118.51709527877453, 1.3207562411255773, -117.19633903764895, 95.49359759887054, 4437.23245692403, -8.216257934441625, 8.216257934441625, -8.216257934441625, 8.216257934441625, -8.216257934441625, 8.216257934441625])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[114000, -118.72667005007821, 1.2772963700128563, -117.44937368006535, 92.35135278903331, 4410.550998548226, -8.199756461230859, 8.199756461230859, -8.199756461230859, 8.199756461230859, -8.199756461230859, 8.199756461230859])]  …  [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1150500, -111.33274572565541, 4.329760927088277, -107.00298479856714, 313.05129197672164, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1151000, -112.05085995319563, 4.005480215432024, -108.0453797377636, 289.60507925120993, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1151500, -111.93748415461936, 4.854075903442631, -107.08340825117672, 350.9604244933883, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1152000, -113.07999770701338, 4.265828336045448, -108.81416937096793, 308.4288242325647, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1152500, -112.41525332679798, 4.987262862658069, -107.42799046413991, 360.5901362393629, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1153000, -111.8906050438705, 4.365236171743513, -107.52536887212698, 315.61623063257963, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1153500, -111.78815670229865, 4.611141339238888, -107.17701536305977, 333.395718157298, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1154000, -112.488679423801, 4.081705353887927, -108.40697406991308, 295.11632536307604, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1154500, -111.6431739577287, 3.906342818047219, -107.73683113968148, 282.43722613942305, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1155000, -111.57970314361862, 4.291823239871932, -107.28787990374668, 310.3083132770473, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])]])

Split atomistic dataset into training and test.

conf_train, conf_test = confs[1:1000], confs[1001:end]
(DataSet{num_configs = 1000} 
	 Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}
	 Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}
	 ⋮
	 Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}, DataSet{num_configs = 1092} 
	 Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}
	 Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}
	 ⋮
	 Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces})

Compute descriptors

Create ACE basis.

ace = ACE(species = [:Na],         # species
          body_order = 4,          # 4-body
          polynomial_degree = 8,   # 8 degree polynomials
          wL = 1.0,                # Defaults, See ACE.jl documentation
          csp = 1.0,               # Defaults, See ACE.jl documentation
          r0 = 1.0,                # minimum distance between atoms
          rcutoff = 5.0);          # cutoff radius

Update training dataset by adding energy (local) descriptors.

println("Computing local descriptors of training dataset")
e_descr_train = compute_local_descriptors(conf_train, ace) # JLD.load("data/sodium_empirical_full.jld", "descriptors")
1000-element Vector{LocalDescriptors}:
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 ⋮
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}

Update training dataset by adding energy and force descriptors.

ds_train = DataSet(conf_train .+ e_descr_train);

Subsample dataset

Create DPP subselector.

dpp = kDPP(ds_train, GlobalMean(), DotProduct(); batch_size = 200)
kDPP(L-Ensemble.
Number of items in ground set : 1000. Max. rank : 1000. Rescaling constant α=4.306186451810486e14
, 200)

Subsample trainig dataset.

dpp_inds = get_random_subset(dpp);

Learn coefficients

Learn ACE coefficients based on ACE descriptors and DFT data.

lb = LBasisPotential(ace)
α = 1e-8
Σ = learn!(lb, ds_train[dpp_inds], α)
UnivariateLinearProblem{T, [-0.10744744580701848, 0.4880419613416791, -1.1199953858567255, 1.6029667807401893, -1.3025081730019679, -0.027577163206558453, 1.2935327541074457, 0.7656942721550308, -0.004646566428006427, 0.015099297637311437, -0.0348753422523842, 0.048220038633658646, -0.026037122674654967, -0.06422741030422241, 0.23646780137512025, -0.009672940665872876, 0.03561313155523749, -0.10165674998479801, 0.208651792314825, -0.33920292887874837, 0.003912990507708658, -0.013710969813982565, 0.03626727002755814, -0.006402485839950067, -0.0948762985617293, 0.14418408093816026, -0.1611862264383035, -0.04112404021292093, 0.6319021574882271, -0.26239381744679235, 0.5300819731507431, -1.089330550791317, 0.09597744466446656, -0.3645689863071766, 0.0899138877538519, 0.8459873937902103, -2.14441321271579, 1.2379996078589848, -0.0003581093225604459, 0.0010420766710330653, -0.002834165873596195, 0.005421388898789736, -0.007937541452193386, 0.00899363023410249, -0.00018194099824352106, 0.0006685549975555816, -0.0019027861403679847, 0.0038885605793022027, 8.428473642890094e-5, -0.00029770996534901617, -0.005217088833979755, 0.009925330778656871, -0.018653350093210008, 0.02647645449294167, -0.004916145535921013, 0.009902069349768894, -0.037099145577460746, 0.04783470570502146, 3.953173522371935e-5, 0.0009856835104735118, -0.00019871420779452031, -0.0029805261660582616, 0.006632443001374888, -0.0001432866833209548, -0.0036223544000266727, 0.010311273519520547, 0.010954890715027421, -0.0004507829074698877, 0.044880764363558256, -0.06881129300763433, 0.12650366371652666, 0.026639410495285176, 0.17875362623119884], [0.14348649756185458]}

Post-process results

Update test dataset by adding energy descriptors.

println("Computing local descriptors of test dataset")
e_descr_test = compute_local_descriptors(conf_test, ace, pbar = false)
ds_test = DataSet(conf_test .+ e_descr_test);
Computing local descriptors of test dataset

Get true and predicted energy values (assuming that all configurations have the same no. of atoms).

n = size(get_system(ds_train[1]))[1]
e_train, e_train_pred = get_all_energies(ds_train)/n, get_all_energies(ds_train, lb)/n
e_test, e_test_pred   = get_all_energies(ds_test)/n, get_all_energies(ds_test, lb)/n;

Compute and print metrics.

e_mae, e_rmse, e_rsq = calc_metrics(e_train, e_train_pred)
println("MAE: $e_mae, RMSE: $e_rmse, RSQ: $e_rsq")
MAE: 0.004524948510912386, RMSE: 0.005944007860629054, RSQ: 0.8370129343720856

Plot energy error.

e_err_train, e_err_test = (e_train_pred - e_train), (e_test_pred - e_test)
dpp_inds2 = get_random_subset(dpp; batch_size = 20)
p = scatter( e_train, e_err_train, label = "Training", color = :blue,
         markersize = 1.5, markerstrokewidth=0)
scatter!(e_test, e_err_test, label = "Test", color = :yellow,
         markersize = 1.5, markerstrokewidth=0)
scatter!(e_train[dpp_inds2], e_err_train[dpp_inds2],
         color = :darkred, label = "DPP Samples",
         markersize = 2.5, markerstrokewidth=0)
scatter!(xlabel = "Energy (eV/atom)", ylabel = "Error (eV/atom)",
         dpi = 1000, fontsize = 16)
Example block output

This page was generated using Literate.jl.