push!(Base.LOAD_PATH, "../../")

using InteratomicPotentials
using AtomsBase
using UnitfulAtomic
using Unitful
using StaticArrays
using LinearAlgebra
using Statistics

Load data

example_path = joinpath(dirname(pathof(InteratomicPotentials)),
                        "..", "examples/LJCluster")
include("$example_path/load_data.jl")
systems, energies, forces =
    load_data(; num_entries = 100,
                file = "$example_path/curated_lj_cluster.xyz");
101

Split into training and test

train_systems, train_energies, train_forces = systems[1:50], energies[1:50], forces[1:50];
test_systems, test_energies, test_forces = systems[51:end], energies[51:end], forces[51:end];

Create ACE Basis (2body, 8 polynomial degree)

ace = ACE( species = [:Ar],
           body_order = 2,
           polynomial_degree = 8,
           wL = 1.0,
           csp = 1.0,
           r0 = 1.0,
           rcutoff = 2.0)
ACE([:Ar], 2, 8, 2.0, 1.0, 1.0, 1.0, ACE1.RPI.RPIBasis{Float64, ACE1.RPI.BasicPSH1pBasis{Float64, 1, ACE1.OrthPolys.TransformedPolys{Float64, ACE1.Transforms.PolyTransform{Int64, Float64}, ACE1.OrthPolys.OrthPolyBasis{Float64}, ACE1.OrthPolys.OneEnvelope}}, 1, ACE1.DAGEvaluator}(ACE1.PIBasis{ACE1.RPI.BasicPSH1pBasis{Float64, 1, ACE1.OrthPolys.TransformedPolys{Float64, ACE1.Transforms.PolyTransform{Int64, Float64}, ACE1.OrthPolys.OrthPolyBasis{Float64}, ACE1.OrthPolys.OneEnvelope}}, 1, ACE1.DAGEvaluator}(ACE1.RPI.BasicPSH1pBasis{Float64, 1, ACE1.OrthPolys.TransformedPolys{Float64, ACE1.Transforms.PolyTransform{Int64, Float64}, ACE1.OrthPolys.OrthPolyBasis{Float64}, ACE1.OrthPolys.OneEnvelope}}(ACE1.OrthPolys.TransformedPolys{Float64, ACE1.Transforms.PolyTransform{Int64, Float64}, ACE1.OrthPolys.OrthPolyBasis{Float64}, ACE1.OrthPolys.OneEnvelope}(ACE1.OrthPolys.OrthPolyBasis{Float64}(2, 0.4444444444444444, 0, 1.4692378328741966, [2.1291808597507096, 6.927551862540507, 5.163501330270071, 4.623432881123967, 4.375893418295161, 4.239373572278734, 4.155366284472768, 4.099731135119783], [0.0, -8.995000596148373, -5.822564729770655, -4.8976868220198, -4.485978541670284, -4.263265659159493, -4.128082315742123, -4.0394597857417605], [0.0, 0.0, -0.7453572968815381, -0.8954065440092176, -0.9464598126125071, -0.968801834741769, -0.9801840327648054, -0.9866112526443657], [0.4449568411386593, 0.44598163452708905, 0.4470064279155188, 0.44803122130394857, 0.4490560146923783, 0.4500808080808081, 0.45110560146923784, 0.45213039485766754, 0.4531551882460973, 0.45417998163452705  …  1.459502295684114, 1.4605270890725437, 1.4615518824609735, 1.4625766758494032, 1.463601469237833, 1.4646262626262627, 1.4656510560146925, 1.4666758494031222, 1.467700642791552, 1.4687254361799817], [0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001  …  0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001]), ACE1.Transforms.PolyTransform{Int64, Float64}(2, 1.0, 1.0), 0.65, 2.0, ACE1.OrthPolys.OneEnvelope()), ACE1.SphericalHarmonics.SHBasis{Float64}(7, ACE1.SphericalHarmonics.ALPCoefficients{Float64}([6.9437777806413e-310, 0.0, 6.94383766289244e-310, 1.9364916731037085, 0.0, 6.94383766289244e-310, 1.9720265943665387, 2.091650066335189, 6.94383766289244e-310, 6.94376252655344e-310  …  6.943762526525e-310, 6.943762526525e-310, 1.9948914348241344, 2.0155644370746373, 2.0816659994661326, 2.207940216581962, 2.4308621740219887, 2.850438562747845, 0.0, 0.0], [6.9437777806413e-310, 0.0, 6.94377457478294e-310, -0.5773502691896257, 0.0, 6.94383766289244e-310, -0.5163977794943222, -0.4472135954999579, 6.94383766289244e-310, 6.94376252655344e-310  …  6.943762526525e-310, 6.943762526525e-310, -0.5017452060042545, -0.4947274449181537, -0.47304991679126607, -0.4345240946267408, -0.3739787960033829, -0.2773500981126146, 0.0, 6.94376119432953e-310])), JuLIP.Potentials.SZList{1}(JuLIP.Chemistry.AtomicNumber[<18>]), ACE1.RPI.PSH1pBasisFcn[znlm[0|1,0,0], znlm[0|1,1,-1], znlm[0|1,1,0], znlm[0|1,1,1], znlm[0|1,2,-2], znlm[0|1,2,-1], znlm[0|1,2,0], znlm[0|1,2,1], znlm[0|1,2,2], znlm[0|1,3,-3]  …  znlm[0|6,2,-2], znlm[0|6,2,-1], znlm[0|6,2,0], znlm[0|6,2,1], znlm[0|6,2,2], znlm[0|7,0,0], znlm[0|7,1,-1], znlm[0|7,1,0], znlm[0|7,1,1], znlm[0|8,0,0]], UnitRange{Int64}[1:204]), JuLIP.Potentials.SZList{1}(JuLIP.Chemistry.AtomicNumber[<18>]), (ACE1.InnerPIBasis([1, 1, 1, 1, 1, 1, 1, 1], [1; 65; … ; 200; 204], Dict{ACE1.PIBasisFcn, Int64}(ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|5,0,0],), ACE1.RPI.PSH1pBasisFcn) => 5, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|1,0,0],), ACE1.RPI.PSH1pBasisFcn) => 1, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|3,0,0],), ACE1.RPI.PSH1pBasisFcn) => 3, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|2,0,0],), ACE1.RPI.PSH1pBasisFcn) => 2, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|6,0,0],), ACE1.RPI.PSH1pBasisFcn) => 6, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|4,0,0],), ACE1.RPI.PSH1pBasisFcn) => 4, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|7,0,0],), ACE1.RPI.PSH1pBasisFcn) => 7, ACE1.PIBasisFcn{1, ACE1.RPI.PSH1pBasisFcn}(<18>, (znlm[18|8,0,0],), ACE1.RPI.PSH1pBasisFcn) => 8), Dict{Any, Int64}(znlm[18|1,2,-1] => 6, znlm[18|3,4,2] => 136, znlm[18|2,5,-3] => 92, znlm[18|3,5,-3] => 141, znlm[18|1,6,4] => 47, znlm[18|3,4,0] => 134, znlm[18|6,2,0] => 197, znlm[18|1,6,1] => 44, znlm[18|1,3,-1] => 12, znlm[18|2,2,2] => 73…), 1:8, <18>, ACE1.DAG.CorrEvalGraph{Int64, Int64}([(1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0), (7, 0), (8, 0), (9, 0), (10, 0)  …  (195, 0), (196, 0), (197, 0), (198, 0), (199, 0), (200, 0), (201, 0), (202, 0), (203, 0), (204, 0)], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 7, 0, 0, 0, 8], 204, 204)),), ACE1.DAGEvaluator()), (
 1.0   ⋅    ⋅    ⋅    ⋅    ⋅    ⋅    ⋅ 
  ⋅   1.0   ⋅    ⋅    ⋅    ⋅    ⋅    ⋅ 
  ⋅    ⋅   1.0   ⋅    ⋅    ⋅    ⋅    ⋅ 
  ⋅    ⋅    ⋅   1.0   ⋅    ⋅    ⋅    ⋅ 
  ⋅    ⋅    ⋅    ⋅   1.0   ⋅    ⋅    ⋅ 
  ⋅    ⋅    ⋅    ⋅    ⋅   1.0   ⋅    ⋅ 
  ⋅    ⋅    ⋅    ⋅    ⋅    ⋅   1.0   ⋅ 
  ⋅    ⋅    ⋅    ⋅    ⋅    ⋅    ⋅   1.0,), (1:8,)))

Calculate descriptors

train_descriptors = [sum(compute_local_descriptors(sys, ace)) for sys in train_systems]
test_descriptors = [sum(compute_local_descriptors(sys, ace)) for sys in test_systems]
50-element Vector{Vector{Float64}}:
 [9.637553471101803, -26.675859876119294, 25.78691727257195, -0.7260195256964641, -13.70857348236253, -6.561704406553085, 33.06021355046952, -30.24060980787009]
 [8.606741075726184, -24.751126644174505, 26.38195391026354, -5.232261587024202, -10.810767041960057, -2.213398696051166, 24.448071977367757, -23.454597072226875]
 [7.995727465304647, -23.733500892883377, 26.9248885594341, -7.117859145192682, -12.273244261802004, 4.963518700123678, 17.832036907805644, -25.151129039593535]
 [7.921558789343797, -23.58423833888915, 26.884065458500764, -7.14656583017853, -12.79696931534228, 6.79961568978195, 14.664213370808248, -21.776297574689178]
 [7.792778687395339, -23.30046422469937, 26.884232957489388, -7.760354578457655, -12.395939624398983, 7.976413535211053, 12.267557941813918, -20.58907892970582]
 [8.055816010742193, -23.853435270921366, 26.784410055873394, -6.226639365418963, -13.997554168622187, 7.088120745000675, 15.671025846170275, -23.155090468194537]
 [7.891089195594272, -23.638451884555426, 27.30252260467795, -8.060238223547483, -11.2691050073774, 4.414171077602365, 17.149054753525242, -22.199255304250457]
 [7.843991305114745, -23.166992305525657, 26.124296143984463, -7.262554315502359, -9.809384249643784, 1.0726771187738446, 18.99224856756902, -20.147701488933773]
 [8.428789831091315, -24.473857013361428, 26.447337791576118, -5.093344295145077, -12.854964936126313, 1.7221806787538947, 21.859666259006485, -25.029519808795403]
 [7.824945011307422, -23.655375657128534, 27.720405358864106, -8.370591161501979, -12.366559776745678, 6.723394588454882, 16.20013992232335, -23.90051245455434]
 ⋮
 [8.409183418512061, -24.24498511478678, 26.139471465602483, -6.139348481276848, -8.845038985801576, -3.6968117629300647, 23.79414530969654, -20.813809169355256]
 [8.375082824736895, -24.393153967420385, 26.85824180986399, -7.121893705992752, -8.408923920070606, -4.001433391071884, 25.399062899399084, -22.8817636881937]
 [7.836335126129326, -23.49611206555286, 27.450551168412535, -9.307111022703898, -8.802719893673249, 2.1508145981590467, 17.634059047864138, -21.00323956189841]
 [8.132077419321302, -23.954048037457845, 27.188998382739378, -8.345139203727445, -9.294554173539824, 1.9211719239093623, 18.301253269614126, -23.362023629191842]
 [7.844879379234444, -23.3902167295274, 26.992616470830637, -8.494346757470588, -9.768153606358332, 3.299195785298064, 16.37737674228044, -20.68339583737922]
 [7.581942977993447, -23.19237551306455, 27.745277358124206, -8.928520976962973, -12.797356763908997, 8.907982073122668, 14.4935550036222, -25.109274643736363]
 [7.919360855803918, -23.61405493161248, 27.373751826381937, -9.26516559316334, -8.21170288281628, 1.6582439920755903, 15.533644510000675, -15.264249968296056]
 [8.03748824668811, -23.77963320116645, 26.9790561024858, -7.853286683110317, -9.771367473971985, 1.2472503238207042, 19.34345247204319, -20.885497245247382]
 [7.9135638083288065, -23.42297244912323, 26.521589090001623, -7.296545629726875, -10.97600265773798, 3.3388659045432942, 17.8825878468799, -22.615946287685667]

Calculate mah_dist

mean_d = mean([train_descriptors; test_descriptors])
cov_d = cov([train_descriptors; test_descriptors])
inv_cov_d = inv(cov_d)
mah_dist_train = [ sqrt( 0.5*(td - mean_d)' * inv_cov_d * (td - mean_d) )
                   for td in train_descriptors ]
mah_dist_test = [ sqrt( 0.5*(td - mean_d)' * inv_cov_d * (td - mean_d) )
                  for td in test_descriptors ]
50-element Vector{Float64}:
 3.3378724871694687
 1.9266719000759462
 1.534689333384129
 1.6511595991074999
 2.1181863590909615
 1.9193251175586432
 2.5834743178832107
 1.7993385042560726
 1.6388298161716983
 1.6660578578771954
 ⋮
 1.5688273634270469
 2.595805008639468
 1.412171689044755
 3.444453484787928
 1.6107284239217066
 1.7532542800342996
 2.494518188211018
 1.800061167821093
 1.386771947466478

Create linear basis potential and estimate β

lb = LBasisPotential(ace)
lb.β .= hcat(train_descriptors...)' \ train_energies
8-element Vector{Float64}:
 808.5425410338428
 627.5458497961872
 399.1204554158797
 210.47149519239613
  90.30276020116625
  30.16180912601384
   7.085777388514522
   0.9119833149357848

Calculate potential energy

potential_energy(test_systems[1], lb)
-37.82994931330177

This page was generated using Literate.jl.