Skip to content

Commit

Permalink
Added better compatibility between soap and soap_turbo and streamline…
Browse files Browse the repository at this point in the history
…d the definition of hyperparameters in the soap_turbo string
  • Loading branch information
mcaroba committed Oct 15, 2023
1 parent 6c3375f commit c05aff9
Showing 1 changed file with 220 additions and 27 deletions.
247 changes: 220 additions & 27 deletions descriptors.f95
Original file line number Diff line number Diff line change
Expand Up @@ -3236,8 +3236,6 @@ subroutine soap_turbo_initialise(this,args_str,error)

type(Dictionary) :: params

logical :: has_atom_sigma_angular

integer :: l, k, i, j, m, n, n_nonzero
real(dp) :: fact, fact1, fact2, ppi, atom_sigma_radial_normalised, cutoff_hard,&
s2, I_n, N_n, N_np1, N_np2, I_np1, I_np2, C2
Expand All @@ -3248,15 +3246,53 @@ subroutine soap_turbo_initialise(this,args_str,error)
real(dp), dimension(:,:), allocatable :: sqrt_overlap, u, v
real(dp), parameter :: sqrt_two = sqrt(2.0_dp)

! Variables for equivalences with regular SOAP
logical :: is_n_max_set, is_cutoff_set, is_cutoff_transition_width_set, &
is_atom_sigma_r_set, is_atom_sigma_t_set, is_atom_sigma_r_scaling_set, &
is_atom_sigma_t_scaling_set, is_central_weight_set, is_amplitude_scaling_set, &
is_atom_sigma_set, set_sigma_t_to_r, is_atom_sigma_scaling_set, set_sigma_t_to_r_scaling
character(len=STRING_LENGTH) :: var_set

is_n_max_set = .false.
is_cutoff_set = .false.
is_cutoff_transition_width_set = .false.
is_atom_sigma_set = .false.
set_sigma_t_to_r = .false.
is_atom_sigma_scaling_set = .false.
set_sigma_t_to_r_scaling = .false.
is_atom_sigma_r_set = .false.
is_atom_sigma_t_set = .false.
is_atom_sigma_r_scaling_set = .false.
is_atom_sigma_t_scaling_set = .false.
is_central_weight_set = .false.
is_amplitude_scaling_set = .false.

INIT_ERROR(error)

call finalise(this)

call initialise(params)

! Look for those parameters defined as in regular SOAP
if( index(args_str,"cutoff=") /= 0 .and. index(args_str,"rcut_hard=") == 0 )then
is_cutoff_set = .true.
call param_register(params, 'cutoff', PARAM_MANDATORY, this%rcut_hard, help_string="TODO")
else
call param_register(params, 'rcut_hard', PARAM_MANDATORY, this%rcut_hard, help_string="Hard cutoff")
end if
if( index(args_str,"rcut_soft=") == 0 )then
is_cutoff_transition_width_set = .true.
! We store the transition width in rcut_soft, then fix it later
call param_register(params, 'cutoff_transition_width', "0.5", this%rcut_soft, help_string="TODO")
else
call param_register(params, 'rcut_soft', PARAM_MANDATORY, this%rcut_soft, help_string="Soft cutoff")
end if

! Look for the rest of scalar parameters
call param_register(params, 'l_max', PARAM_MANDATORY, this%l_max, help_string="Angular basis resolution")
call param_register(params, 'n_species', '1', this%n_species, help_string="Number of species for the descriptor")
call param_register(params, 'rcut_hard', PARAM_MANDATORY, this%rcut_hard, help_string="Hard cutoff")
call param_register(params, 'rcut_soft', PARAM_MANDATORY, this%rcut_soft, help_string="Soft cutoff")

! These parameters are not mandatory; these are sensible defaults
call param_register(params, 'nf', "4.0", this%nf, help_string="TODO")
call param_register(params, 'radial_enhancement', "0", this%radial_enhancement, help_string="TODO")
call param_register(params, 'basis', "poly3", this%basis, help_string="poly3 or poly3gauss")
Expand All @@ -3271,6 +3307,18 @@ subroutine soap_turbo_initialise(this,args_str,error)

call finalise(params)

! Fix the soft cutoff if needed
if( is_cutoff_transition_width_set )then
this%rcut_soft = this%rcut_hard - this%rcut_soft
end if

! All of these hyperparameters are species-dependent and thus given as arrays
! We try to infer intended use from the regular SOAP equivalent parameters, e.g., we
! infer alpha_max(1:n_species) = n_max, UNLESS the array definitions are provided
! explicitly, in which case explicit definitions ALWAYS override implicit definitions,
! e.g., if both n_max and alpha_max are defined, the alpha_max definition will
! override the n_max definition

allocate(this%atom_sigma_r(this%n_species))
allocate(this%atom_sigma_r_scaling(this%n_species))
allocate(this%atom_sigma_t(this%n_species))
Expand All @@ -3280,48 +3328,188 @@ subroutine soap_turbo_initialise(this,args_str,error)
allocate(this%alpha_max(this%n_species))
allocate(this%species_Z(this%n_species))

! central_weight is special because regular SOAP and soap_turbo use the same keyword
call initialise(params)
! If it's set as a vector
if( index(args_str,"central_weight={") /= 0 )then
if( this%n_species == 1 )then
is_central_weight_set = .true.
call param_register(params, 'central_weight', "1.0", this%central_weight(1), &
help_string="Weight of central atom in environment")
end if
! If it's set as a scalar or not set
else
is_central_weight_set = .true.
call param_register(params, 'central_weight', "1.0", this%central_weight(1), &
help_string="Weight of central atom in environment")
end if
call finalise(params)
if( is_central_weight_set )then
this%central_weight = this%central_weight(1)
end if

! Now we set the soap_turbo hypers with the explicit array definitions OR use the implicit definitions
! to set them
call initialise(params)
if(this%n_species == 1) then
call param_register(params, 'alpha_max', PARAM_MANDATORY, this%alpha_max(1), &
help_string="Radial basis resolution for each species")
! alpha_max
if( index(args_str,"n_max=") /= 0 .and. index(args_str,"alpha_max=") == 0 )then
is_n_max_set = .true.
call param_register(params, 'n_max', PARAM_MANDATORY, this%alpha_max(1), help_string="TODO")
else
if( this%n_species == 1 )then
call param_register(params, 'alpha_max', PARAM_MANDATORY, this%alpha_max(1), &
help_string="Radial basis resolution for each species")
else
call param_register(params, 'alpha_max', '//MANDATORY//', this%alpha_max, &
help_string="Radial basis resultion for each species")
end if
end if
! atom_sigma_r
if( index(args_str,"atom_sigma_r={") /= 0 )then
if( this%n_species == 1 )then
call param_register(params, 'atom_sigma_r', PARAM_MANDATORY, this%atom_sigma_r(1), &
help_string="Width of atomic Gaussians for soap-type descriptors in the radial direction")
else
call param_register(params, 'atom_sigma_r', '//MANDATORY//', this%atom_sigma_r, &
help_string="Width of atomic Gaussians for soap-type descriptors in the radial direction")
end if
else if( index(args_str,"atom_sigma_r=") /= 0 )then
is_atom_sigma_r_set = .true.
call param_register(params, 'atom_sigma_r', PARAM_MANDATORY, this%atom_sigma_r(1), &
help_string="Width of atomic Gaussians for soap-type descriptors in the radial direction")
call param_register(params, 'atom_sigma_r_scaling', PARAM_MANDATORY, this%atom_sigma_r_scaling(1), &
help_string="Scaling rate of radial sigma: scaled as a function of neighbour distance")
else
is_atom_sigma_r_set = .true.
is_atom_sigma_set = .true.
call param_register(params, 'atom_sigma', PARAM_MANDATORY, this%atom_sigma_r(1), &
help_string="Width of atomic Gaussians for soap-type descriptors")
end if
! atom_sigma_t
if( index(args_str,"atom_sigma_t={") /= 0 )then
if( this%n_species == 1 )then
call param_register(params, 'atom_sigma_t', PARAM_MANDATORY, this%atom_sigma_t(1), &
help_string="Width of atomic Gaussians for soap-type descriptors in the angular direction")
else
call param_register(params, 'atom_sigma_t', '//MANDATORY//', this%atom_sigma_t, &
help_string="Width of atomic Gaussians for soap-type descriptors in the angular direction")
end if
else if( index(args_str,"atom_sigma_t=") /= 0 )then
is_atom_sigma_t_set = .true.
call param_register(params, 'atom_sigma_t', PARAM_MANDATORY, this%atom_sigma_t(1), &
help_string="Width of atomic Gaussians for soap-type descriptors in the angular direction")
else
is_atom_sigma_t_set = .true.
if( is_atom_sigma_set )then
set_sigma_t_to_r = .true.
else
call param_register(params, 'atom_sigma', PARAM_MANDATORY, this%atom_sigma_t(1), &
help_string="Width of atomic Gaussians for soap-type descriptors")
end if
end if
! atom_sigma_r_scaling
if( index(args_str,"atom_sigma_r_scaling={") /= 0 )then
if( this%n_species == 1 )then
call param_register(params, 'atom_sigma_r_scaling', PARAM_MANDATORY, this%atom_sigma_r_scaling(1), &
help_string="Scaling rate of radial sigma: scaled as a function of neighbour distance")
else
call param_register(params, 'atom_sigma_r_scaling', '//MANDATORY//', this%atom_sigma_r_scaling, &
help_string="Scaling rate of radial sigma: scaled as a function of neighbour distance")
end if
else if( index(args_str,"atom_sigma_r_scaling=") /= 0 )then
is_atom_sigma_r_scaling_set = .true.
call param_register(params, 'atom_sigma_r_scaling', PARAM_MANDATORY, this%atom_sigma_r_scaling(1), &
help_string="Scaling rate of radial sigma: scaled as a function of neighbour distance")
else
is_atom_sigma_r_scaling_set = .true.
is_atom_sigma_scaling_set = .true.
call param_register(params, 'atom_sigma_scaling', "0.0", this%atom_sigma_r_scaling(1), &
help_string="Scaling rate of atom sigma: scaled as a function of neighbour distance")
end if
! atom_sigma_t_scaling
if( index(args_str,"atom_sigma_t_scaling={") /= 0 )then
if( this%n_species == 1 )then
call param_register(params, 'atom_sigma_t_scaling', PARAM_MANDATORY, this%atom_sigma_t_scaling(1), &
help_string="Scaling rate of angular sigma: scaled as a function of neighbour distance")
else
call param_register(params, 'atom_sigma_t_scaling', '//MANDATORY//', this%atom_sigma_t_scaling, &
help_string="Scaling rate of angular sigma: scaled as a function of neighbour distance")
end if
else if( index(args_str,"atom_sigma_t_scaling=") /= 0 )then
is_atom_sigma_t_scaling_set = .true.
call param_register(params, 'atom_sigma_t_scaling', PARAM_MANDATORY, this%atom_sigma_t_scaling(1), &
help_string="Scaling rate of angular sigma: scaled as a function of neighbour distance")
else
is_atom_sigma_t_scaling_set = .true.
if( is_atom_sigma_scaling_set )then
set_sigma_t_to_r_scaling = .true.
else
call param_register(params, 'atom_sigma_scaling', "0.0", this%atom_sigma_t_scaling(1), &
help_string="Scaling rate of atom sigma: scaled as a function of neighbour distance")
end if
end if
! amplitude_scaling
if( index(args_str,"amplitude_scaling={") /= 0 )then
if( this%n_species == 1 )then
call param_register(params, 'amplitude_scaling', PARAM_MANDATORY, this%amplitude_scaling(1), &
help_string="Scaling rate of amplitude: scaled as an inverse function of neighbour distance")
else
call param_register(params, 'amplitude_scaling', '//MANDATORY//', this%amplitude_scaling, &
help_string="Scaling rate of amplitude: scaled as an inverse function of neighbour distance")
end if
else if( index(args_str,"amplitude_scaling=") /= 0 )then
is_amplitude_scaling_set = .true.
call param_register(params, 'amplitude_scaling', PARAM_MANDATORY, this%amplitude_scaling(1), &
help_string="Scaling rate of amplitude: scaled as an inverse function of neighbour distance")
call param_register(params, 'central_weight', PARAM_MANDATORY, this%central_weight(1), &
help_string="Weight of central atom in environment")
else
is_amplitude_scaling_set = .true.
call param_register(params, 'amplitude_scaling', "1.0", this%amplitude_scaling(1), &
help_string="Scaling rate of amplitude: scaled as an inverse function of neighbour distance")
end if
! species_Z
if( this%n_species == 1 )then
call param_register(params, 'species_Z', PARAM_MANDATORY, this%species_Z(1), &
help_string="Atomic number of species, including the central atom")
else
call param_register(params, 'alpha_max', '//MANDATORY//', this%alpha_max, &
help_string="Radial basis resultion for each species")
call param_register(params, 'atom_sigma_r', '//MANDATORY//', this%atom_sigma_r, &
help_string="Width of atomic Gaussians for soap-type descriptors in the radial direction")
call param_register(params, 'atom_sigma_r_scaling', '//MANDATORY//', this%atom_sigma_r_scaling, &
help_string="Scaling rate of radial sigma: scaled as a function of neighbour distance")
call param_register(params, 'atom_sigma_t', '//MANDATORY//', this%atom_sigma_t, &
help_string="Width of atomic Gaussians for soap-type descriptors in the angular direction")
call param_register(params, 'atom_sigma_t_scaling', '//MANDATORY//', this%atom_sigma_t_scaling, &
help_string="Scaling rate of angular sigma: scaled as a function of neighbour distance")
call param_register(params, 'amplitude_scaling', '//MANDATORY//', this%amplitude_scaling, &
help_string="Scaling rate of amplitude: scaled as an inverse function of neighbour distance")
call param_register(params, 'central_weight', '//MANDATORY//', this%central_weight, &
help_string="Weight of central atom in environment")
call param_register(params, 'species_Z', '//MANDATORY//', this%species_Z, &
help_string="Atomic number of species, including the central atom")
endif
end if
! central_weight
if( .not. is_central_weight_set )then
call param_register(params, 'central_weight', '//MANDATORY//', this%central_weight, &
help_string="Weight of central atom in environment")
end if


if (.not. param_read_line(params, args_str, ignore_unknown=.true.,task='soap_turbo_initialise args_str')) then
RAISE_ERROR("soap_turbo_initialise failed to parse args_str='"//trim(args_str)//"'", error)
endif
call finalise(params)

if( is_n_max_set )then
this%alpha_max = this%alpha_max(1)
end if
if( is_atom_sigma_r_set )then
this%atom_sigma_r = this%atom_sigma_r(1)
end if
if( is_atom_sigma_t_set )then
this%atom_sigma_t = this%atom_sigma_t(1)
end if
if( is_atom_sigma_r_scaling_set )then
this%atom_sigma_r_scaling = this%atom_sigma_r_scaling(1)
end if
if( is_atom_sigma_t_scaling_set )then
this%atom_sigma_t_scaling = this%atom_sigma_t_scaling(1)
end if
if( is_amplitude_scaling_set )then
this%amplitude_scaling = this%amplitude_scaling(1)
end if
if( set_sigma_t_to_r )then
this%atom_sigma_t = this%atom_sigma_r
end if
if( set_sigma_t_to_r_scaling )then
this%atom_sigma_t_scaling = this%atom_sigma_r_scaling
end if


! Here we read in the compression information from a file (compress_file) or rely on a keyword provided
! by the user (compress_mode) which leads to a predefined recipe to compress the soap_turbo descriptor
! The file always takes precedence over the keyword.
Expand Down Expand Up @@ -3530,7 +3718,12 @@ subroutine descriptor_str_add_species(this,species,descriptor_str,error)
case(DT_SOAP_EXPRESS)
RAISE_ERROR("descriptor_str_add_species: no recipe for "//my_descriptor_type//" yet.",error)
case(DT_SOAP_TURBO)
RAISE_ERROR("descriptor_str_add_species: no recipe for "//my_descriptor_type//" yet.",error)
! RAISE_ERROR("descriptor_str_add_species: no recipe for "//my_descriptor_type//" yet.",error)
allocate(descriptor_str(n_species))
do i = 1, n_species
! descriptor_str(i) = trim(this)//" n_species="//n_species//" Z="//species(i)//" species_Z={"//species//"}"
descriptor_str(i) = trim(this)//" n_species="//n_species//" species_Z={"//species//"} central_index="//i
enddo
case default
RAISE_ERROR("descriptor_str_add_species: unknown descriptor type "//my_descriptor_type,error)
endselect
Expand Down

0 comments on commit c05aff9

Please sign in to comment.