Skip to content

Commit

Permalink
Merge pull request #837 from microsoft/malmoenvvideo
Browse files Browse the repository at this point in the history
add video recording
  • Loading branch information
AndKram authored Sep 5, 2019
2 parents 29e1c05 + 1b29f67 commit 5d1d91f
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 5 deletions.
19 changes: 15 additions & 4 deletions MalmoEnv/malmoenv/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class Env:
def __init__(self, reshape=False):
self.action_space = None
self.observation_space = None
self.metadata = {'render.modes': ['rgb_array']}
self.xml = None
self.integratedServerPort = 0
self.role = 0
Expand All @@ -104,11 +105,12 @@ def __init__(self, reshape=False):
self.height = 0
self.depth = 0
self.reshape = reshape
self.last_obs = None

def init(self, xml, port, server=None,
server2=None, port2=None,
role=0, exp_uid=None, episode=0,
action_filter=None, resync=0, step_options=0, action_space=None):
action_filter=None, resync=0, step_options=0, action_space=None, reshape=False):
""""Initialize a Malmo environment.
xml - the mission xml.
port - the MalmoEnv service's port.
Expand Down Expand Up @@ -206,6 +208,7 @@ def init(self, xml, port, server=None,
})
self.xml.insert(2, e)

self.reshape = reshape
video_producers = self.xml.findall('.//' + self.ns + 'VideoProducer')
assert len(video_producers) == self.agent_count
video_producer = video_producers[self.role]
Expand Down Expand Up @@ -236,6 +239,7 @@ def reset(self):

@retry
def _start_up(self):
self.last_obs = None
self.resets += 1
if self.role != 0:
self._find_server()
Expand Down Expand Up @@ -275,7 +279,7 @@ def _peek_obs(self):
obs = np.zeros(self.height * self.width * self.depth, dtype=np.uint8)
elif self.reshape:
obs = obs.reshape((self.height, self.width, self.depth)).astype(np.uint8)

self.last_obs = obs
return obs

def _quit_episode(self):
Expand All @@ -284,9 +288,15 @@ def _quit_episode(self):
ok, = struct.unpack('!I', reply)
return ok != 0

def render(self):
def render(self, mode=None):
"""gym api render"""
pass
if self.last_obs is None:
if self.reshape:
self.last_obs = np.zeros((self.height, self.width, self.depth), dtype=np.uint8)
else:
self.last_obs = np.zeros(self.height * self.width * self.depth, dtype=np.uint8)

return np.flipud(self.last_obs)

def seed(self):
pass
Expand Down Expand Up @@ -336,6 +346,7 @@ def step(self, action):
obs = np.zeros((self.height, self.width, self.depth), dtype=np.uint8)
else:
obs = obs.reshape((self.height, self.width, self.depth)).astype(np.uint8)
self.last_obs = obs

return obs, reward, self.done, info

Expand Down
2 changes: 1 addition & 1 deletion MalmoEnv/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

setuptools.setup(
name="malmoenv",
version="0.0.6",
version="0.0.7",
author="Andre Kramer",
author_email="[email protected]",
description="A gym environemnt for Malmo",
Expand Down
84 changes: 84 additions & 0 deletions MalmoEnv/video_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# ------------------------------------------------------------------------------------------------
# Copyright (c) 2018 Microsoft Corporation
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ------------------------------------------------------------------------------------------------

import malmoenv
import argparse
from pathlib import Path
import time
import gym
from gym.wrappers.monitoring.video_recorder import VideoRecorder

import logging
logging.basicConfig(level=logging.DEBUG)

if __name__ == '__main__':

parser = argparse.ArgumentParser(description='malmovnv test')
parser.add_argument('--mission', type=str, default='missions/mobchase_single_agent.xml', help='the mission xml')
parser.add_argument('--port', type=int, default=9000, help='the mission server port')
parser.add_argument('--server', type=str, default='127.0.0.1', help='the mission server DNS or IP address')
parser.add_argument('--port2', type=int, default=None, help="(Multi-agent) role N's mission port. Defaults to server port.")
parser.add_argument('--server2', type=str, default=None, help="(Multi-agent) role N's server DNS or IP")
parser.add_argument('--episodes', type=int, default=1, help='the number of resets to perform - default is 1')
parser.add_argument('--episode', type=int, default=0, help='the start episode - default is 0')
parser.add_argument('--role', type=int, default=0, help='the agent role - defaults to 0')
parser.add_argument('--episodemaxsteps', type=int, default=0, help='max number of steps per episode')
parser.add_argument('--resync', type=int, default=0, help='exit and re-sync every N resets'
' - default is 0 meaning never.')
parser.add_argument('--experimentUniqueId', type=str, default='test1', help="the experiment's unique id.")
parser.add_argument('--video_path', type=str, default="video.mp4", help="Optional video path.")
args = parser.parse_args()
if args.server2 is None:
args.server2 = args.server

xml = Path(args.mission).read_text()
env = malmoenv.make()

env.init(xml, args.port,
server=args.server,
server2=args.server2, port2=args.port2,
role=args.role,
exp_uid=args.experimentUniqueId,
episode=args.episode,
resync=args.resync,
reshape=True)

rec = VideoRecorder(env, args.video_path)

for i in range(args.episodes):
print("reset " + str(i))
obs = env.reset()
rec.capture_frame()

steps = 0
done = False
while not done and (args.episodemaxsteps <= 0 or steps < args.episodemaxsteps):
action = env.action_space.sample()
rec.capture_frame()
obs, reward, done, info = env.step(action)
steps += 1
print("reward: " + str(reward))
# print("done: " + str(done))
print("obs: " + str(obs))
# print("info" + info)

time.sleep(.05)

rec.close()
env.close()

0 comments on commit 5d1d91f

Please sign in to comment.