diff --git a/rapyuta_io/clients/paramserver.py b/rapyuta_io/clients/paramserver.py index 827bd030..89a7efc7 100644 --- a/rapyuta_io/clients/paramserver.py +++ b/rapyuta_io/clients/paramserver.py @@ -11,14 +11,13 @@ import hashlib import mimetypes -from rapyuta_io.utils import RestClient, InvalidParameterException +from rapyuta_io.utils import RestClient, InvalidParameterException, ConfigNotFoundException from rapyuta_io.utils.rest_client import HttpMethod from rapyuta_io.utils.settings import PARAMSERVER_API_TREE_PATH, PARAMSERVER_API_TREEBLOBS_PATH, PARAMSERVER_API_FILENODE_PATH from rapyuta_io.utils.utils import create_auth_header, prepend_bearer_to_auth_token, get_api_response_data, \ validate_list_of_strings import six - class _Node(str, enum.Enum): def __str__(self): @@ -36,6 +35,8 @@ class _ParamserverClient: yaml_content_type = 'text/yaml' json_content_type = 'application/json' default_binary_content_type = "application/octet-stream" + max_non_binary_size = 128 * 1024 + def __init__(self, auth_token, project, core_api_host): self._auth_token = auth_token @@ -69,6 +70,11 @@ def create_binary_file(self, tree_path, file_path, retry_limit=0): if guessed_content_type[1]: headers['Content-Encoding'] = guessed_content_type[1] + # Override Content-Type for JSON and YAML to allow creating Binary files. + # This is required for large YAML/JSON files. + if content_type in (self.json_content_type, self.yaml_content_type): + content_type = self.default_binary_content_type + headers.update({'X-Rapyuta-Params-Version': "0", 'Content-Type': content_type}) @@ -129,7 +135,10 @@ def process_dir(self, executor, rootdir, tree_path, level, dir_futures, file_fut future = executor.submit(func, new_tree_path) dir_futures[future] = (new_tree_path, level + 1) elif not in_attribute_dir: # ignore files in attribute directories + file_stat = os.stat(full_path) file_name = os.path.basename(full_path) + if file_stat.st_size > self.max_non_binary_size: + future = executor.submit(self.create_binary_file, new_tree_path, full_path) if file_name.endswith('.yaml'): with open(full_path, 'r') as f: data = f.read() @@ -151,8 +160,11 @@ def process_folder(self, executor, rootdir, tree_path, level, dir_futures, file_ future = executor.submit(self.create_folder, new_tree_path) dir_futures[future] = (new_tree_path, level + 1) else: + file_stat = os.stat(full_path) file_name = os.path.basename(full_path) - if file_name.endswith('.yaml'): + if file_stat.st_size > self.max_non_binary_size: + future = executor.submit(self.create_binary_file, new_tree_path, full_path) + elif file_name.endswith('.yaml'): with open(full_path, 'r') as f: data = f.read() future = executor.submit(self.create_file, new_tree_path, data) @@ -266,6 +278,9 @@ def download_configurations(self, rootdir, tree_names, delete_existing_trees): if tree_names: api_tree_names = [tree_name for tree_name in api_tree_names if tree_name in tree_names] + if not api_tree_names: + raise ConfigNotFoundException('One or more trees not found') + blob_temp_dir = tempfile.mkdtemp() blob_files = self.get_blob_data(api_tree_names) diff --git a/rapyuta_io/rio_client.py b/rapyuta_io/rio_client.py index 96d48ef9..7c70d0ff 100644 --- a/rapyuta_io/rio_client.py +++ b/rapyuta_io/rio_client.py @@ -643,16 +643,18 @@ def download_configurations(self, rootdir, tree_names=None, delete_existing_tree Following example demonstrates how to use download_configurations and handle errors. >>> from rapyuta_io import Client - >>> from rapyuta_io.utils.error import APIError, InternalServerError + >>> from rapyuta_io.utils.error import APIError, InternalServerError, ConfigNotFoundException >>> client = Client(auth_token='auth_token', project='project_guid') >>> try: ... client.download_configurations('path/to/destination_dir', ... tree_names=['config_tree1', 'config_tree2'], ... delete_existing_trees=True) ... except (APIError, InternalServerError) as e: - ... print 'failed API request', e.tree_path, e + ... print('failed API request', e.tree_path, e) + except ConfigNotFoundException as e: + print ('config not found') ... except (IOError, OSError) as e: - ... print 'failed file/directory creation', e + ... print('failed file/directory creation', e) """ return self._paramserver_client.download_configurations(rootdir, tree_names, delete_existing_trees)