Optimize ACE hyper-parameters: minimize force time and fitting error.
Setup experiment
Load packages.
using AtomsBase, InteratomicPotentials, PotentialLearning
using Unitful, UnitfulAtomic
using LinearAlgebra, Random, DisplayAs
using DataFrames, Hyperopt
Define paths.
base_path = haskey(ENV, "BASE_PATH") ? ENV["BASE_PATH"] : "../../"
ds_path = "$base_path/examples/data/a-HfO2/a-HfO2-300K-NVT-6000.extxyz"
res_path = "$base_path/examples/Opt-ACE-aHfO2/results/";
Load utility functions.
include("$base_path/examples/utils/utils.jl");
Create experiment folder.
run(`mkdir -p $res_path`);
Load datasets
Load atomistic dataset: atomistic configurations (atom positions, geometry, etc.) + DFT data (energies, forces, etc.)
ds = load_data(ds_path, uparse("eV"), uparse("Å"))[1:1000]; # Load first 1K samples.
Split atomistic dataset into training and test
n_train, n_test = 50, 50 # Only 50 samples per dataset are used in this example.
conf_train, conf_test = split(ds, n_train, n_test)
(DataSet{num_configs = 50}
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}
⋮
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}, DataSet{num_configs = 50}
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces}
⋮
Configuration{S, AtomsBase.FlexibleSystem{3, AtomsBase.Atom, Unitful.Quantity{Float64, 𝐋, Unitful.FreeUnits{(Å,), 𝐋, nothing}}}, Energy, Forces})
Optimize hyper-parameters
Define a custom loss function. Here, we minimize fitting error and force calculation time. Possible metrics are e_mae
, e_rmse
, e_rsq
, f_mae
, f_rmse
, f_rsq
, and time_us
.
function custom_loss(
metrics::OrderedDict
)
e_mae = metrics[:e_mae]
f_mae = metrics[:f_mae]
time_us = metrics[:time_us]
e_mae_max = 0.05 # eV/atom
f_mae_max = 0.05 # eV/Å
w_e = e_mae/e_mae_max
w_f = f_mae/f_mae_max
w_t = 1.0E-3
loss = w_e * e_mae + w_f * e_mae + w_t * time_us
return loss
end;
Define model and hyper-parameter value ranges to be optimized.
model = ACE
pars = OrderedDict( :body_order => [2, 3, 4],
:polynomial_degree => [3, 4, 5],
:rcutoff => LinRange(4, 6, 10),
:wL => LinRange(0.5, 1.5, 10),
:csp => LinRange(0.5, 1.5, 10),
:r0 => LinRange(0.5, 1.5, 10));
Use latin hypercube sampling to find the optimal hyper-parameters. Alternatively, use random sampling (sampler = RandomSampler()).
sampler = CLHSampler(dims=[Categorical(3), Categorical(3), Continuous(),
Continuous(), Continuous(), Continuous()])
iap, res = hyperlearn!(model, pars, conf_train;
n_samples = 10, sampler = sampler,
loss = custom_loss, ws = [1.0, 1.0], int = true);
E_MAE:0.196 eV/atom, F_MAE:0.303 eV/Å, Time per force per atom:65.657 µs
E_MAE:0.211 eV/atom, F_MAE:0.245 eV/Å, Time per force per atom:97.812 µs
E_MAE:0.162 eV/atom, F_MAE:0.249 eV/Å, Time per force per atom:45.498 µs
E_MAE:0.191 eV/atom, F_MAE:0.298 eV/Å, Time per force per atom:119.695 µs
E_MAE:0.204 eV/atom, F_MAE:0.295 eV/Å, Time per force per atom:120.154 µs
E_MAE:0.087 eV/atom, F_MAE:0.134 eV/Å, Time per force per atom:112.686 µs
E_MAE:0.076 eV/atom, F_MAE:0.113 eV/Å, Time per force per atom:380.355 µs
E_MAE:0.23 eV/atom, F_MAE:0.25 eV/Å, Time per force per atom:51.989 µs
E_MAE:0.078 eV/atom, F_MAE:0.09 eV/Å, Time per force per atom:848.531 µs
E_MAE:0.07 eV/atom, F_MAE:0.114 eV/Å, Time per force per atom:343.02 µs
Post-process results
Save and show results.
@save_var res_path iap.β
@save_var res_path iap.β0
@save_var res_path iap.basis
@save_dataframe res_path res
res
Row | e_mae | e_rmse | e_rsq | f_mae | f_rmse | f_rsq | time_us | body_order | polynomial_degree | rcutoff | wL | csp | r0 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Any | Any | Any | Any | Any | Any | Any | Any | Any | Any | Any | Any | Any | |
1 | 0.0699476 | 0.0887077 | 0.875851 | 0.113789 | 0.146506 | 0.941304 | 343.02 | 2.0 | 3.0 | 5.77778 | 1.5 | 0.944444 | 0.944444 |
2 | 0.0764879 | 0.0973517 | 0.850477 | 0.113466 | 0.145909 | 0.941782 | 380.355 | 2.0 | 5.0 | 4.88889 | 0.611111 | 1.16667 | 0.611111 |
3 | 0.0783818 | 0.0926494 | 0.864573 | 0.0896341 | 0.114964 | 0.963857 | 848.531 | 2.0 | 3.0 | 5.77778 | 1.5 | 0.944444 | 0.944444 |
4 | 0.0867853 | 0.110484 | 0.807417 | 0.134411 | 0.174444 | 0.916784 | 112.686 | 3.0 | 3.0 | 4.0 | 0.722222 | 1.27778 | 1.16667 |
5 | 0.162323 | 0.205959 | 0.330761 | 0.248512 | 0.321802 | 0.716811 | 45.4981 | 3.0 | 4.0 | 5.55556 | 0.833333 | 1.38889 | 1.27778 |
6 | 0.191251 | 0.246166 | 0.0439552 | 0.297826 | 0.389138 | 0.585902 | 119.695 | 3.0 | 4.0 | 5.55556 | 0.833333 | 1.38889 | 1.27778 |
7 | 0.1962 | 0.255247 | -0.0278851 | 0.303027 | 0.394134 | 0.5752 | 65.6567 | 2.0 | 4.0 | 4.44444 | 1.27778 | 1.05556 | 1.5 |
8 | 0.203927 | 0.250195 | 0.0124086 | 0.295062 | 0.384647 | 0.595404 | 120.154 | 3.0 | 3.0 | 4.0 | 0.722222 | 1.27778 | 1.16667 |
9 | 0.210765 | 0.264187 | -0.101143 | 0.244558 | 0.31363 | 0.731012 | 97.8123 | 2.0 | 4.0 | 4.44444 | 1.27778 | 1.05556 | 1.5 |
10 | 0.229571 | 0.281583 | -0.250938 | 0.249951 | 0.324096 | 0.71276 | 51.9888 | 2.0 | 5.0 | 4.88889 | 0.611111 | 1.16667 | 0.611111 |
Plot error vs time.
err_time = plot_err_time(res)
@save_fig res_path err_time
DisplayAs.PNG(err_time)
![Example block output](0c8593f5.png)
This page was generated using Literate.jl.