Subsample Na dataset with DPP and fit energies with ACE

Setup experiment

Load packages.

using Unitful, UnitfulAtomic
using AtomsBase, InteratomicPotentials, PotentialLearning
using LinearAlgebra, Plots

Define paths.

base_path = haskey(ENV, "BASE_PATH") ? ENV["BASE_PATH"] : "../../"
ds_path   = "$base_path/examples/data/Na/liquify_sodium.yaml";

Load datasets

Load atomistic dataset: atomistic configurations (atom positions, geometry, etc.) + DFT data (energies, forces, etc.).

confs, thermo = load_data(ds_path, YAML(:Na, u"eV", u"Å"))
confs, thermo = confs[220:end], thermo[220:end]
(DataSet{num_configs = 2092} 
	 Configuration{S, Energy, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces}
	 Configuration{S, Energy, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces}
	 ⋮
	 Configuration{S, Energy, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces}, Vector{Dict}[[Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[109500, -118.56754033566662, 1.2468911443758697, -117.32064919129076, 90.1529877225114, 4481.320775643779, -8.243380533904219, 8.243380533904219, -8.243380533904219, 8.243380533904219, -8.243380533904219, 8.243380533904219])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[110000, -118.8191048880354, 1.3088999144548927, -117.5102049735805, 94.63635895570786, 4445.796913398657, -8.22154069868183, 8.22154069868183, -8.22154069868183, 8.22154069868183, -8.22154069868183, 8.22154069868183])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[110500, -118.66706916398354, 1.1040005887840023, -117.56306857519954, 79.82168449524809, 4443.609994398957, -8.220192399538881, 8.220192399538881, -8.220192399538881, 8.220192399538881, -8.220192399538881, 8.220192399538881])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[111000, -118.5424954601734, 1.253547941582279, -117.28894751859112, 90.63428888462772, 4457.032418979577, -8.228460752349738, 8.228460752349738, -8.228460752349738, 8.228460752349738, -8.228460752349738, 8.228460752349738])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[111500, -118.60598321165922, 1.2766945282051063, -117.32928868345411, 92.30783829512603, 4472.0581396906955, -8.23769708292504, 8.23769708292504, -8.23769708292504, 8.23769708292504, -8.23769708292504, 8.23769708292504])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[112000, -118.85324930143875, 1.3425751052444483, -117.5106741961943, 97.0711497355587, 4466.470566415444, -8.234264813932816, 8.234264813932816, -8.234264813932816, 8.234264813932816, -8.234264813932816, 8.234264813932816])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[112500, -118.71245699254627, 1.1670894723589, -117.54536752018737, 84.38315032328688, 4468.368033253137, -8.235430688433015, 8.235430688433015, -8.235430688433015, 8.235430688433015, -8.235430688433015, 8.235430688433015])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[113000, -118.76694882093753, 1.2319510396827227, -117.5349977812548, 89.07278510735168, 4446.342141937924, -8.221876779064969, 8.221876779064969, -8.221876779064969, 8.221876779064969, -8.221876779064969, 8.221876779064969])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[113500, -118.51709527877453, 1.3207562411255773, -117.19633903764895, 95.49359759887054, 4437.23245692403, -8.216257934441625, 8.216257934441625, -8.216257934441625, 8.216257934441625, -8.216257934441625, 8.216257934441625])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[114000, -118.72667005007821, 1.2772963700128563, -117.44937368006535, 92.35135278903331, 4410.550998548226, -8.199756461230859, 8.199756461230859, -8.199756461230859, 8.199756461230859, -8.199756461230859, 8.199756461230859])]  …  [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1150500, -111.33274572565541, 4.329760927088277, -107.00298479856714, 313.05129197672164, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1151000, -112.05085995319563, 4.005480215432024, -108.0453797377636, 289.60507925120993, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1151500, -111.93748415461936, 4.854075903442631, -107.08340825117672, 350.9604244933883, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1152000, -113.07999770701338, 4.265828336045448, -108.81416937096793, 308.4288242325647, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1152500, -112.41525332679798, 4.987262862658069, -107.42799046413991, 360.5901362393629, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1153000, -111.8906050438705, 4.365236171743513, -107.52536887212698, 315.61623063257963, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1153500, -111.78815670229865, 4.611141339238888, -107.17701536305977, 333.395718157298, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1154000, -112.488679423801, 4.081705353887927, -108.40697406991308, 295.11632536307604, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1154500, -111.6431739577287, 3.906342818047219, -107.73683113968148, 282.43722613942305, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])], [Dict{Any, Any}("keywords" => ["Step", "PotEng", "KinEng", "TotEng", "Temp", "Volume", "Xlo", "Xhi", "Ylo", "Yhi", "Zlo", "Zhi"]), Dict{Any, Any}("data" => Real[1155000, -111.57970314361862, 4.291823239871932, -107.28787990374668, 310.3083132770473, 4881.2260943202855, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883, -8.481636303280883, 8.481636303280883])]])

Split atomistic dataset into training and test.

conf_train, conf_test = confs[1:1000], confs[1001:end]
(DataSet{num_configs = 1000} 
	 Configuration{S, Energy, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces}
	 Configuration{S, Energy, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces}
	 ⋮
	 Configuration{S, Energy, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces}, DataSet{num_configs = 1092} 
	 Configuration{S, Energy, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces}
	 Configuration{S, Energy, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces}
	 ⋮
	 Configuration{S, Energy, AtomsBase.FlexibleSystem{3, AtomsBase.Atom{3, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}, Unitful.Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Unitful.Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Forces})

Compute descriptors

Create ACE basis.

ace = ACE(species = [:Na],         # species
          body_order = 4,          # 4-body
          polynomial_degree = 8,   # 8 degree polynomials
          wL = 1.0,                # Defaults, See ACE.jl documentation
          csp = 1.0,               # Defaults, See ACE.jl documentation
          r0 = 1.0,                # minimum distance between atoms
          rcutoff = 5.0);          # cutoff radius

Update training dataset by adding energy (local) descriptors.

println("Computing local descriptors of training dataset")
e_descr_train = compute_local_descriptors(conf_train, ace) # JLD.load("data/sodium_empirical_full.jld", "descriptors")
1000-element Vector{LocalDescriptors}:
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 ⋮
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}
 LocalDescriptors{n = 108, d = 73}

Update training dataset by adding energy and force descriptors.

ds_train = DataSet(conf_train .+ e_descr_train);

Subsample dataset

Create DPP subselector.

dpp = kDPP(ds_train, GlobalMean(), DotProduct(); batch_size = 200)
kDPP(L-Ensemble.
Number of items in ground set : 1000. Max. rank : 1000. Rescaling constant α=4.306186451810486e14
, 200)

Subsample trainig dataset.

dpp_inds = get_random_subset(dpp);

Learn coefficients

Learn ACE coefficients based on ACE descriptors and DFT data.

lb = LBasisPotential(ace)
α = 1e-8
Σ = learn!(lb, ds_train[dpp_inds], α)
UnivariateLinearProblem{T, [-0.060920287099087034, 0.28126329047309184, -0.663808325527981, 1.002277393205702, -0.9380998761609476, 0.2754590055905055, 0.5009256466271381, 0.2768811043230812, -0.0019644408615833164, 0.006177941821742472, -0.013256072850065692, 0.014547539223734418, 0.005369658334572591, -0.06364670654744886, 0.16789486016351773, -0.006699275307994412, 0.02497159558545725, -0.07252603118840995, 0.15236160628866458, -0.255454961108299, 0.0021459643026797703, -0.007680566028986036, 0.020996762882038666, -0.0021276567953840697, -0.03682871019848122, 0.04886403934833794, -0.026472475106274374, -0.12674757891505095, 0.5015459801461404, -0.18624330217375018, 0.3827146854069099, -0.8046570243287192, 0.05499051922402387, -0.07851331420394558, -0.12999819538190138, 0.774436056844575, -1.574228861560755, 1.2239182767414327, -0.00016079411227490348, 0.00046623396229761305, -0.0012594123073991215, 0.0023774559957074892, -0.003387554888251376, 0.0036006881007213387, -0.00011659361750585712, 0.00043378904560845144, -0.0012564443750768817, 0.0026294141065525967, 4.112569702240457e-5, -0.00014751090934282673, -0.002322193281508955, 0.0043754235601802804, -0.008070576876900937, 0.011010503995036558, -0.0032292237511242217, 0.0066177554171491515, -0.016111074841191453, 0.02016333613430632, 2.048418882037717e-5, 0.0006368483986647355, -0.00010356904393500621, -0.0019467972057095606, 0.004391173391249392, -7.509234744007554e-5, -0.0023697491441625473, 0.006865248409550162, 0.007245530908623898, -0.00022250940862431068, 0.019833268873735932, -0.02999964715568737, 0.05370155931120024, 0.017643471506834985, 0.07630096488195193], [0.25655607496637123]}

Post-process results

Update test dataset by adding energy descriptors.

println("Computing local descriptors of test dataset")
e_descr_test = compute_local_descriptors(conf_test, ace, pbar = false)
ds_test = DataSet(conf_test .+ e_descr_test);
Computing local descriptors of test dataset

Get true and predicted energy values (assuming that all configurations have the same no. of atoms).

n = size(get_system(ds_train[1]))[1]
e_train, e_train_pred = get_all_energies(ds_train)/n, get_all_energies(ds_train, lb)/n
e_test, e_test_pred   = get_all_energies(ds_test)/n, get_all_energies(ds_test, lb)/n;

Compute and print metrics.

e_mae, e_rmse, e_rsq = calc_metrics(e_train, e_train_pred)
println("MAE: $e_mae, RMSE: $e_rmse, RSQ: $e_rsq")
MAE: 0.004586768589844355, RMSE: 0.006003550961632418, RSQ: 0.8295454924878514

Plot energy error.

e_err_train, e_err_test = (e_train_pred - e_train), (e_test_pred - e_test)
dpp_inds2 = get_random_subset(dpp; batch_size = 20)
p = scatter( e_train, e_err_train, label = "Training", color = :blue,
         markersize = 1.5, markerstrokewidth=0)
scatter!(e_test, e_err_test, label = "Test", color = :yellow,
         markersize = 1.5, markerstrokewidth=0)
scatter!(e_train[dpp_inds2], e_err_train[dpp_inds2],
         color = :darkred, label = "DPP Samples",
         markersize = 2.5, markerstrokewidth=0)
scatter!(xlabel = "Energy (eV/atom)", ylabel = "Error (eV/atom)",
         dpi = 1000, fontsize = 16)
Example block output

This page was generated using Literate.jl.