Skip to content

Commit

Permalink
Add checkpoints 16+mb support to MongoDB (#396)
Browse files Browse the repository at this point in the history
* small fix

* small fix

* Add checkpoints 16+mb support  to MongoDB

* Update mongo.py

* Update __version__.py
  • Loading branch information
vaklyuenkov authored and Scitator committed Oct 3, 2019
1 parent b1e0132 commit 7be7526
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
2 changes: 1 addition & 1 deletion catalyst/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "19.09.4"
__version__ = "19.10"
44 changes: 26 additions & 18 deletions catalyst/rl/db/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime

import pymongo
import gridfs
import safitty

from catalyst.rl import utils
Expand All @@ -26,16 +27,18 @@ def __init__(

self._trajectory_collection = self._shared_db["trajectories"]
self._raw_trajectory_collection = self._shared_db["raw_trajectories"]
self._checkpoints_collection = self._agent_db["checkpoints"]
self._messages_collection = self._agent_db["messages"]
self._checkpoint_collection =\
gridfs.GridFS(self._agent_db, collection="checkpoints")
self._message_collection = self._agent_db["messages"]

self._last_datetime = datetime.datetime.min

self._epoch = 0
self._sync_epoch = sync_epoch

def _set_flag(self, key, value):
try:
self._messages_collection.replace_one(
self._message_collection.replace_one(
{"key": key},
{"key": key, "value": value},
upsert=True
Expand All @@ -46,7 +49,7 @@ def _set_flag(self, key, value):

def _get_flag(self, key, default=None):
try:
flag_obj = self._messages_collection.find_one(
flag_obj = self._message_collection.find_one(
{"key": {"$eq": key}}
)
except pymongo.errors.AutoReconnect:
Expand Down Expand Up @@ -124,8 +127,8 @@ def get_trajectory(self, index=None):
self._last_datetime = trajectory_obj["date"]

trajectory, trajectory_epoch = \
utils.unpack(
trajectory_obj["trajectory"]), trajectory_obj["epoch"]
utils.unpack(trajectory_obj["trajectory"]), \
trajectory_obj["epoch"]
if self._sync_epoch and self._epoch != trajectory_epoch:
trajectory = None
else:
Expand All @@ -146,37 +149,42 @@ def put_checkpoint(self, checkpoint, epoch):
try:
self._epoch = epoch
checkpoint_ = utils.pack(checkpoint)
self._checkpoints_collection.replace_one(
{"prefix": "checkpoint"}, {
"checkpoint": checkpoint_,
"prefix": "checkpoint",
"epoch": self._epoch
},
upsert=True
if self._checkpoint_collection.exists({"filename": "checkpoint"}):
self.del_checkpoint()

self._checkpoint_collection.put(
checkpoint_,
encoding="ascii",
filename="checkpoint",
epoch=self._epoch
)

except pymongo.errors.AutoReconnect:
time.sleep(self._reconnect_timeout)
return self.put_checkpoint(checkpoint, epoch)

def get_checkpoint(self):
try:
checkpoint_obj = self._checkpoints_collection.find_one(
{"prefix": "checkpoint"}
checkpoint_obj = self._checkpoint_collection.find_one(
{"filename": "checkpoint"}
)
except pymongo.errors.AutoReconnect:
time.sleep(self._reconnect_timeout)
return self.get_checkpoint()

if checkpoint_obj is not None:
checkpoint = checkpoint_obj.get("checkpoint")
self._epoch = checkpoint_obj["epoch"]
checkpoint = checkpoint_obj.read().decode("ascii")
self._epoch = checkpoint_obj.epoch
checkpoint = utils.unpack(checkpoint)
else:
checkpoint = None
return checkpoint

def del_checkpoint(self):
self._checkpoints_collection.delete_one({"prefix": "checkpoint"})
id_ = self._checkpoint_collection.find_one(
{"filename": "checkpoint"}
)._id
self._checkpoint_collection.delete(id_)


__all__ = ["MongoDB"]

0 comments on commit 7be7526

Please sign in to comment.