Skip to content

Commit

Permalink
dsl: improve type safety on TOSCA constraints and operation decorator.
Browse files Browse the repository at this point in the history
  • Loading branch information
aszs committed Dec 17, 2024
1 parent 10c4a8a commit ee525e5
Showing 1 changed file with 54 additions and 34 deletions.
88 changes: 54 additions & 34 deletions tosca-package/tosca/_tosca.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def safe_mode() -> bool:
def global_state_mode() -> str:
"""
This function returns the execution state (either "spec" or "runtime") that the current thread is in.
Returns "spec" or "runtime"
"""
return global_state.mode

Expand Down Expand Up @@ -119,8 +121,8 @@ def set_evaluation_mode(mode: str):
.. code-block:: python
with set_mode("spec"):
assert tosca.global_state.mode == "spec"
with set_evaluation_mode("spec"):
assert tosca.global_state_mode() == "spec"
"""
saved = global_state.mode
try:
Expand Down Expand Up @@ -164,52 +166,52 @@ def apply_constraint(self, val: T) -> bool:
return True


class equal(DataConstraint):
class equal(DataConstraint[T]):
pass


class greater_than(DataConstraint):
class greater_than(DataConstraint[T]):
pass


class greater_or_equal(DataConstraint):
class greater_or_equal(DataConstraint[T]):
pass


class less_than(DataConstraint):
class less_than(DataConstraint[T]):
pass


class less_or_equal(DataConstraint):
class less_or_equal(DataConstraint[T]):
pass


class in_range(DataConstraint, Generic[T]):
class in_range(DataConstraint[T]):
def __init__(self, min: T, max: T):
super().__init__([min, max])
self.constraint = [min, max] # type: ignore


class valid_values(DataConstraint):
class valid_values(DataConstraint[T]):
pass


class length(DataConstraint):
class length(DataConstraint[T]):
pass


class min_length(DataConstraint):
class min_length(DataConstraint[T]):
pass


class max_length(DataConstraint):
class max_length(DataConstraint[T]):
pass


class pattern(DataConstraint):
class pattern(DataConstraint[T]):
pass


class schema(DataConstraint):
class schema(DataConstraint[T]):
pass


Expand Down Expand Up @@ -293,7 +295,10 @@ class OperationFunc(Protocol):
invoke: Optional[str]


@overload
def operation(
func: None = None,
*,
name="",
apply_to: Optional[Sequence[str]] = None,
timeout: Optional[float] = None,
Expand All @@ -303,7 +308,26 @@ def operation(
outputs: Optional[Dict[str, Optional[str]]] = None,
entry_state: Optional[str] = None,
invoke: Optional[str] = None,
) -> Callable[[Callable], Callable]:
) -> Callable[[F], F]: ...


@overload
def operation(func: F) -> F: ...


def operation(
func: Optional[F] = None,
*,
name="",
apply_to: Optional[Sequence[str]] = None,
timeout: Optional[float] = None,
operation_host: Optional[str] = None,
environment: Optional[Dict[str, str]] = None,
dependencies: Optional[List[Union[str, Dict[str, Any]]]] = None,
outputs: Optional[Dict[str, Optional[str]]] = None,
entry_state: Optional[str] = None,
invoke: Optional[str] = None,
) -> Union[F, Callable[[F], F]]:
"""Function decorator that marks a function or methods as a TOSCA operation.
Args:
Expand All @@ -326,7 +350,7 @@ def default(self):
return self.my_artifact.execute()
"""

def decorator_operation(func_: Callable) -> Callable:
def decorator_operation(func_: F) -> F:
func = cast(OperationFunc, func_)
func.operation_name = name or func.__name__
func.apply_to = apply_to
Expand Down Expand Up @@ -397,23 +421,19 @@ def __repr__(self) -> str:
any="object",
range="Tuple[int, int]",
)
TOSCA_SIMPLE_TYPES.update(
{
"scalar-unit.size": "Size",
"scalar-unit.frequency": "Frequency",
"scalar-unit.time": "Time",
"scalar-unit.bitrate": "Bitrate",
}
)
TOSCA_SIMPLE_TYPES.update({
"scalar-unit.size": "Size",
"scalar-unit.frequency": "Frequency",
"scalar-unit.time": "Time",
"scalar-unit.bitrate": "Bitrate",
})

PYTHON_TO_TOSCA_TYPES = {v: k for k, v in TOSCA_SIMPLE_TYPES.items()}
PYTHON_TO_TOSCA_TYPES.update(
{
"Tuple": "range",
"dict": "map",
"list": "list",
}
)
PYTHON_TO_TOSCA_TYPES.update({
"Tuple": "range",
"dict": "map",
"list": "list",
})

TOSCA_SHORT_NAMES = {
"PortDef": "tosca.datatypes.network.PortDef",
Expand Down Expand Up @@ -1085,8 +1105,8 @@ def _to_artifact_yaml(self, converter: Optional["PythonToYaml"]) -> Dict[str, An
assert len(info.types) == 1
_type = info.types[0]
assert issubclass(_type, _ToscaType), (self, _type)
cap_def: dict = yaml_cls(type=_type.tosca_type_name())
return cap_def
type_only_def: dict = yaml_cls(type=_type.tosca_type_name())
return type_only_def

def pytype_to_tosca_schema(self, _type) -> Tuple[dict, bool]:
# dict[str, list[int, constraint], constraint]
Expand Down

0 comments on commit ee525e5

Please sign in to comment.