Skip to content

Commit

Permalink
Only redact secrets when the method return value is passed to the user
Browse files Browse the repository at this point in the history
  • Loading branch information
themylogin committed Nov 28, 2024
1 parent b475c10 commit 1b1b658
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 58 deletions.
6 changes: 3 additions & 3 deletions src/middlewared/middlewared/api/base/handler/accept.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def accept_params(model: type[BaseModel], args: list, *, exclude_unset=False, ex
:param model: `BaseModel` that defines method args.
:param args: a list of method args.
:param exclude_unset: if true, will not append default parameters to the list.
:param expose_secrets: if false, will replace `Private` parameters with a placeholder.
:param expose_secrets: if false, will replace `Secret` parameters with a placeholder.
:return: a validated list of method args.
"""
args_as_dict = model_dict_from_list(model, args)
Expand Down Expand Up @@ -60,7 +60,7 @@ def validate_model(model: type[BaseModel], data: dict, *, exclude_unset=False, e
:param model: `BaseModel` subclass.
:param data: provided data.
:param exclude_unset: if true, will not add default values.
:param expose_secrets: if false, will replace `Private` fields with a placeholder.
:param expose_secrets: if false, will replace `Secret` fields with a placeholder.
:return: validated data.
"""
try:
Expand All @@ -83,5 +83,5 @@ def validate_model(model: type[BaseModel], data: dict, *, exclude_unset=False, e
context={"expose_secrets": expose_secrets},
exclude_unset=exclude_unset,
warnings=False,
by_alias=True
by_alias=True,
)
8 changes: 4 additions & 4 deletions src/middlewared/middlewared/api/base/handler/dump_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def dump_params(model: type[BaseModel], args: list, expose_secrets: bool) -> lis
:param model: `BaseModel` that defines method args.
:param args: a list of method args.
:param expose_secrets: if false, will replace `Private` parameters with a placeholder.
:param expose_secrets: if false, will replace `Secret` parameters with a placeholder.
:return: A list of method call arguments ready to be printed.
"""
try:
Expand All @@ -32,10 +32,10 @@ def dump_params(model: type[BaseModel], args: list, expose_secrets: bool) -> lis

def remove_secrets(model: type[BaseModel], value):
"""
Removes `Private` values from a model value.
Removes `Secret` values from a model value.
:param model: `BaseModel` that corresponds to `value`.
:param value: value that potentially contains `Private` data.
:return: `value` with `Private` parameters replaced with a placeholder.
:param value: value that potentially contains `Secret` data.
:return: `value` with `Secret` parameters replaced with a placeholder.
"""
if isinstance(value, dict) and (nested_model := model_field_is_model(model)):
return {
Expand Down
8 changes: 8 additions & 0 deletions src/middlewared/middlewared/api/base/handler/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@


def serialize_result(model, result, expose_secrets):
"""
Serializes a `result` of the method execution using the corresponding `model`.
:param model: `BaseModel` that defines method return value.
:param result: method return value.
:param expose_secrets: if false, will replace `Secret` parameters with a placeholder.
:return: serialized method execution result.
"""
return model(result=result).model_dump(
context={"expose_secrets": expose_secrets},
warnings=False,
Expand Down
10 changes: 9 additions & 1 deletion src/middlewared/middlewared/api/base/handler/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def adapt(self, value: dict, model_name: str, version1: str, version2: str) -> d
:param version2: target API version that needs `value`
:return: converted value
"""
return self.adapt_model(value, model_name, version1, version2)[1]

def adapt_model(self, value: dict, model_name: str, version1: str, version2: str) -> tuple[type[BaseModel], dict]:
"""
Same as `adapt`, but returned value will be a tuple of `version2` model instance and converted value.
"""
try:
version1_index = self.versions_history.index(version1)
except ValueError:
Expand All @@ -101,6 +107,7 @@ def adapt(self, value: dict, model_name: str, version1: str, version2: str) -> d
raise APIVersionDoesNotContainModelException(current_version.version, model_name)

value_factory = functools.partial(validate_model, current_version_model, value)
model = current_version_model

if version1_index < version2_index:
step = 1
Expand All @@ -115,10 +122,11 @@ def adapt(self, value: dict, model_name: str, version1: str, version2: str) -> d
value_factory = functools.partial(
self._adapt_model, value_factory, model_name, current_version, new_version, direction,
)
model = new_version.models.get(model_name)

current_version = new_version

return value_factory()
return model, value_factory()

def _adapt_model(self, value_factory: Callable[[], dict], model_name: str, current_version: APIVersion,
new_version: APIVersion, direction: Direction):
Expand Down
38 changes: 22 additions & 16 deletions src/middlewared/middlewared/api/base/server/legacy_api_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ def __init__(self, middleware: "Middleware", name: str, api_version: str, adapte
methodobj = self.methodobj
if crud_methodobj := real_crud_method(methodobj):
methodobj = crud_methodobj
if hasattr(methodobj, "new_style_accepts"):
if hasattr(methodobj, "new_style_accepts"): # FIXME: Remove this check when all models become new style
self.accepts_model = methodobj.new_style_accepts
self.returns_model = methodobj.new_style_returns
else:
self.accepts_model = None
self.returns_model = None

async def call(self, app: "RpcWebSocketApp", params):
if self.accepts_model:
return self._adapt_result(await super().call(app, self._adapt_params(params)))
if self.accepts_model: # FIXME: Remove this check when all models become new style
params = self._adapt_params(params)

return await super().call(app, params)

Expand All @@ -70,22 +70,28 @@ def _adapt_params(self, params):

return [adapted_params_dict[field] for field in self.accepts_model.model_fields]

def _adapt_result(self, result):
try:
return self.adapter.adapt(
{"result": result},
self.returns_model.__name__,
self.adapter.current_version,
self.api_version,
)["result"]
except APIVersionDoesNotContainModelException:
if self.passthrough_nonexistent_methods:
return result
def _dump_result(self, app: "RpcWebSocketApp", methodobj, result):
if self.accepts_model: # FIXME: Remove this check when all models become new style
try:
model, result = self.adapter.adapt_model(
{"result": result},
self.returns_model.__name__,
self.adapter.current_version,
self.api_version,
)
except APIVersionDoesNotContainModelException:
if self.passthrough_nonexistent_methods:
return super()._dump_result(app, methodobj, result)

raise

return self.middleware.dump_result(self.serviceobj, methodobj, app, result["result"],
new_style_returns_model=model)

raise
return super()._dump_result(app, methodobj, result)

def dump_args(self, params):
if self.accepts_model:
if self.accepts_model: # FIXME: Remove this check when all models become new style
return dump_params(self.accepts_model, params, False)

return super().dump_args(params)
10 changes: 7 additions & 3 deletions src/middlewared/middlewared/api/base/server/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,17 @@ async def call(self, app: "RpcWebSocketApp", params: list):

result = await self.middleware.call_with_audit(self.name, self.serviceobj, methodobj, params, app)
if isinstance(result, Job):
result = result.id
elif isinstance(result, types.GeneratorType):
return result.id

if isinstance(result, types.GeneratorType):
result = list(result)
elif isinstance(result, types.AsyncGeneratorType):
result = [i async for i in result]

return result
return self._dump_result(app, methodobj, result)

def _dump_result(self, app: "RpcWebSocketApp", methodobj, result):
return self.middleware.dump_result(self.serviceobj, methodobj, app, result)

def dump_args(self, params: list) -> list:
"""
Expand Down
50 changes: 28 additions & 22 deletions src/middlewared/middlewared/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ async def call_method(self, message, serviceobj, methodobj):
result = [i async for i in result]
else:
if lam.returns_model:
result = lam._adapt_result(result)
result = lam._dump_result(self, methodobj, result)
else:
result = self.middleware.dump_result(serviceobj, methodobj, self, result)

self._send({
'id': message['id'],
Expand Down Expand Up @@ -1510,22 +1512,37 @@ def dump_args(self, args, method=None, method_name=None):
return [method.accepts[i].dump(arg) if i < len(method.accepts) else arg
for i, arg in enumerate(args)]

def dump_result(self, method, result, expose_secrets):
def dump_result(self, serviceobj, methodobj, app, result, *, new_style_returns_model=None):
expose_secrets = True
if app and app.authenticated_credentials:
if app.authenticated_credentials.is_user_session and not (
credential_has_full_admin(app.authenticated_credentials) or
(
serviceobj._config.role_prefix and
app.authenticated_credentials.has_role(f'{serviceobj._config.role_prefix}_WRITE')
)
):
expose_secrets = False

if isinstance(result, Job):
return result

if method_self := getattr(method, "__self__", None):
if method.__name__ in ["create", "update", "delete"]:
if do_method := getattr(method_self, f"do_{method.__name__}", None):
if method_self := getattr(methodobj, "__self__", None):
if methodobj.__name__ in ["create", "update", "delete"]:
if do_method := getattr(method_self, f"do_{methodobj.__name__}", None):
if hasattr(do_method, "new_style_returns"):
# FIXME: Get rid of `create`/`do_create` duality
method = do_method
methodobj = do_method

if hasattr(methodobj, "new_style_returns"):
# FIXME: When all models become new style, this should be passed explicitly
if new_style_returns_model is None:
new_style_returns_model = methodobj.new_style_returns

if hasattr(method, "new_style_returns"):
return serialize_result(method.new_style_returns, result, expose_secrets)
return serialize_result(new_style_returns_model, result, expose_secrets)

if not expose_secrets and hasattr(method, "returns") and method.returns:
schema = method.returns[0]
if not expose_secrets and hasattr(methodobj, "returns") and methodobj.returns:
schema = methodobj.returns[0]
if isinstance(schema, OROperator):
result = schema.dump(result, False)
else:
Expand Down Expand Up @@ -1620,18 +1637,7 @@ async def job_on_finish_cb(job):
job = result
await job.set_on_finish_cb(job_on_finish_cb)

expose_secrets = True
if app and app.authenticated_credentials:
if app.authenticated_credentials.is_user_session and not (
credential_has_full_admin(app.authenticated_credentials) or
(
serviceobj._config.role_prefix and
app.authenticated_credentials.has_role(f'{serviceobj._config.role_prefix}_WRITE')
)
):
expose_secrets = False

result = self.dump_result(methodobj, result, expose_secrets)
return result
finally:
# If the method is a job, audit message will be logged by `job_on_finish_cb`
if job is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,3 @@ def method(number, text, multiplier):

def test_adapt_params():
assert legacy_api_method._adapt_params([1]) == [1, "Default", 2]


def test_adapt_result():
assert legacy_api_method._adapt_result(1) == "1"
1 change: 1 addition & 0 deletions src/middlewared/middlewared/restful.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,7 @@ async def do(self, http_method, req, resp, app, authorized, **kwargs):
if authorized:
result = await self.middleware.call_with_audit(methodname, serviceobj, methodobj, method_args,
**method_kwargs)
result = self.middleware.dump_result(serviceobj, methodobj, app, result)
else:
await self.middleware.log_audit_message_for_method(methodname, methodobj, method_args, app,
True, False, False)
Expand Down
6 changes: 4 additions & 2 deletions src/middlewared/middlewared/service/core_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,11 +804,11 @@ async def bulk(self, app, job, method, params, description):
# entries for external callers to methods. app is only None
# on internal calls to core.bulk.
if app:
msg = await self.middleware.call_with_audit(method, serviceobj, methodobj, p, app=app)
msg = await self.middleware.call_with_audit(method, serviceobj, methodobj, p, app)
else:
msg = await self.middleware.call(method, *p)

status = {"result": msg, "error": None}
status = {"error": None}

if isinstance(msg, Job):
b_job = msg
Expand All @@ -817,6 +817,8 @@ async def bulk(self, app, job, method, params, description):

if b_job.error:
status["error"] = b_job.error
else:
status["result"] = self.middleware.dump_result(serviceobj, methodobj, app, msg)

statuses.append(status)
except Exception as e:
Expand Down
40 changes: 37 additions & 3 deletions tests/api2/test_legacy_websocket.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import random
import string

import pytest

from truenas_api_client import Client

from middlewared.test.integration.assets.account import unprivileged_user
from middlewared.test.integration.assets.cloud_sync import credential
from middlewared.test.integration.utils import password, websocket_url

Expand All @@ -17,7 +21,28 @@ def c():
yield c


def test_adapts_cloud_credentials(c):
@pytest.fixture(scope="module")
def unprivileged_client():
suffix = "".join([random.choice(string.ascii_lowercase + string.digits) for _ in range(8)])
with unprivileged_user(
username=f"unprivileged_{suffix}",
group_name=f"unprivileged_users_{suffix}",
privilege_name=f"Unprivileged users ({suffix})",
allowlist=[],
roles=["READONLY_ADMIN"],
web_shell=False,
) as t:
with Client(websocket_url() + "/websocket") as c:
c.call("auth.login_ex", {
"mechanism": "PASSWORD_PLAIN",
"username": t.username,
"password": t.password,
})
yield c


@pytest.fixture(scope="module")
def ftp_credential():
with credential({
"provider": {
"type": "FTP",
Expand All @@ -27,5 +52,14 @@ def test_adapts_cloud_credentials(c):
"pass": "",
},
}) as cred:
result = c.call("cloudsync.credentials.get_instance", cred["id"])
assert result["provider"] == "FTP"
yield cred


def test_adapts_cloud_credentials(c, ftp_credential):
result = c.call("cloudsync.credentials.get_instance", ftp_credential["id"])
assert result["provider"] == "FTP"


def test_adapts_cloud_credentials_for_unprivileged(unprivileged_client, ftp_credential):
result = unprivileged_client.call("cloudsync.credentials.get_instance", ftp_credential["id"])
assert result["attributes"] == "********"

0 comments on commit 1b1b658

Please sign in to comment.