Skip to content
This repository has been archived by the owner on May 1, 2024. It is now read-only.

Support summary written to HDFS #50

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/mxboard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
from __future__ import absolute_import
from .writer import SummaryWriter

__version__ = '0.1.0'
__version__ = '0.1.1'
11 changes: 9 additions & 2 deletions python/mxboard/event_file_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,15 @@ def __init__(self, logdir, max_queue=10, flush_secs=120, filename_suffix=None, v
the event file:
"""
self._logdir = logdir
if not os.path.exists(self._logdir):
os.makedirs(self._logdir)
parse_result = six.moves.urllib.parse.urlparse(self._logdir)
if parse_result.scheme == '':
if not os.path.exists(self._logdir):
os.makedirs(self._logdir)
elif parse_result.scheme in ('hdfs', 'viewfs'):
import pyarrow.fs
hdfs = pyarrow.fs.HadoopFileSystem(host='default', port=0)
hdfs.create_dir(parse_result.path)

self._event_queue = six.moves.queue.Queue(max_queue)
self._ev_writer = EventsWriter(os.path.join(self._logdir, "events"), verbose=verbose)
self._flush_secs = flush_secs
Expand Down
10 changes: 9 additions & 1 deletion python/mxboard/record_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Writer for writing events to the event file."""

from __future__ import absolute_import
import six
import struct
from ._crc32c import crc32c

Expand All @@ -37,7 +38,14 @@ class RecordWriter(object):
def __init__(self, path):
self._writer = None
try:
self._writer = open(path, 'wb')
parse_result = six.moves.urllib.parse.urlparse(path)
if parse_result.scheme == '':
self._writer = open(path, 'wb')
elif parse_result.scheme in ('hdfs', 'viewfs'):
import pyarrow.fs
# Use fs.defaultFS from core-site.xml
hdfs = pyarrow.fs.HadoopFileSystem(host='default', port=0)
self._writer = hdfs.open_output_stream(parse_result.path)
except (OSError, IOError) as err:
raise ValueError('failed to open file {}: {}'.format(path, str(err)))

Expand Down
102 changes: 77 additions & 25 deletions python/mxboard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import os
import logging
import numpy as np
from numpy.core.numeric import full
import six


try:
from PIL import Image
Expand Down Expand Up @@ -199,7 +202,24 @@ def _save_image(image, filename, nrow=8, padding=2, square_image=True):
if Image is None:
raise ImportError('saving image failed because PIL is not found')
im = Image.fromarray(image.asnumpy())
im.save(filename)
parse_result = six.moves.urllib.parse.urlparse(filename)
if parse_result.scheme == '':
im.save(filename)
elif parse_result.scheme in ('hdfs', 'viewfs'):
import pyarrow.fs
# PIL does not support write to a HDFS stream
# Write to local first
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
_, filename_tail = os.path.split(filename)
with open(os.path.join(tmpdir, filename_tail), 'wb') as fp:
im.save(fp)
with open(os.path.join(tmpdir, filename_tail), 'rb') as fp:
hdfs = pyarrow.fs.HadoopFileSystem(host='default', port=0)
hdfs_fp = hdfs.open_output_stream(parse_result.path)
hdfs_fp.write(fp.read())
hdfs_fp.close()



def _prepare_image(img, nrow=8, padding=2, square_image=False):
Expand Down Expand Up @@ -254,10 +274,18 @@ def _make_metadata_tsv(metadata, save_path):

if len(metadata.shape) == 1:
metadata = metadata.reshape(-1, 1)
with open(os.path.join(save_path, 'metadata.tsv'), 'w') as f:
parse_result = six.moves.urllib.parse.urlparse(save_path)
if parse_result.scheme == '':
with open(os.path.join(save_path, 'metadata.tsv'), 'w') as f:
for row in metadata:
f.write('\t'.join([str(x) for x in row]) + '\n')
elif parse_result.scheme in ('hdfs', 'viewfs'):
import pyarrow.fs
hdfs = pyarrow.fs.HadoopFileSystem(host='default', port=0)
hdfs_fp = hdfs.open_output_stream(os.path.join(parse_result.path, 'metadata.tsv'))
for row in metadata:
f.write('\t'.join([str(x) for x in row]) + '\n')

hdfs_fp.write('\t'.join([str(x) for x in row]).encode() + b'\n')
hdfs_fp.close()

def _make_sprite_image(images, save_path):
"""Given an NDArray as a batch images, make a sprite image out of it following the rule
Expand Down Expand Up @@ -288,25 +316,39 @@ def _add_embedding_config(file_path, data_dir, has_metadata=False, label_img_sha
"""Creates a config file used by the embedding projector.
Adapted from the TensorFlow function `visualize_embeddings()` at
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tensorboard/plugins/projector/__init__.py"""
with open(os.path.join(file_path, 'projector_config.pbtxt'), 'a') as f:
s = 'embeddings {\n'
s += 'tensor_name: "{}"\n'.format(data_dir)
s += 'tensor_path: "{}"\n'.format(os.path.join(data_dir, 'tensors.tsv'))
if has_metadata:
s += 'metadata_path: "{}"\n'.format(os.path.join(data_dir, 'metadata.tsv'))
if label_img_shape is not None:
if len(label_img_shape) != 4:
logging.warning('expected 4D sprite image in the format NCHW, while received image'
' ndim=%d, skipping saving sprite'
' image info', len(label_img_shape))
else:
s += 'sprite {\n'
s += 'image_path: "{}"\n'.format(os.path.join(data_dir, 'sprite.png'))
s += 'single_image_dim: {}\n'.format(label_img_shape[3])
s += 'single_image_dim: {}\n'.format(label_img_shape[2])
s += '}\n'
s += '}\n'
f.write(s)
s = 'embeddings {\n'
s += 'tensor_name: "{}"\n'.format(data_dir)
s += 'tensor_path: "{}"\n'.format(os.path.join(data_dir, 'tensors.tsv'))
if has_metadata:
s += 'metadata_path: "{}"\n'.format(os.path.join(data_dir, 'metadata.tsv'))
if label_img_shape is not None:
if len(label_img_shape) != 4:
logging.warning('expected 4D sprite image in the format NCHW, while received image'
' ndim=%d, skipping saving sprite'
' image info', len(label_img_shape))
else:
s += 'sprite {\n'
s += 'image_path: "{}"\n'.format(os.path.join(data_dir, 'sprite.png'))
s += 'single_image_dim: {}\n'.format(label_img_shape[3])
s += 'single_image_dim: {}\n'.format(label_img_shape[2])
s += '}\n'
s += '}\n'
parse_result = six.moves.urllib.parse.urlparse(file_path)
if parse_result.scheme == '':
with open(os.path.join(file_path, 'projector_config.pbtxt'), 'a') as f:
f.write(s)
elif parse_result.scheme in ('hdfs', 'viewfs'):
import pyarrow.fs
hdfs = pyarrow.fs.HadoopFileSystem(host='default', port=0)
full_path = os.path.join(parse_result.path, 'projector_config.pbtxt')
# open_append_stream does not create empty file if target does not exist
# See https://issues.apache.org/jira/browse/ARROW-14925
if hdfs.get_file_info(full_path).type == pyarrow.fs.FileType.NotFound:
hdfs_fp = hdfs.open_output_stream(full_path)
else:
hdfs_fp = hdfs.open_append_stream(full_path)
hdfs_fp.write(s.encode())
hdfs_fp.close()


def _save_embedding_tsv(data, file_path):
Expand All @@ -319,7 +361,17 @@ def _save_embedding_tsv(data, file_path):
else:
raise TypeError('expected NDArray of np.ndarray, while received type {}'.format(
str(type(data))))
with open(os.path.join(file_path, 'tensors.tsv'), 'w') as f:
parse_result = six.moves.urllib.parse.urlparse(file_path)
if parse_result.scheme == '':
with open(os.path.join(file_path, 'tensors.tsv'), 'w') as f:
for x in data_list:
x = [str(i) for i in x]
f.write('\t'.join(x) + '\n')
elif parse_result.scheme in ('hdfs', 'viewfs'):
import pyarrow.fs
hdfs = pyarrow.fs.HadoopFileSystem(host='default', port=0)
hdfs_fp = hdfs.open_output_stream(os.path.join(parse_result.path, 'tensors.tsv'))
for x in data_list:
x = [str(i) for i in x]
f.write('\t'.join(x) + '\n')
hdfs_fp.write('\t'.join(x).encode() + b'\n')
hdfs_fp.close()
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def compile_summary_protobuf():

setup(
name='mxboard',
version='0.1.0',
version='0.1.1',
description='A logging tool for visualizing MXNet data in TensorBoard',
long_description='MXBoard is a logging tool that enables visualization of MXNet data in TensorBoard.',
author='Amazon Web Services',
Expand Down