push!(Base.LOAD_PATH, "../../")
using InteratomicPotentials
using AtomsBase
using UnitfulAtomic
using Unitful
using StaticArrays
using LinearAlgebra
using Statistics
Load data
example_path = joinpath(dirname(pathof(InteratomicPotentials)),
"..", "examples/LJCluster")
include("$example_path/load_data.jl")
systems, energies, forces =
load_data(; num_entries = 100,
file = "$example_path/curated_lj_cluster.xyz");
101
Split into training and test
train_systems, train_energies, train_forces = systems[1:50], energies[1:50], forces[1:50];
test_systems, test_energies, test_forces = systems[51:end], energies[51:end], forces[51:end];
Create ACE Basis (2body, 8 polynomial degree)
ace = ACE( species = [:Ar],
body_order = 2,
polynomial_degree = 8,
wL = 1.0,
csp = 1.0,
r0 = 1.0,
rcutoff = 2.0)
ACE([:Ar], 2, 8, 2.0, 1.0, 1.0, 1.0, ACE1.RPI.RPIBasis{Float64, ACE1.RPI.BasicPSH1pBasis{Float64, 1, ACE1.OrthPolys.TransformedPolys{Float64, ACE1.Transforms.PolyTransform{Int64, Float64}, ACE1.OrthPolys.OrthPolyBasis{Float64}, ACE1.OrthPolys.OneEnvelope}}, 1, ACE1.DAGEvaluator}(ACE1.PIBasis{ACE1.RPI.BasicPSH1pBasis{Float64, 1, ACE1.OrthPolys.TransformedPolys{Float64, ACE1.Transforms.PolyTransform{Int64, Float64}, ACE1.OrthPolys.OrthPolyBasis{Float64}, ACE1.OrthPolys.OneEnvelope}}, 1, ACE1.DAGEvaluator}(ACE1.RPI.BasicPSH1pBasis{Float64, 1, ACE1.OrthPolys.TransformedPolys{Float64, ACE1.Transforms.PolyTransform{Int64, Float64}, ACE1.OrthPolys.OrthPolyBasis{Float64}, ACE1.OrthPolys.OneEnvelope}}(ACE1.OrthPolys.TransformedPolys{Float64, ACE1.Transforms.PolyTransform{Int64, Float64}, ACE1.OrthPolys.OrthPolyBasis{Float64}, ACE1.OrthPolys.OneEnvelope}(ACE1.OrthPolys.OrthPolyBasis{Float64}(2, 0.4444444444444444, 0, 1.4692378328741966, [2.1291808597507096, 6.927551862540507, 5.163501330270071, 4.623432881123967, 4.37589341829516, 4.239373572278735, 4.155366284472768, 4.099731135119782], [0.0, -8.995000596148365, -5.822564729770663, -4.897686822019796, -4.485978541670284, -4.263265659159493, -4.128082315742125, -4.039459785741758], [0.0, 0.0, -0.7453572968815858, -0.8954065440091757, -0.946459812612521, -0.9688018347417651, -0.9801840327648093, -0.9866112526443569], [0.4449568411386593, 0.44598163452708905, 0.4470064279155188, 0.44803122130394857, 0.4490560146923783, 0.4500808080808081, 0.45110560146923784, 0.45213039485766754, 0.4531551882460973, 0.45417998163452705 … 1.459502295684114, 1.4605270890725437, 1.4615518824609735, 1.4625766758494032, 1.463601469237833, 1.4646262626262627, 1.4656510560146925, 1.4666758494031222, 1.467700642791552, 1.4687254361799817], [0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001 … 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001]), ACE1.Transforms.PolyTransform{Int64, Float64}(2, 1.0, 1.0), 0.65, 2.0, ACE1.OrthPolys.OneEnvelope()), ACE1.SphericalHarmonics.SHBasis{Float64}(7, ACE1.SphericalHarmonics.ALPCoefficients{Float64}([5.0e-324, 1.0e-323, 1.5e-323, 1.9364916731037085, 2.5e-323, 3.0e-323, 1.9720265943665387, 2.091650066335189, 4.4e-323, 5.0e-323 … 1.33e-322, 1.4e-322, 1.9948914348241344, 2.0155644370746373, 2.0816659994661326, 2.207940216581962, 2.4308621740219887, 2.850438562747845, 1.73e-322, 1.8e-322], [5.0e-324, 1.0e-323, 1.5e-323, -0.5773502691896257, 2.5e-323, 3.0e-323, -0.5163977794943222, -0.4472135954999579, 4.4e-323, 1.14e-322 … 6.4e-323, 7.0e-323, -0.5017452060042545, -0.4947274449181537, -0.47304991679126607, -0.4345240946267408, -0.3739787960033829, -0.2773500981126146, 1.0e-322, 1.04e-322])), JuLIP.Potentials.SZList{1}(JuLIP.Chemistry.AtomicNumber[<18>]), ACE1.RPI.PSH1pBasisFcn[znlm[0|1,0,0], znlm[0|1,1,-1], znlm[0|1,1,0], znlm[0|1,1,1], znlm[0|1,2,-2], znlm[0|1,2,-1], znlm[0|1,2,0], znlm[0|1,2,1], znlm[0|1,2,2], znlm[0|1,3,-3] … znlm[0|6,2,-2], znlm[0|6,2,-1], znlm[0|6,2,0], znlm[0|6,2,1], znlm[0|6,2,2], znlm[0|7,0,0], znlm[0|7,1,-1], znlm[0|7,1,0], znlm[0|7,1,1], znlm[0|8,0,0]], UnitRange{Int64}[1:204;;]), JuLIP.Potentials.SZList{1}(JuLIP.Chemistry.AtomicNumber[<18>]), (ACE1.InnerPIBasis([1, 1, 1, 1, 1, 1, 1, 1], [1; 65; … ; 200; 204;;], Dict{ACE1.PIBasisFcn, Int64}(ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|7,0,0],), ACE1.RPI.PSH1pBasisFcn) => 7, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|6,0,0],), ACE1.RPI.PSH1pBasisFcn) => 6, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|1,0,0],), ACE1.RPI.PSH1pBasisFcn) => 1, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|8,0,0],), ACE1.RPI.PSH1pBasisFcn) => 8, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|2,0,0],), ACE1.RPI.PSH1pBasisFcn) => 2, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|5,0,0],), ACE1.RPI.PSH1pBasisFcn) => 5, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|4,0,0],), ACE1.RPI.PSH1pBasisFcn) => 4, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|3,0,0],), ACE1.RPI.PSH1pBasisFcn) => 3), Dict{Any, Int64}(znlm[18|3,1,-1] => 115, znlm[18|1,6,5] => 48, znlm[18|4,3,-2] => 160, znlm[18|2,3,-1] => 76, znlm[18|1,6,-3] => 40, znlm[18|1,5,3] => 34, znlm[18|3,5,-2] => 142, znlm[18|3,5,5] => 149, znlm[18|1,4,4] => 25, znlm[18|3,5,2] => 146…), 1:8, <18>, ACE1.DAG.CorrEvalGraph{Int64, Int64}([(1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0), (7, 0), (8, 0), (9, 0), (10, 0) … (195, 0), (196, 0), (197, 0), (198, 0), (199, 0), (200, 0), (201, 0), (202, 0), (203, 0), (204, 0)], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0 … 0, 0, 0, 0, 0, 7, 0, 0, 0, 8], 204, 204)),), ACE1.DAGEvaluator()), (sparse([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 8, 8),), (1:8,)))
Calculate descriptors
train_descriptors = [sum(compute_local_descriptors(sys, ace)) for sys in train_systems]
test_descriptors = [sum(compute_local_descriptors(sys, ace)) for sys in test_systems]
50-element Vector{Vector{Float64}}:
[9.637553471101803, -26.675859876119226, 25.786917272571632, -0.7260195256973117, -13.708573482362102, -6.561704406552489, 33.060213550469, -30.240609807870495]
[8.606741075726184, -24.751126644174448, 26.381953910263245, -5.232261587024965, -10.81076704195966, -2.2133986960506764, 24.44807197736727, -23.454597072227124]
[7.995727465304647, -23.733500892883324, 26.924888559433835, -7.117859145193399, -12.273244261801626, 4.963518700124103, 17.832036907805147, -25.15112903959363]
[7.921558789343797, -23.584238338889094, 26.884065458500498, -7.146565830179244, -12.796969315341903, 6.799615689782368, 14.664213370807753, -21.776297574689256]
[7.792778687395339, -23.300464224699315, 26.884232957489132, -7.760354578458356, -12.39593962439861, 7.976413535211458, 12.267557941813429, -20.589078929705877]
[8.055816010742193, -23.85343527092131, 26.78441005587313, -6.226639365419687, -13.997554168621804, 7.088120745001107, 15.671025846169771, -23.15509046819463]
[7.891089195594272, -23.638451884555373, 27.302522604677687, -8.060238223548193, -11.269105007377025, 4.414171077602776, 17.14905475352475, -22.199255304250556]
[7.843991305114745, -23.1669923055256, 26.1242961439842, -7.26255431550306, -9.809384249643417, 1.072677118774266, 18.99224856756854, -20.14770148893393]
[8.428789831091315, -24.473857013361368, 26.447337791575837, -5.0933442951458305, -12.854964936125919, 1.7221806787543672, 21.859666259005984, -25.029519808795587]
[7.824945011307422, -23.65537565712848, 27.720405358863843, -8.370591161502684, -12.3665597767453, 6.7233945884552835, 16.20013992232284, -23.900512454554395]
⋮
[8.409183418512061, -24.244985114786722, 26.1394714656022, -6.139348481277596, -8.845038985801187, -3.696811762929591, 23.794145309696066, -20.813809169355512]
[8.375082824736895, -24.393153967420325, 26.85824180986371, -7.121893705993498, -8.408923920070215, -4.001433391071423, 25.39906289939859, -22.881763688193935]
[7.836335126129326, -23.496112065552804, 27.45055116841227, -9.307111022704603, -8.802719893672874, 2.1508145981594486, 17.634059047863655, -21.003239561898518]
[8.132077419321302, -23.954048037457785, 27.1889983827391, -8.34513920372817, -9.294554173539439, 1.921171923909791, 18.301253269613646, -23.36202362919198]
[7.844879379234444, -23.390216729527342, 26.992616470830377, -8.494346757471291, -9.768153606357954, 3.299195785298471, 16.37737674227996, -20.683395837379326]
[7.581942977993447, -23.192375513064494, 27.74527735812395, -8.928520976963663, -12.797356763908624, 8.907982073123044, 14.493555003621701, -25.109274643736367]
[7.919360855803918, -23.614054931612426, 27.37375182638167, -9.26516559316405, -8.211702882815906, 1.6582439920760017, 15.533644510000197, -15.264249968296193]
[8.03748824668811, -23.779633201166398, 26.97905610248553, -7.8532866831110395, -9.771367473971607, 1.2472503238211337, 19.343452472042703, -20.885497245247535]
[7.9135638083288065, -23.422972449123172, 26.521589090001356, -7.296545629727583, -10.976002657737606, 3.338865904543715, 17.882587846879424, -22.615946287685794]
Calculate mah_dist
mean_d = mean([train_descriptors; test_descriptors])
cov_d = cov([train_descriptors; test_descriptors])
inv_cov_d = inv(cov_d)
mah_dist_train = [ sqrt( 0.5*(td - mean_d)' * inv_cov_d * (td - mean_d) )
for td in train_descriptors ]
mah_dist_test = [ sqrt( 0.5*(td - mean_d)' * inv_cov_d * (td - mean_d) )
for td in test_descriptors ]
50-element Vector{Float64}:
3.337873662581474
1.9266726882837408
1.534690130260185
1.6511596807911857
2.1181866044660094
1.9193249989872847
2.583475641772898
1.7993381943009978
1.6388316082702068
1.6660577899567883
⋮
1.5688273715850631
2.5958049749658376
1.4121716321806719
3.4444540162081663
1.610728869370596
1.7532542295405253
2.4945185512050854
1.8000613443997822
1.3867716904562941
Create linear basis potential and estimate β
lb = LBasisPotential(ace)
lb.β .= hcat(train_descriptors...)' \ train_energies
8-element Vector{Float64}:
808.542540874818
627.5458496708357
399.1204553346803
210.47149514925792
90.30276018279386
30.16180912003849
7.085777387173585
0.9119833147776378
Calculate potential energy
potential_energy(test_systems[1], lb)
-37.82994931330458
This page was generated using Literate.jl.