diff --git a/python/mxboard/__init__.py b/python/mxboard/__init__.py index 656c5d8..f51e535 100644 --- a/python/mxboard/__init__.py +++ b/python/mxboard/__init__.py @@ -20,4 +20,4 @@ from __future__ import absolute_import from .writer import SummaryWriter -__version__ = '0.1.0' +__version__ = '0.1.1' diff --git a/python/mxboard/event_file_writer.py b/python/mxboard/event_file_writer.py index e23c14f..75522b7 100644 --- a/python/mxboard/event_file_writer.py +++ b/python/mxboard/event_file_writer.py @@ -130,8 +130,15 @@ def __init__(self, logdir, max_queue=10, flush_secs=120, filename_suffix=None, v the event file: """ self._logdir = logdir - if not os.path.exists(self._logdir): - os.makedirs(self._logdir) + parse_result = six.moves.urllib.parse.urlparse(self._logdir) + if parse_result.scheme == '': + if not os.path.exists(self._logdir): + os.makedirs(self._logdir) + elif parse_result.scheme in ('hdfs', 'viewfs'): + import pyarrow.fs + hdfs = pyarrow.fs.HadoopFileSystem(host='default', port=0) + hdfs.create_dir(parse_result.path) + self._event_queue = six.moves.queue.Queue(max_queue) self._ev_writer = EventsWriter(os.path.join(self._logdir, "events"), verbose=verbose) self._flush_secs = flush_secs diff --git a/python/mxboard/record_writer.py b/python/mxboard/record_writer.py index f4609e5..7f1b7c9 100644 --- a/python/mxboard/record_writer.py +++ b/python/mxboard/record_writer.py @@ -18,6 +18,7 @@ """Writer for writing events to the event file.""" from __future__ import absolute_import +import six import struct from ._crc32c import crc32c @@ -37,7 +38,14 @@ class RecordWriter(object): def __init__(self, path): self._writer = None try: - self._writer = open(path, 'wb') + parse_result = six.moves.urllib.parse.urlparse(path) + if parse_result.scheme == '': + self._writer = open(path, 'wb') + elif parse_result.scheme in ('hdfs', 'viewfs'): + import pyarrow.fs + # Use fs.defaultFS from core-site.xml + hdfs = pyarrow.fs.HadoopFileSystem(host='default', port=0) + self._writer = hdfs.open_output_stream(parse_result.path) except (OSError, IOError) as err: raise ValueError('failed to open file {}: {}'.format(path, str(err))) diff --git a/python/mxboard/utils.py b/python/mxboard/utils.py index 8ff5acd..b73c1b7 100644 --- a/python/mxboard/utils.py +++ b/python/mxboard/utils.py @@ -21,6 +21,9 @@ import os import logging import numpy as np +from numpy.core.numeric import full +import six + try: from PIL import Image @@ -199,7 +202,24 @@ def _save_image(image, filename, nrow=8, padding=2, square_image=True): if Image is None: raise ImportError('saving image failed because PIL is not found') im = Image.fromarray(image.asnumpy()) - im.save(filename) + parse_result = six.moves.urllib.parse.urlparse(filename) + if parse_result.scheme == '': + im.save(filename) + elif parse_result.scheme in ('hdfs', 'viewfs'): + import pyarrow.fs + # PIL does not support write to a HDFS stream + # Write to local first + import tempfile + with tempfile.TemporaryDirectory() as tmpdir: + _, filename_tail = os.path.split(filename) + with open(os.path.join(tmpdir, filename_tail), 'wb') as fp: + im.save(fp) + with open(os.path.join(tmpdir, filename_tail), 'rb') as fp: + hdfs = pyarrow.fs.HadoopFileSystem(host='default', port=0) + hdfs_fp = hdfs.open_output_stream(parse_result.path) + hdfs_fp.write(fp.read()) + hdfs_fp.close() + def _prepare_image(img, nrow=8, padding=2, square_image=False): @@ -254,10 +274,18 @@ def _make_metadata_tsv(metadata, save_path): if len(metadata.shape) == 1: metadata = metadata.reshape(-1, 1) - with open(os.path.join(save_path, 'metadata.tsv'), 'w') as f: + parse_result = six.moves.urllib.parse.urlparse(save_path) + if parse_result.scheme == '': + with open(os.path.join(save_path, 'metadata.tsv'), 'w') as f: + for row in metadata: + f.write('\t'.join([str(x) for x in row]) + '\n') + elif parse_result.scheme in ('hdfs', 'viewfs'): + import pyarrow.fs + hdfs = pyarrow.fs.HadoopFileSystem(host='default', port=0) + hdfs_fp = hdfs.open_output_stream(os.path.join(parse_result.path, 'metadata.tsv')) for row in metadata: - f.write('\t'.join([str(x) for x in row]) + '\n') - + hdfs_fp.write('\t'.join([str(x) for x in row]).encode() + b'\n') + hdfs_fp.close() def _make_sprite_image(images, save_path): """Given an NDArray as a batch images, make a sprite image out of it following the rule @@ -288,25 +316,39 @@ def _add_embedding_config(file_path, data_dir, has_metadata=False, label_img_sha """Creates a config file used by the embedding projector. Adapted from the TensorFlow function `visualize_embeddings()` at https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tensorboard/plugins/projector/__init__.py""" - with open(os.path.join(file_path, 'projector_config.pbtxt'), 'a') as f: - s = 'embeddings {\n' - s += 'tensor_name: "{}"\n'.format(data_dir) - s += 'tensor_path: "{}"\n'.format(os.path.join(data_dir, 'tensors.tsv')) - if has_metadata: - s += 'metadata_path: "{}"\n'.format(os.path.join(data_dir, 'metadata.tsv')) - if label_img_shape is not None: - if len(label_img_shape) != 4: - logging.warning('expected 4D sprite image in the format NCHW, while received image' - ' ndim=%d, skipping saving sprite' - ' image info', len(label_img_shape)) - else: - s += 'sprite {\n' - s += 'image_path: "{}"\n'.format(os.path.join(data_dir, 'sprite.png')) - s += 'single_image_dim: {}\n'.format(label_img_shape[3]) - s += 'single_image_dim: {}\n'.format(label_img_shape[2]) - s += '}\n' - s += '}\n' - f.write(s) + s = 'embeddings {\n' + s += 'tensor_name: "{}"\n'.format(data_dir) + s += 'tensor_path: "{}"\n'.format(os.path.join(data_dir, 'tensors.tsv')) + if has_metadata: + s += 'metadata_path: "{}"\n'.format(os.path.join(data_dir, 'metadata.tsv')) + if label_img_shape is not None: + if len(label_img_shape) != 4: + logging.warning('expected 4D sprite image in the format NCHW, while received image' + ' ndim=%d, skipping saving sprite' + ' image info', len(label_img_shape)) + else: + s += 'sprite {\n' + s += 'image_path: "{}"\n'.format(os.path.join(data_dir, 'sprite.png')) + s += 'single_image_dim: {}\n'.format(label_img_shape[3]) + s += 'single_image_dim: {}\n'.format(label_img_shape[2]) + s += '}\n' + s += '}\n' + parse_result = six.moves.urllib.parse.urlparse(file_path) + if parse_result.scheme == '': + with open(os.path.join(file_path, 'projector_config.pbtxt'), 'a') as f: + f.write(s) + elif parse_result.scheme in ('hdfs', 'viewfs'): + import pyarrow.fs + hdfs = pyarrow.fs.HadoopFileSystem(host='default', port=0) + full_path = os.path.join(parse_result.path, 'projector_config.pbtxt') + # open_append_stream does not create empty file if target does not exist + # See https://issues.apache.org/jira/browse/ARROW-14925 + if hdfs.get_file_info(full_path).type == pyarrow.fs.FileType.NotFound: + hdfs_fp = hdfs.open_output_stream(full_path) + else: + hdfs_fp = hdfs.open_append_stream(full_path) + hdfs_fp.write(s.encode()) + hdfs_fp.close() def _save_embedding_tsv(data, file_path): @@ -319,7 +361,17 @@ def _save_embedding_tsv(data, file_path): else: raise TypeError('expected NDArray of np.ndarray, while received type {}'.format( str(type(data)))) - with open(os.path.join(file_path, 'tensors.tsv'), 'w') as f: + parse_result = six.moves.urllib.parse.urlparse(file_path) + if parse_result.scheme == '': + with open(os.path.join(file_path, 'tensors.tsv'), 'w') as f: + for x in data_list: + x = [str(i) for i in x] + f.write('\t'.join(x) + '\n') + elif parse_result.scheme in ('hdfs', 'viewfs'): + import pyarrow.fs + hdfs = pyarrow.fs.HadoopFileSystem(host='default', port=0) + hdfs_fp = hdfs.open_output_stream(os.path.join(parse_result.path, 'tensors.tsv')) for x in data_list: x = [str(i) for i in x] - f.write('\t'.join(x) + '\n') + hdfs_fp.write('\t'.join(x).encode() + b'\n') + hdfs_fp.close() \ No newline at end of file diff --git a/python/setup.py b/python/setup.py index 40c4503..4ed088b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -50,7 +50,7 @@ def compile_summary_protobuf(): setup( name='mxboard', - version='0.1.0', + version='0.1.1', description='A logging tool for visualizing MXNet data in TensorBoard', long_description='MXBoard is a logging tool that enables visualization of MXNet data in TensorBoard.', author='Amazon Web Services',