From 7be7526007380d88dab646a935ebad37fa6a710d Mon Sep 17 00:00:00 2001 From: Klyuenkov Vladimir Date: Thu, 3 Oct 2019 11:50:11 +0300 Subject: [PATCH] Add checkpoints 16+mb support to MongoDB (#396) * small fix * small fix * Add checkpoints 16+mb support to MongoDB * Update mongo.py * Update __version__.py --- catalyst/__version__.py | 2 +- catalyst/rl/db/mongo.py | 44 ++++++++++++++++++++++++----------------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/catalyst/__version__.py b/catalyst/__version__.py index 5ecae5441a..c2c3b8e3b5 100644 --- a/catalyst/__version__.py +++ b/catalyst/__version__.py @@ -1 +1 @@ -__version__ = "19.09.4" +__version__ = "19.10" diff --git a/catalyst/rl/db/mongo.py b/catalyst/rl/db/mongo.py index fe5d7b73c4..461d0fa78d 100644 --- a/catalyst/rl/db/mongo.py +++ b/catalyst/rl/db/mongo.py @@ -2,6 +2,7 @@ import datetime import pymongo +import gridfs import safitty from catalyst.rl import utils @@ -26,8 +27,10 @@ def __init__( self._trajectory_collection = self._shared_db["trajectories"] self._raw_trajectory_collection = self._shared_db["raw_trajectories"] - self._checkpoints_collection = self._agent_db["checkpoints"] - self._messages_collection = self._agent_db["messages"] + self._checkpoint_collection =\ + gridfs.GridFS(self._agent_db, collection="checkpoints") + self._message_collection = self._agent_db["messages"] + self._last_datetime = datetime.datetime.min self._epoch = 0 @@ -35,7 +38,7 @@ def __init__( def _set_flag(self, key, value): try: - self._messages_collection.replace_one( + self._message_collection.replace_one( {"key": key}, {"key": key, "value": value}, upsert=True @@ -46,7 +49,7 @@ def _set_flag(self, key, value): def _get_flag(self, key, default=None): try: - flag_obj = self._messages_collection.find_one( + flag_obj = self._message_collection.find_one( {"key": {"$eq": key}} ) except pymongo.errors.AutoReconnect: @@ -124,8 +127,8 @@ def get_trajectory(self, index=None): self._last_datetime = trajectory_obj["date"] trajectory, trajectory_epoch = \ - utils.unpack( - trajectory_obj["trajectory"]), trajectory_obj["epoch"] + utils.unpack(trajectory_obj["trajectory"]), \ + trajectory_obj["epoch"] if self._sync_epoch and self._epoch != trajectory_epoch: trajectory = None else: @@ -146,37 +149,42 @@ def put_checkpoint(self, checkpoint, epoch): try: self._epoch = epoch checkpoint_ = utils.pack(checkpoint) - self._checkpoints_collection.replace_one( - {"prefix": "checkpoint"}, { - "checkpoint": checkpoint_, - "prefix": "checkpoint", - "epoch": self._epoch - }, - upsert=True + if self._checkpoint_collection.exists({"filename": "checkpoint"}): + self.del_checkpoint() + + self._checkpoint_collection.put( + checkpoint_, + encoding="ascii", + filename="checkpoint", + epoch=self._epoch ) + except pymongo.errors.AutoReconnect: time.sleep(self._reconnect_timeout) return self.put_checkpoint(checkpoint, epoch) def get_checkpoint(self): try: - checkpoint_obj = self._checkpoints_collection.find_one( - {"prefix": "checkpoint"} + checkpoint_obj = self._checkpoint_collection.find_one( + {"filename": "checkpoint"} ) except pymongo.errors.AutoReconnect: time.sleep(self._reconnect_timeout) return self.get_checkpoint() if checkpoint_obj is not None: - checkpoint = checkpoint_obj.get("checkpoint") - self._epoch = checkpoint_obj["epoch"] + checkpoint = checkpoint_obj.read().decode("ascii") + self._epoch = checkpoint_obj.epoch checkpoint = utils.unpack(checkpoint) else: checkpoint = None return checkpoint def del_checkpoint(self): - self._checkpoints_collection.delete_one({"prefix": "checkpoint"}) + id_ = self._checkpoint_collection.find_one( + {"filename": "checkpoint"} + )._id + self._checkpoint_collection.delete(id_) __all__ = ["MongoDB"]