Skip to content

Commit

Permalink
fix(paramserver): validates yaml or json data
Browse files Browse the repository at this point in the history
Validates the JSON or YAML data if the file is not uploaded as a blob
before invoking the create_file API.

Wrike Ticket: https://www.wrike.com/open.htm?id=1315956250
  • Loading branch information
pallabpain committed Mar 20, 2024
1 parent d2ee458 commit 03efd6c
Show file tree
Hide file tree
Showing 8 changed files with 607 additions and 438 deletions.
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ urllib3 = ">=1.23"
python-dateutil = ">=2.8.1"
pytz = "*"
jsonschema = "==4.0.0"
pyyaml = ">=5.4.1"

[dev-packages]
testtools = "==2.5.0"
Expand Down
722 changes: 412 additions & 310 deletions Pipfile.lock

Large diffs are not rendered by default.

43 changes: 29 additions & 14 deletions rapyuta_io/clients/paramserver.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
from __future__ import absolute_import

import enum
import errno
import hashlib
import mimetypes
import os
import tempfile
from concurrent import futures
from os import listdir, makedirs
from os.path import isdir, join
from shutil import rmtree, copyfile

from concurrent import futures
import enum
import tempfile
import os
import hashlib
import mimetypes
import six

from rapyuta_io.utils import RestClient, InvalidParameterException, ConfigNotFoundException
from rapyuta_io.utils.error import InvalidJSONError, InvalidYAMLError
from rapyuta_io.utils.rest_client import HttpMethod
from rapyuta_io.utils.settings import PARAMSERVER_API_TREE_PATH, PARAMSERVER_API_TREEBLOBS_PATH, PARAMSERVER_API_FILENODE_PATH
from rapyuta_io.utils.settings import PARAMSERVER_API_TREE_PATH, PARAMSERVER_API_TREEBLOBS_PATH, \
PARAMSERVER_API_FILENODE_PATH
from rapyuta_io.utils.utils import create_auth_header, prepend_bearer_to_auth_token, get_api_response_data, \
validate_list_of_strings
import six
validate_list_of_strings, is_valid_json, is_valid_yaml


class _Node(str, enum.Enum):

Expand All @@ -28,6 +32,7 @@ def __str__(self):
Attribute = 'AttributeNode'
Folder = 'FolderNode'


class _ParamserverClient:
"""
Internal client for paramserver. Not for public use.
Expand All @@ -37,7 +42,6 @@ class _ParamserverClient:
default_binary_content_type = "application/octet-stream"
max_non_binary_size = 128 * 1024


def __init__(self, auth_token, project, core_api_host):
self._auth_token = auth_token
self._headers = create_auth_header(prepend_bearer_to_auth_token(auth_token), project)
Expand Down Expand Up @@ -142,10 +146,14 @@ def process_dir(self, executor, rootdir, tree_path, level, dir_futures, file_fut
if file_name.endswith('.yaml'):
with open(full_path, 'r') as f:
data = f.read()
if not is_valid_yaml(data):
raise InvalidYAMLError(full_path)
future = executor.submit(self.create_file, new_tree_path, data)
elif file_name.endswith('.json'):
with open(full_path, 'r') as f:
data = f.read()
if not is_valid_json(data):
raise InvalidJSONError(full_path)
future = executor.submit(self.create_file, new_tree_path, data, content_type=self.json_content_type)
else:
future = executor.submit(self.create_binary_file, new_tree_path, full_path)
Expand All @@ -167,17 +175,21 @@ def process_folder(self, executor, rootdir, tree_path, level, dir_futures, file_
elif file_name.endswith('.yaml'):
with open(full_path, 'r') as f:
data = f.read()
if not is_valid_yaml(data):
raise InvalidYAMLError(full_path)
future = executor.submit(self.create_file, new_tree_path, data)
elif file_name.endswith('.json'):
with open(full_path, 'r') as f:
data = f.read()
if not is_valid_json(data):
raise InvalidJSONError(full_path)
future = executor.submit(self.create_file, new_tree_path, data, content_type=self.json_content_type)
else:
future = executor.submit(self.create_binary_file, new_tree_path, full_path)
file_futures[future] = new_tree_path
return dir_futures, file_futures

def upload_configurations(self, rootdir, tree_names, delete_existing_trees, as_folder = False):
def upload_configurations(self, rootdir, tree_names, delete_existing_trees, as_folder=False):
self.validate_args(rootdir, tree_names, delete_existing_trees, as_folder)
with futures.ThreadPoolExecutor(max_workers=15) as executor:
dir_futures = self.process_root_dir(executor, rootdir, tree_names, delete_existing_trees)
Expand All @@ -193,7 +205,8 @@ def upload_configurations(self, rootdir, tree_names, delete_existing_trees, as_f
raise exc

processor_func = self.process_dir if not as_folder else self.process_folder
dir_futures, file_futures = processor_func(executor, rootdir, tree_path, level, dir_futures, file_futures)
dir_futures, file_futures = processor_func(executor, rootdir, tree_path, level, dir_futures,
file_futures)
done = futures.wait(dir_futures, return_when=futures.FIRST_COMPLETED).done
future = done.pop() if len(done) else None

Expand Down Expand Up @@ -241,7 +254,8 @@ def download_tree(self, tree_name, rootdir, delete_existing, blob_temp_dir):

def get_blob_data(self, tree_names):
url = self._core_api_host + PARAMSERVER_API_TREEBLOBS_PATH
response = RestClient(url).method(HttpMethod.GET).query_param({'treeNames': tree_names}).headers(self._headers).retry(0).execute()
response = RestClient(url).method(HttpMethod.GET).query_param({'treeNames': tree_names}).headers(
self._headers).retry(0).execute()
blob_data = get_api_response_data(response, parse_full=True).get('data', {})
return blob_data

Expand All @@ -254,7 +268,7 @@ def download_blob_file(blob, blob_temp_dir):
f.write(chunk)

@staticmethod
def validate_args(rootdir, tree_names, delete_existing_trees, as_folder = False):
def validate_args(rootdir, tree_names, delete_existing_trees, as_folder=False):
if not isinstance(rootdir, six.string_types):
raise InvalidParameterException('rootdir must be a string')
if tree_names:
Expand All @@ -263,6 +277,7 @@ def validate_args(rootdir, tree_names, delete_existing_trees, as_folder = False)
raise InvalidParameterException('delete_existing_trees must be a boolean')
if not isinstance(as_folder, bool):
raise InvalidParameterException('as_folder must be a boolean')

def download_configurations(self, rootdir, tree_names, delete_existing_trees):
self.validate_args(rootdir, tree_names, delete_existing_trees)
self._safe_makedirs(rootdir)
Expand Down
19 changes: 19 additions & 0 deletions rapyuta_io/utils/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class DeploymentNotRunningException(Exception):
"""
:ivar deployment_status: Deployment status object retrieved from the last poll
"""

def __init__(self, msg, deployment_status=None):
self.deployment_status = deployment_status
Exception.__init__(self, msg)
Expand Down Expand Up @@ -167,3 +168,21 @@ def __init__(self, msg=None):
class BuildOperationFailed(Exception):
def __init__(self, msg):
Exception.__init__(self, msg)


class InvalidJSONError(Exception):
def __init__(self, file_path=None):
msg = "Invalid JSON"
if file_path:
msg += ": {}".format(file_path)

Exception.__init__(self, msg)


class InvalidYAMLError(Exception):
def __init__(self, file_path=None):
msg = "Invalid YAML"
if file_path:
msg += ": {}".format(file_path)

Exception.__init__(self, msg)
32 changes: 31 additions & 1 deletion rapyuta_io/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

import requests
import six
import yaml
from six.moves import range

from rapyuta_io.utils import APIError, ParameterMissingException, InvalidParameterException, \
UnauthorizedError, ResourceNotFoundError, BadRequestError, InternalServerError, ConflictError, \
ForbiddenError
from rapyuta_io.utils.settings import EMPTY, DEFAULT_RANDOM_VALUE_LENGTH
from six.moves import range

BEARER = "Bearer"

Expand Down Expand Up @@ -140,3 +142,31 @@ def is_true(value):
def is_false(value):
return value in [False, 'False', 'false']


def is_valid_json(data):
"""Check if the given data is a valid JSON"""
try:
json.loads(data)
except json.decoder.JSONDecodeError:
return False

return True


def is_valid_yaml(data):
"""Check if the given data is a valid YAML"""

try:
loaded = yaml.safe_load(data)
except yaml.YAMLError:
return False

# For example, consider a file with just the following text.
# The yaml.safe_load() function will still parse this file.
#
# invalid data
#
if not isinstance(loaded, dict):
return False

return True
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"urllib3>=1.23",
"python-dateutil>=2.8.2",
"pytz",
"pyyaml>=5.4.1",
"setuptools",
"jsonschema==4.0.0",
],
Expand Down
14 changes: 6 additions & 8 deletions tests/paramserver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyfakefs import fake_filesystem_unittest
from requests import Response

import rapyuta_io.utils.error
from rapyuta_io.utils.error import BadRequestError, InternalServerError
from tests.utils.client import get_client, headers
from tests.utils.paramserver import UPLOAD_SUCCESS_TREE_PATHS, UPLOAD_SUCCESS_MOCK_CALLS, UPLOAD_FAILURE_400CASE_TREE_PATHS, \
Expand Down Expand Up @@ -149,20 +150,17 @@ def side_effect(*args, **kwargs):
mock_response = MagicMock(spec=Response)
url = kwargs['url']
url_suffix = url[len(self.URL_PREFIX):]
if url_suffix == '/tree2/robot_type/AMR/motors.yaml':
mock_response.status_code = requests.codes.BAD_REQUEST
mock_response.text = '{"error": "invalid data"}'
else:
if url_suffix != '/tree2/robot_type/AMR/motors.yaml':
mock_response.status_code = requests.codes.OK
mock_response.text = 'null'
return mock_response
mock_request.side_effect = side_effect

with self.assertRaisesRegex(BadRequestError, 'invalid data') as exc:
with self.assertRaisesRegex(rapyuta_io.utils.error.InvalidYAMLError, 'Invalid YAML') as exc:
get_client().upload_configurations(rootdir)
self.assertEqual('tree2/robot_type/AMR/motors.yaml', exc.exception.tree_path)
mock_request.assert_has_calls(expected_mock_calls, any_order=True)
self.assertEqual(len(expected_mock_calls), mock_request.call_count, 'extra request calls were made')
self.assertRegex(str(exc.exception), 'tree2/robot_type/AMR/motors.yaml')
self.assertNotEqual(len(expected_mock_calls), mock_request.call_count,
'expected fewer calls due to client side exception')

@patch('requests.request')
def test_upload_configurations_failure_500case(self, mock_request):
Expand Down
Loading

0 comments on commit 03efd6c

Please sign in to comment.