Skip to content

Commit

Permalink
BEGIN_PUBLIC
Browse files Browse the repository at this point in the history
Add support for testing against more complex subject types.

We can now use subjects such as subjects.result(subjects.str) for something that may return a string, or fail.
END_PUBLIC

PiperOrigin-RevId: 608971309
Change-Id: I9ae61c988a597189b84fb6ccef75c96697c6e364
  • Loading branch information
Googler authored and copybara-github committed Feb 21, 2024
1 parent 35fe45e commit 9eb790f
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 15 deletions.
122 changes: 122 additions & 0 deletions tests/rule_based_toolchain/generics.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2024 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of a result type for use with rules_testing."""

load("@rules_testing//lib:truth.bzl", "subjects")

visibility("//tests/rule_based_toolchain/...")

def result_fn_wrapper(fn):
"""Wraps a function that may fail in a type similar to rust's Result type.
An example usage is the following:
# Implementation file
def get_only(value, fail=fail):
if len(value) == 1:
return value[0]
elif not value:
fail("Unexpectedly empty")
else:
fail("%r had length %d, expected 1" % (value, len(value))
# Test file
load("...", _fn=fn)
fn = result_fn_wrapper(_fn)
int_result = result_subject(subjects.int)
def my_test(env, _):
env.expect.that_value(fn([]), factory=int_result)
.err().equals("Unexpectedly empty")
env.expect.that_value(fn([1]), factory=int_result)
.ok().equals(1)
env.expect.that_value(fn([1, 2]), factory=int_result)
.err().contains("had length 2, expected 1")
Args:
fn: A function that takes in a parameter fail and calls it on failure.
Returns:
On success: struct(ok = <result>, err = None)
On failure: struct(ok = None, err = <first error message>
"""

def new_fn(*args, **kwargs):
# Use a mutable type so that the fail_wrapper can modify this.
failures = []

def fail_wrapper(msg):
failures.append(msg)

result = fn(fail = fail_wrapper, *args, **kwargs)
if failures:
return struct(ok = None, err = failures[0])
else:
return struct(ok = result, err = None)

return new_fn

def result_subject(factory):
"""A subject factory for Result<T>.
Args:
factory: A subject factory for T
Returns:
A subject factory for Result<T>
"""

def new_factory(value, *, meta):
def ok():
if value.err != None:
meta.add_failure("Wanted a value, but got an error", value.err)
return factory(value.ok, meta = meta.derive("ok()"))

def err():
if value.err == None:
meta.add_failure("Wanted an error, but got a value", value.ok)
return subjects.str(value.err, meta = meta.derive("err()"))

return struct(ok = ok, err = err)

return new_factory

def optional_subject(factory):
"""A subject factory for Optional<T>.
Args:
factory: A subject factory for T
Returns:
A subject factory for Optional<T>
"""

def new_factory(value, *, meta):
def some():
if value == None:
meta.add_failure("Wanted a value, but got None", None)
return factory(value, meta = meta)

def is_none():
if value != None:
meta.add_failure("Wanted None, but got a value", value)

return struct(some = some, is_none = is_none)

return new_factory

# Curry subjects.struct so the type is actually generic.
struct_subject = lambda **attrs: lambda value, *, meta: subjects.struct(
value,
meta = meta,
attrs = attrs,
)
43 changes: 28 additions & 15 deletions tests/rule_based_toolchain/subjects.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.
"""Test subjects for cc_toolchain_info providers."""

load("@rules_testing//lib:truth.bzl", "subjects")
load("@bazel_skylib//lib:structs.bzl", "structs")
load("@rules_testing//lib:truth.bzl", _subjects = "subjects")
load(
"//cc/toolchains:cc_toolchain_info.bzl",
"ActionConfigInfo",
Expand All @@ -29,15 +30,16 @@ load(
"ToolInfo",
)
load(":generate_factory.bzl", "ProviderDepset", "ProviderSequence", "generate_factory")
load(":generics.bzl", "optional_subject", "result_subject", "struct_subject", _result_fn_wrapper = "result_fn_wrapper")

visibility("private")
visibility("//tests/rule_based_toolchain/...")

# buildifier: disable=name-conventions
_ActionTypeFactory = generate_factory(
ActionTypeInfo,
"ActionTypeInfo",
dict(
name = subjects.str,
name = _subjects.str,
),
)

Expand All @@ -54,17 +56,17 @@ _ActionTypeSetFactory = generate_factory(
_MutuallyExclusiveCategoryFactory = generate_factory(
MutuallyExclusiveCategoryInfo,
"MutuallyExclusiveCategoryInfo",
dict(name = subjects.str),
dict(name = _subjects.str),
)

_FEATURE_FLAGS = dict(
name = subjects.str,
enabled = subjects.bool,
name = _subjects.str,
enabled = _subjects.bool,
flag_sets = None,
implies = None,
requires_any_of = None,
provides = ProviderSequence(_MutuallyExclusiveCategoryFactory),
known = subjects.bool,
known = _subjects.bool,
overrides = None,
)

Expand Down Expand Up @@ -98,8 +100,8 @@ _AddArgsFactory = generate_factory(
AddArgsInfo,
"AddArgsInfo",
dict(
args = subjects.collection,
files = subjects.depset_file,
args = _subjects.collection,
files = _subjects.depset_file,
),
)

Expand All @@ -110,8 +112,8 @@ _ArgsFactory = generate_factory(
dict(
actions = ProviderDepset(_ActionTypeFactory),
args = ProviderSequence(_AddArgsFactory),
env = subjects.dict,
files = subjects.depset_file,
env = _subjects.dict,
files = _subjects.depset_file,
requires_any_of = ProviderSequence(_FeatureConstraintFactory),
),
)
Expand All @@ -132,9 +134,10 @@ _ToolFactory = generate_factory(
ToolInfo,
"ToolInfo",
dict(
exe = subjects.file,
runifles = subjects.depset_file,
exe = _subjects.file,
runfiles = _subjects.depset_file,
requires_any_of = ProviderSequence(_FeatureConstraintFactory),
execution_requirements = _subjects.collection,
),
)

Expand All @@ -144,11 +147,11 @@ _ActionConfigFactory = generate_factory(
"ActionConfigInfo",
dict(
action = _ActionTypeFactory,
enabled = subjects.bool,
enabled = _subjects.bool,
tools = ProviderSequence(_ToolFactory),
flag_sets = ProviderSequence(_ArgsFactory),
implies = ProviderDepset(_FeatureFactory),
files = subjects.depset_file,
files = _subjects.depset_file,
),
)

Expand Down Expand Up @@ -185,3 +188,13 @@ FACTORIES = [
_ToolFactory,
_ActionConfigSetFactory,
]

result_fn_wrapper = _result_fn_wrapper

subjects = struct(
**(structs.to_dict(_subjects) | dict(
result = result_subject,
optional = optional_subject,
struct = struct_subject,
) | {factory.name: factory.factory for factory in FACTORIES})
)

0 comments on commit 9eb790f

Please sign in to comment.