-
Notifications
You must be signed in to change notification settings - Fork 17
/
utils.py
25 lines (21 loc) · 823 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os
from openmm import Platform
def get_platform():
os_platform = os.getenv('PLATFORM')
if os_platform:
platform = Platform.getPlatformByName(os_platform)
else:
# work out the fastest platform
speed = 0
for i in range(Platform.getNumPlatforms()):
p = Platform.getPlatform(i)
# print(p.getName(), p.getSpeed())
if p.getSpeed() > speed:
platform = p
speed = p.getSpeed()
print('Using platform', platform.getName())
# if it's GPU platform set the precision to mixed
if platform.getName() == 'CUDA' or platform.getName() == 'OpenCL':
platform.setPropertyDefaultValue('Precision', 'mixed')
print('Set precision for platform', platform.getName(), 'to mixed')
return platform