-
Notifications
You must be signed in to change notification settings - Fork 8
/
pair_allegro.h
80 lines (57 loc) · 2.13 KB
/
pair_allegro.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
/* -*- c++ -*- ----------------------------------------------------------
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
http://lammps.sandia.gov, Sandia National Laboratories
Steve Plimpton, [email protected]
Copyright (2003) Sandia Corporation. Under the terms of Contract
DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
certain rights in this software. This software is distributed under
the GNU General Public License.
See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */
#ifdef PAIR_CLASS
PairStyle(allegro,PairAllegro<lowhigh>)
PairStyle(allegro3232,PairAllegro<lowlow>)
PairStyle(allegro6464,PairAllegro<highhigh>)
PairStyle(allegro3264,PairAllegro<lowhigh>)
PairStyle(allegro6432,PairAllegro<highlow>)
#else
#ifndef LMP_PAIR_ALLEGRO_H
#define LMP_PAIR_ALLEGRO_H
#include "pair.h"
#include <torch/torch.h>
#include <vector>
#include <type_traits>
#include <map>
#include <string>
enum Precision {lowlow, highhigh, lowhigh, highlow};
namespace LAMMPS_NS {
template<Precision precision>
class PairAllegro : public Pair {
public:
PairAllegro(class LAMMPS *);
virtual ~PairAllegro();
virtual void compute(int, int);
void settings(int, char **);
virtual void coeff(int, char **);
virtual double init_one(int, int);
virtual void init_style();
void allocate();
double cutoff;
torch::jit::Module model;
torch::Device device = torch::kCPU;
std::vector<int> type_mapper;
int batch_size = -1;
typedef typename std::conditional_t<precision==lowlow || precision==lowhigh, float, double> inputtype;
typedef typename std::conditional_t<precision==lowlow || precision==highlow, float, double> outputtype;
torch::ScalarType inputtorchtype = torch::CppTypeToScalarType<inputtype>();
torch::ScalarType outputtorchtype = torch::CppTypeToScalarType<outputtype>();
std::vector<std::string> custom_output_names;
std::map<std::string, torch::Tensor> custom_output;
void add_custom_output(std::string);
protected:
int debug_mode = 0;
double** cutoff_matrix;
};
}
#endif
#endif