Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NAS-132676 / 25.04 / Only redact secrets when the method return value is passed to the user #15042

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/middlewared/middlewared/api/base/handler/accept.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def accept_params(model: type[BaseModel], args: list, *, exclude_unset=False, ex
:param model: `BaseModel` that defines method args.
:param args: a list of method args.
:param exclude_unset: if true, will not append default parameters to the list.
:param expose_secrets: if false, will replace `Private` parameters with a placeholder.
:param expose_secrets: if false, will replace `Secret` parameters with a placeholder.
:return: a validated list of method args.
"""
args_as_dict = model_dict_from_list(model, args)
Expand Down Expand Up @@ -60,7 +60,7 @@ def validate_model(model: type[BaseModel], data: dict, *, exclude_unset=False, e
:param model: `BaseModel` subclass.
:param data: provided data.
:param exclude_unset: if true, will not add default values.
:param expose_secrets: if false, will replace `Private` fields with a placeholder.
:param expose_secrets: if false, will replace `Secret` fields with a placeholder.
:return: validated data.
"""
try:
Expand All @@ -83,5 +83,5 @@ def validate_model(model: type[BaseModel], data: dict, *, exclude_unset=False, e
context={"expose_secrets": expose_secrets},
exclude_unset=exclude_unset,
warnings=False,
by_alias=True
by_alias=True,
)
8 changes: 4 additions & 4 deletions src/middlewared/middlewared/api/base/handler/dump_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def dump_params(model: type[BaseModel], args: list, expose_secrets: bool) -> lis

:param model: `BaseModel` that defines method args.
:param args: a list of method args.
:param expose_secrets: if false, will replace `Private` parameters with a placeholder.
:param expose_secrets: if false, will replace `Secret` parameters with a placeholder.
:return: A list of method call arguments ready to be printed.
"""
try:
Expand All @@ -32,10 +32,10 @@ def dump_params(model: type[BaseModel], args: list, expose_secrets: bool) -> lis

def remove_secrets(model: type[BaseModel], value):
"""
Removes `Private` values from a model value.
Removes `Secret` values from a model value.
:param model: `BaseModel` that corresponds to `value`.
:param value: value that potentially contains `Private` data.
:return: `value` with `Private` parameters replaced with a placeholder.
:param value: value that potentially contains `Secret` data.
:return: `value` with `Secret` parameters replaced with a placeholder.
"""
if isinstance(value, dict) and (nested_model := model_field_is_model(model)):
return {
Expand Down
8 changes: 8 additions & 0 deletions src/middlewared/middlewared/api/base/handler/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@


def serialize_result(model, result, expose_secrets):
"""
Serializes a `result` of the method execution using the corresponding `model`.

:param model: `BaseModel` that defines method return value.
:param result: method return value.
:param expose_secrets: if false, will replace `Secret` parameters with a placeholder.
:return: serialized method execution result.
"""
return model(result=result).model_dump(
context={"expose_secrets": expose_secrets},
warnings=False,
Expand Down
10 changes: 9 additions & 1 deletion src/middlewared/middlewared/api/base/handler/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def adapt(self, value: dict, model_name: str, version1: str, version2: str) -> d
:param version2: target API version that needs `value`
:return: converted value
"""
return self.adapt_model(value, model_name, version1, version2)[1]

def adapt_model(self, value: dict, model_name: str, version1: str, version2: str) -> tuple[type[BaseModel], dict]:
"""
Same as `adapt`, but returned value will be a tuple of `version2` model instance and converted value.
"""
try:
version1_index = self.versions_history.index(version1)
except ValueError:
Expand All @@ -101,6 +107,7 @@ def adapt(self, value: dict, model_name: str, version1: str, version2: str) -> d
raise APIVersionDoesNotContainModelException(current_version.version, model_name)

value_factory = functools.partial(validate_model, current_version_model, value)
model = current_version_model

if version1_index < version2_index:
step = 1
Expand All @@ -115,10 +122,11 @@ def adapt(self, value: dict, model_name: str, version1: str, version2: str) -> d
value_factory = functools.partial(
self._adapt_model, value_factory, model_name, current_version, new_version, direction,
)
model = new_version.models.get(model_name)

current_version = new_version

return value_factory()
return model, value_factory()

def _adapt_model(self, value_factory: Callable[[], dict], model_name: str, current_version: APIVersion,
new_version: APIVersion, direction: Direction):
Expand Down
38 changes: 22 additions & 16 deletions src/middlewared/middlewared/api/base/server/legacy_api_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ def __init__(self, middleware: "Middleware", name: str, api_version: str, adapte
methodobj = self.methodobj
if crud_methodobj := real_crud_method(methodobj):
methodobj = crud_methodobj
if hasattr(methodobj, "new_style_accepts"):
if hasattr(methodobj, "new_style_accepts"): # FIXME: Remove this check when all models become new style
self.accepts_model = methodobj.new_style_accepts
self.returns_model = methodobj.new_style_returns
else:
self.accepts_model = None
self.returns_model = None

async def call(self, app: "RpcWebSocketApp", params):
if self.accepts_model:
return self._adapt_result(await super().call(app, self._adapt_params(params)))
if self.accepts_model: # FIXME: Remove this check when all models become new style
params = self._adapt_params(params)

return await super().call(app, params)

Expand All @@ -70,22 +70,28 @@ def _adapt_params(self, params):

return [adapted_params_dict[field] for field in self.accepts_model.model_fields]

def _adapt_result(self, result):
try:
return self.adapter.adapt(
{"result": result},
self.returns_model.__name__,
self.adapter.current_version,
self.api_version,
)["result"]
except APIVersionDoesNotContainModelException:
if self.passthrough_nonexistent_methods:
return result
def _dump_result(self, app: "RpcWebSocketApp", methodobj, result):
if self.accepts_model: # FIXME: Remove this check when all models become new style
try:
model, result = self.adapter.adapt_model(
{"result": result},
self.returns_model.__name__,
self.adapter.current_version,
self.api_version,
)
except APIVersionDoesNotContainModelException:
if self.passthrough_nonexistent_methods:
return super()._dump_result(app, methodobj, result)

raise

return self.middleware.dump_result(self.serviceobj, methodobj, app, result["result"],
new_style_returns_model=model)

raise
return super()._dump_result(app, methodobj, result)

def dump_args(self, params):
if self.accepts_model:
if self.accepts_model: # FIXME: Remove this check when all models become new style
return dump_params(self.accepts_model, params, False)

return super().dump_args(params)
10 changes: 7 additions & 3 deletions src/middlewared/middlewared/api/base/server/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,17 @@ async def call(self, app: "RpcWebSocketApp", params: list):

result = await self.middleware.call_with_audit(self.name, self.serviceobj, methodobj, params, app)
if isinstance(result, Job):
result = result.id
elif isinstance(result, types.GeneratorType):
return result.id

if isinstance(result, types.GeneratorType):
result = list(result)
elif isinstance(result, types.AsyncGeneratorType):
result = [i async for i in result]

return result
return self._dump_result(app, methodobj, result)

def _dump_result(self, app: "RpcWebSocketApp", methodobj, result):
return self.middleware.dump_result(self.serviceobj, methodobj, app, result)

def dump_args(self, params: list) -> list:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/middlewared/middlewared/api/base/types/base/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __get_pydantic_core_schema__(
LongString = Annotated[
LongStringWrapper,
BeforeValidator(LongStringWrapper),
PlainSerializer(lambda x: undefined if x == undefined else x.value),
PlainSerializer(lambda x: x.value if isinstance(x, LongStringWrapper) else x),
]

NonEmptyString = Annotated[str, Field(min_length=1)]
Expand Down
4 changes: 2 additions & 2 deletions src/middlewared/middlewared/api/v25_04_0/cloud_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class CloudBackupCreate(BaseModel):
pre_script: LongString = ""
post_script: LongString = ""
snapshot: bool = False
include: list[NonEmptyString]
exclude: list[NonEmptyString]
include: list[NonEmptyString] = []
exclude: list[NonEmptyString] = []
args: LongString = ""
enabled: bool = True

Expand Down
50 changes: 28 additions & 22 deletions src/middlewared/middlewared/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ async def call_method(self, message, serviceobj, methodobj):
result = [i async for i in result]
else:
if lam.returns_model:
result = lam._adapt_result(result)
result = lam._dump_result(self, methodobj, result)
else:
result = self.middleware.dump_result(serviceobj, methodobj, self, result)

self._send({
'id': message['id'],
Expand Down Expand Up @@ -1510,22 +1512,37 @@ def dump_args(self, args, method=None, method_name=None):
return [method.accepts[i].dump(arg) if i < len(method.accepts) else arg
for i, arg in enumerate(args)]

def dump_result(self, method, result, expose_secrets):
def dump_result(self, serviceobj, methodobj, app, result, *, new_style_returns_model=None):
expose_secrets = True
if app and app.authenticated_credentials:
if app.authenticated_credentials.is_user_session and not (
credential_has_full_admin(app.authenticated_credentials) or
(
serviceobj._config.role_prefix and
app.authenticated_credentials.has_role(f'{serviceobj._config.role_prefix}_WRITE')
)
):
expose_secrets = False

if isinstance(result, Job):
return result

if method_self := getattr(method, "__self__", None):
if method.__name__ in ["create", "update", "delete"]:
if do_method := getattr(method_self, f"do_{method.__name__}", None):
if method_self := getattr(methodobj, "__self__", None):
if methodobj.__name__ in ["create", "update", "delete"]:
if do_method := getattr(method_self, f"do_{methodobj.__name__}", None):
if hasattr(do_method, "new_style_returns"):
# FIXME: Get rid of `create`/`do_create` duality
method = do_method
methodobj = do_method

if hasattr(methodobj, "new_style_returns"):
# FIXME: When all models become new style, this should be passed explicitly
if new_style_returns_model is None:
new_style_returns_model = methodobj.new_style_returns

if hasattr(method, "new_style_returns"):
return serialize_result(method.new_style_returns, result, expose_secrets)
return serialize_result(new_style_returns_model, result, expose_secrets)

if not expose_secrets and hasattr(method, "returns") and method.returns:
schema = method.returns[0]
if not expose_secrets and hasattr(methodobj, "returns") and methodobj.returns:
schema = methodobj.returns[0]
if isinstance(schema, OROperator):
result = schema.dump(result, False)
else:
Expand Down Expand Up @@ -1620,18 +1637,7 @@ async def job_on_finish_cb(job):
job = result
await job.set_on_finish_cb(job_on_finish_cb)

expose_secrets = True
if app and app.authenticated_credentials:
if app.authenticated_credentials.is_user_session and not (
credential_has_full_admin(app.authenticated_credentials) or
(
serviceobj._config.role_prefix and
app.authenticated_credentials.has_role(f'{serviceobj._config.role_prefix}_WRITE')
)
):
expose_secrets = False

result = self.dump_result(methodobj, result, expose_secrets)
return result
finally:
# If the method is a job, audit message will be logged by `job_on_finish_cb`
if job is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,3 @@ def method(number, text, multiplier):

def test_adapt_params():
assert legacy_api_method._adapt_params([1]) == [1, "Default", 2]


def test_adapt_result():
assert legacy_api_method._adapt_result(1) == "1"
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ class SecretLongStringMethodArgs(BaseModel):

def test_secret_long_string():
assert accept_params(SecretLongStringMethodArgs, ["test"]) == ["test"]


class LongStringDefaultMethodArgs(BaseModel):
str: LongString = ""


def test_long_string_default():
assert accept_params(LongStringDefaultMethodArgs, []) == [""]
1 change: 1 addition & 0 deletions src/middlewared/middlewared/restful.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,7 @@ async def do(self, http_method, req, resp, app, authorized, **kwargs):
if authorized:
result = await self.middleware.call_with_audit(methodname, serviceobj, methodobj, method_args,
**method_kwargs)
result = self.middleware.dump_result(serviceobj, methodobj, app, result)
else:
await self.middleware.log_audit_message_for_method(methodname, methodobj, method_args, app,
True, False, False)
Expand Down
6 changes: 4 additions & 2 deletions src/middlewared/middlewared/service/core_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,11 +804,11 @@ async def bulk(self, app, job, method, params, description):
# entries for external callers to methods. app is only None
# on internal calls to core.bulk.
if app:
msg = await self.middleware.call_with_audit(method, serviceobj, methodobj, p, app=app)
msg = await self.middleware.call_with_audit(method, serviceobj, methodobj, p, app)
else:
msg = await self.middleware.call(method, *p)

status = {"result": msg, "error": None}
status = {"error": None}

if isinstance(msg, Job):
b_job = msg
Expand All @@ -817,6 +817,8 @@ async def bulk(self, app, job, method, params, description):

if b_job.error:
status["error"] = b_job.error
else:
status["result"] = self.middleware.dump_result(serviceobj, methodobj, app, msg)

statuses.append(status)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion tests/api2/test_account_privilege_role_private_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def vmware():
("acme.dns.authenticator", dns_authenticator, {}, ["attributes"]),
("certificate", 1, {}, ["privatekey", "issuer"]),
("certificateauthority", certificateauthority, {}, ["privatekey", "issuer"]),
("cloud_backup", cloudbackup, {}, ["credentials.provider", "password"]),
("cloud_backup", cloudbackup, {}, ["credentials.provider.pass", "password"]),
("cloudsync.credentials", cloudsync_credential, {}, ["provider.pass"]),
("cloudsync", cloudsync, {}, ["credentials.provider", "encryption_password"]),
("disk", disk, {"extra": {"passwords": True}}, ["passwd"]),
Expand Down
40 changes: 37 additions & 3 deletions tests/api2/test_legacy_websocket.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import random
import string

import pytest

from truenas_api_client import Client

from middlewared.test.integration.assets.account import unprivileged_user
from middlewared.test.integration.assets.cloud_sync import credential
from middlewared.test.integration.utils import password, websocket_url

Expand All @@ -17,7 +21,28 @@ def c():
yield c


def test_adapts_cloud_credentials(c):
@pytest.fixture(scope="module")
def unprivileged_client():
suffix = "".join([random.choice(string.ascii_lowercase + string.digits) for _ in range(8)])
with unprivileged_user(
username=f"unprivileged_{suffix}",
group_name=f"unprivileged_users_{suffix}",
privilege_name=f"Unprivileged users ({suffix})",
allowlist=[],
roles=["READONLY_ADMIN"],
web_shell=False,
) as t:
with Client(websocket_url() + "/websocket") as c:
c.call("auth.login_ex", {
"mechanism": "PASSWORD_PLAIN",
"username": t.username,
"password": t.password,
})
yield c


@pytest.fixture(scope="module")
def ftp_credential():
with credential({
"provider": {
"type": "FTP",
Expand All @@ -27,5 +52,14 @@ def test_adapts_cloud_credentials(c):
"pass": "",
},
}) as cred:
result = c.call("cloudsync.credentials.get_instance", cred["id"])
assert result["provider"] == "FTP"
yield cred


def test_adapts_cloud_credentials(c, ftp_credential):
result = c.call("cloudsync.credentials.get_instance", ftp_credential["id"])
assert result["provider"] == "FTP"


def test_adapts_cloud_credentials_for_unprivileged(unprivileged_client, ftp_credential):
result = unprivileged_client.call("cloudsync.credentials.get_instance", ftp_credential["id"])
assert result["attributes"] == "********"