push!(Base.LOAD_PATH, "../../")

using InteratomicPotentials
using AtomsBase
using UnitfulAtomic
using Unitful
using StaticArrays
using LinearAlgebra
using Statistics

Load data

example_path = joinpath(dirname(pathof(InteratomicPotentials)),
                        "..", "examples/LJCluster")
include("$example_path/load_data.jl")
systems, energies, forces =
    load_data(; num_entries = 100,
                file = "$example_path/curated_lj_cluster.xyz");
101

Split into training and test

train_systems, train_energies, train_forces = systems[1:50], energies[1:50], forces[1:50];
test_systems, test_energies, test_forces = systems[51:end], energies[51:end], forces[51:end];

Create ACE Basis (2body, 8 polynomial degree)

ace = ACE( species = [:Ar],
           body_order = 2,
           polynomial_degree = 8,
           wL = 1.0,
           csp = 1.0,
           r0 = 1.0,
           rcutoff = 2.0)
ACE([:Ar], 2, 8, 2.0, 1.0, 1.0, 1.0, ACE1.RPI.RPIBasis{Float64, ACE1.RPI.BasicPSH1pBasis{Float64, 1, ACE1.OrthPolys.TransformedPolys{Float64, ACE1.Transforms.PolyTransform{Int64, Float64}, ACE1.OrthPolys.OrthPolyBasis{Float64}, ACE1.OrthPolys.OneEnvelope}}, 1, ACE1.DAGEvaluator}(ACE1.PIBasis{ACE1.RPI.BasicPSH1pBasis{Float64, 1, ACE1.OrthPolys.TransformedPolys{Float64, ACE1.Transforms.PolyTransform{Int64, Float64}, ACE1.OrthPolys.OrthPolyBasis{Float64}, ACE1.OrthPolys.OneEnvelope}}, 1, ACE1.DAGEvaluator}(ACE1.RPI.BasicPSH1pBasis{Float64, 1, ACE1.OrthPolys.TransformedPolys{Float64, ACE1.Transforms.PolyTransform{Int64, Float64}, ACE1.OrthPolys.OrthPolyBasis{Float64}, ACE1.OrthPolys.OneEnvelope}}(ACE1.OrthPolys.TransformedPolys{Float64, ACE1.Transforms.PolyTransform{Int64, Float64}, ACE1.OrthPolys.OrthPolyBasis{Float64}, ACE1.OrthPolys.OneEnvelope}(ACE1.OrthPolys.OrthPolyBasis{Float64}(2, 0.4444444444444444, 0, 1.4692378328741966, [2.1291808597507096, 6.927551862540507, 5.163501330270071, 4.623432881123967, 4.37589341829516, 4.239373572278735, 4.155366284472768, 4.099731135119782], [0.0, -8.995000596148365, -5.822564729770663, -4.897686822019796, -4.485978541670284, -4.263265659159493, -4.128082315742125, -4.039459785741758], [0.0, 0.0, -0.7453572968815858, -0.8954065440091757, -0.946459812612521, -0.9688018347417651, -0.9801840327648093, -0.9866112526443569], [0.4449568411386593, 0.44598163452708905, 0.4470064279155188, 0.44803122130394857, 0.4490560146923783, 0.4500808080808081, 0.45110560146923784, 0.45213039485766754, 0.4531551882460973, 0.45417998163452705  …  1.459502295684114, 1.4605270890725437, 1.4615518824609735, 1.4625766758494032, 1.463601469237833, 1.4646262626262627, 1.4656510560146925, 1.4666758494031222, 1.467700642791552, 1.4687254361799817], [0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001  …  0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001]), ACE1.Transforms.PolyTransform{Int64, Float64}(2, 1.0, 1.0), 0.65, 2.0, ACE1.OrthPolys.OneEnvelope()), ACE1.SphericalHarmonics.SHBasis{Float64}(7, ACE1.SphericalHarmonics.ALPCoefficients{Float64}([5.0e-324, 1.0e-323, 1.5e-323, 1.9364916731037085, 2.5e-323, 3.0e-323, 1.9720265943665387, 2.091650066335189, 4.4e-323, 5.0e-323  …  1.33e-322, 1.4e-322, 1.9948914348241344, 2.0155644370746373, 2.0816659994661326, 2.207940216581962, 2.4308621740219887, 2.850438562747845, 1.73e-322, 1.8e-322], [5.0e-324, 1.0e-323, 1.5e-323, -0.5773502691896257, 2.5e-323, 3.0e-323, -0.5163977794943222, -0.4472135954999579, 4.4e-323, 1.14e-322  …  6.4e-323, 7.0e-323, -0.5017452060042545, -0.4947274449181537, -0.47304991679126607, -0.4345240946267408, -0.3739787960033829, -0.2773500981126146, 1.0e-322, 1.04e-322])), JuLIP.Potentials.SZList{1}(JuLIP.Chemistry.AtomicNumber[<18>]), ACE1.RPI.PSH1pBasisFcn[znlm[0|1,0,0], znlm[0|1,1,-1], znlm[0|1,1,0], znlm[0|1,1,1], znlm[0|1,2,-2], znlm[0|1,2,-1], znlm[0|1,2,0], znlm[0|1,2,1], znlm[0|1,2,2], znlm[0|1,3,-3]  …  znlm[0|6,2,-2], znlm[0|6,2,-1], znlm[0|6,2,0], znlm[0|6,2,1], znlm[0|6,2,2], znlm[0|7,0,0], znlm[0|7,1,-1], znlm[0|7,1,0], znlm[0|7,1,1], znlm[0|8,0,0]], UnitRange{Int64}[1:204;;]), JuLIP.Potentials.SZList{1}(JuLIP.Chemistry.AtomicNumber[<18>]), (ACE1.InnerPIBasis([1, 1, 1, 1, 1, 1, 1, 1], [1; 65; … ; 200; 204;;], Dict{ACE1.PIBasisFcn, Int64}(ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|7,0,0],), ACE1.RPI.PSH1pBasisFcn) => 7, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|6,0,0],), ACE1.RPI.PSH1pBasisFcn) => 6, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|1,0,0],), ACE1.RPI.PSH1pBasisFcn) => 1, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|8,0,0],), ACE1.RPI.PSH1pBasisFcn) => 8, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|2,0,0],), ACE1.RPI.PSH1pBasisFcn) => 2, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|5,0,0],), ACE1.RPI.PSH1pBasisFcn) => 5, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|4,0,0],), ACE1.RPI.PSH1pBasisFcn) => 4, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|3,0,0],), ACE1.RPI.PSH1pBasisFcn) => 3), Dict{Any, Int64}(znlm[18|3,1,-1] => 115, znlm[18|1,6,5] => 48, znlm[18|4,3,-2] => 160, znlm[18|2,3,-1] => 76, znlm[18|1,6,-3] => 40, znlm[18|1,5,3] => 34, znlm[18|3,5,-2] => 142, znlm[18|3,5,5] => 149, znlm[18|1,4,4] => 25, znlm[18|3,5,2] => 146…), 1:8, <18>, ACE1.DAG.CorrEvalGraph{Int64, Int64}([(1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0), (7, 0), (8, 0), (9, 0), (10, 0)  …  (195, 0), (196, 0), (197, 0), (198, 0), (199, 0), (200, 0), (201, 0), (202, 0), (203, 0), (204, 0)], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 7, 0, 0, 0, 8], 204, 204)),), ACE1.DAGEvaluator()), (sparse([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 8, 8),), (1:8,)))

Calculate descriptors

train_descriptors = [sum(compute_local_descriptors(sys, ace)) for sys in train_systems]
test_descriptors = [sum(compute_local_descriptors(sys, ace)) for sys in test_systems]
50-element Vector{Vector{Float64}}:
 [9.637553471101803, -26.675859876119226, 25.786917272571632, -0.7260195256973117, -13.708573482362102, -6.561704406552489, 33.060213550469, -30.240609807870495]
 [8.606741075726184, -24.751126644174448, 26.381953910263245, -5.232261587024965, -10.81076704195966, -2.2133986960506764, 24.44807197736727, -23.454597072227124]
 [7.995727465304647, -23.733500892883324, 26.924888559433835, -7.117859145193399, -12.273244261801626, 4.963518700124103, 17.832036907805147, -25.15112903959363]
 [7.921558789343797, -23.584238338889094, 26.884065458500498, -7.146565830179244, -12.796969315341903, 6.799615689782368, 14.664213370807753, -21.776297574689256]
 [7.792778687395339, -23.300464224699315, 26.884232957489132, -7.760354578458356, -12.39593962439861, 7.976413535211458, 12.267557941813429, -20.589078929705877]
 [8.055816010742193, -23.85343527092131, 26.78441005587313, -6.226639365419687, -13.997554168621804, 7.088120745001107, 15.671025846169771, -23.15509046819463]
 [7.891089195594272, -23.638451884555373, 27.302522604677687, -8.060238223548193, -11.269105007377025, 4.414171077602776, 17.14905475352475, -22.199255304250556]
 [7.843991305114745, -23.1669923055256, 26.1242961439842, -7.26255431550306, -9.809384249643417, 1.072677118774266, 18.99224856756854, -20.14770148893393]
 [8.428789831091315, -24.473857013361368, 26.447337791575837, -5.0933442951458305, -12.854964936125919, 1.7221806787543672, 21.859666259005984, -25.029519808795587]
 [7.824945011307422, -23.65537565712848, 27.720405358863843, -8.370591161502684, -12.3665597767453, 6.7233945884552835, 16.20013992232284, -23.900512454554395]
 ⋮
 [8.409183418512061, -24.244985114786722, 26.1394714656022, -6.139348481277596, -8.845038985801187, -3.696811762929591, 23.794145309696066, -20.813809169355512]
 [8.375082824736895, -24.393153967420325, 26.85824180986371, -7.121893705993498, -8.408923920070215, -4.001433391071423, 25.39906289939859, -22.881763688193935]
 [7.836335126129326, -23.496112065552804, 27.45055116841227, -9.307111022704603, -8.802719893672874, 2.1508145981594486, 17.634059047863655, -21.003239561898518]
 [8.132077419321302, -23.954048037457785, 27.1889983827391, -8.34513920372817, -9.294554173539439, 1.921171923909791, 18.301253269613646, -23.36202362919198]
 [7.844879379234444, -23.390216729527342, 26.992616470830377, -8.494346757471291, -9.768153606357954, 3.299195785298471, 16.37737674227996, -20.683395837379326]
 [7.581942977993447, -23.192375513064494, 27.74527735812395, -8.928520976963663, -12.797356763908624, 8.907982073123044, 14.493555003621701, -25.109274643736367]
 [7.919360855803918, -23.614054931612426, 27.37375182638167, -9.26516559316405, -8.211702882815906, 1.6582439920760017, 15.533644510000197, -15.264249968296193]
 [8.03748824668811, -23.779633201166398, 26.97905610248553, -7.8532866831110395, -9.771367473971607, 1.2472503238211337, 19.343452472042703, -20.885497245247535]
 [7.9135638083288065, -23.422972449123172, 26.521589090001356, -7.296545629727583, -10.976002657737606, 3.338865904543715, 17.882587846879424, -22.615946287685794]

Calculate mah_dist

mean_d = mean([train_descriptors; test_descriptors])
cov_d = cov([train_descriptors; test_descriptors])
inv_cov_d = inv(cov_d)
mah_dist_train = [ sqrt( 0.5*(td - mean_d)' * inv_cov_d * (td - mean_d) )
                   for td in train_descriptors ]
mah_dist_test = [ sqrt( 0.5*(td - mean_d)' * inv_cov_d * (td - mean_d) )
                  for td in test_descriptors ]
50-element Vector{Float64}:
 3.337873662581474
 1.9266726882837408
 1.534690130260185
 1.6511596807911857
 2.1181866044660094
 1.9193249989872847
 2.583475641772898
 1.7993381943009978
 1.6388316082702068
 1.6660577899567883
 ⋮
 1.5688273715850631
 2.5958049749658376
 1.4121716321806719
 3.4444540162081663
 1.610728869370596
 1.7532542295405253
 2.4945185512050854
 1.8000613443997822
 1.3867716904562941

Create linear basis potential and estimate β

lb = LBasisPotential(ace)
lb.β .= hcat(train_descriptors...)' \ train_energies
8-element Vector{Float64}:
 808.542540874818
 627.5458496708357
 399.1204553346803
 210.47149514925792
  90.30276018279386
  30.16180912003849
   7.085777387173585
   0.9119833147776378

Calculate potential energy

potential_energy(test_systems[1], lb)
-37.82994931330458

This page was generated using Literate.jl.