diff --git a/tests/test_solver.py b/tests/test_solver.py index 30bf1883..6182c0aa 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -1,15 +1,25 @@ import sys -from unfurl.yamlloader import load_yaml, yaml -from unfurl.solver import solve_topology, tosca_to_rust, Node, solve, Field, FieldValue, ToscaValue, SimpleValue +import os +from unfurl.yamlloader import load_yaml, yaml, ImportResolver +from unfurl.solver import ( + solve_topology, + tosca_to_rust, + Node, + solve, + Field, + FieldValue, + ToscaValue, + SimpleValue, +) from toscaparser.tosca_template import ToscaTemplate from toscaparser.properties import Property from toscaparser.elements.portspectype import PortSpec +from toscaparser.common import exception from ruamel.yaml.comments import CommentedMap + def make_tpl(yaml_str: str): - tosca_yaml = load_yaml( - yaml, yaml_str, readonly=True - ) # export uses readonly yaml parser + tosca_yaml = load_yaml(yaml, yaml_str, readonly=True) tosca_yaml["tosca_definitions_version"] = "tosca_simple_unfurl_1_0_0" if "topology_template" not in tosca_yaml: tosca_yaml["topology_template"] = dict( @@ -17,8 +27,9 @@ def make_tpl(yaml_str: str): ) return ToscaTemplate(path=__file__, yaml_dict_tpl=tosca_yaml) + example_helloworld_yaml = """ -description: Template for deploying a single server with predefined properties. +tosca_definitions_version: tosca_simple_unfurl_1_0_0" node_types: Example: derived_from: tosca.nodes.Root @@ -29,6 +40,7 @@ def make_tpl(yaml_str: str): - host: capability: tosca.capabilities.Compute node: tosca.nodes.Compute + occurrences: [1, 1] topology_template: substitution_mappings: @@ -66,12 +78,13 @@ def make_tpl(yaml_str: str): version: "6.5" """ + def test_convert(): for val, toscatype in [ (80, "PortDef"), (CommentedMap(), "map"), (PortSpec.make("80:80"), "tosca.datatypes.network.PortSpec"), - ]: + ]: prop = Property( toscatype, val, @@ -79,8 +92,9 @@ def test_convert(): ) assert tosca_to_rust(prop) + def test_solve(): - f = Field('f', FieldValue.Property(ToscaValue(SimpleValue.integer(0)))) + f = Field("f", FieldValue.Property(ToscaValue(SimpleValue.integer(0)))) na = Node("a", "Foo", fields=[f]) assert na.name == "a" assert na.tosca_type == "Foo" @@ -89,7 +103,7 @@ def test_solve(): types = dict(a=["a", "Root"]) solved = solve(nodes, types) assert not solved - + tosca = make_tpl(example_helloworld_yaml) assert tosca.topology_template @@ -107,6 +121,16 @@ def test_solve(): # test requirement match for each type of CriteriaTerm and Constraint # test restrictions +def test_multiple(): + tosca_yaml = load_yaml(yaml, example_helloworld_yaml, readonly=True) + tosca_yaml["topology_template"]["node_templates"]["db_server2"] = tosca_yaml["topology_template"]["node_templates"]["db_server"].copy() + t = ToscaTemplate(path=__file__, yaml_dict_tpl=tosca_yaml, import_resolver=ImportResolver(None), verify=False) + exception.ExceptionCollector.start() + t.validate_relationships() + assert str(exception.ExceptionCollector.exceptions[-1]) == 'requirement "host" of node "app" found 2 targets more than max occurrences 1' + # t.topology_template.node_templates["app"]._relationships = None + # print(t.topology_template.node_templates["app"].requirements) + assert len(t.topology_template.node_templates["app"].relationships) == 2 def test_node_filter(): tosca_tpl = ( diff --git a/unfurl/solver.py b/unfurl/solver.py index 40973b03..b0ab0809 100644 --- a/unfurl/solver.py +++ b/unfurl/solver.py @@ -1,5 +1,9 @@ +# Copyright (c) 2024 Adam Souzis +# SPDX-License-Identifier: MIT from typing import Any, Dict, List, Optional, Tuple, cast import sys + +# import types from rust extension from .tosca_solver import ( # type: ignore solve, CriteriaTerm, @@ -17,6 +21,7 @@ from toscaparser.properties import Property from toscaparser.nodetemplate import NodeTemplate from toscaparser.topology_template import TopologyTemplate +from toscaparser.common import exception from .eval import Ref, analyze_expr from .logs import getLogger @@ -25,12 +30,18 @@ Solution = Dict[Tuple[str, str], List[Tuple[str, str]]] +# note: make sure Node in rust/lib.rs staying in sync class Node: + """A partial representations of a TOSCA node template (enough for [solve()])""" + def __init__(self, name, type="tosca.nodes.Root", fields=None): self.name: str = name self.tosca_type: str = type self.fields: List[Field] = fields or [] + # Set if any of its fields has restrictions self.has_restrictions: bool = False + self._reqs: Dict[str, int] = {} # extra attribute for book keeping (not used in rust) + def __repr__(self) -> str: return f"Node({self.name}, {self.tosca_type}, {self.has_restrictions}, {self.fields!r})" @@ -195,15 +206,20 @@ def filter2term( def convert( - nt: NodeTemplate, types: Dict[str, List[str]], topology_template: TopologyTemplate + node_template: NodeTemplate, + types: Dict[str, List[str]], + topology_template: TopologyTemplate, ) -> Node: - # XXX if nt is in nested topology and replaced, partially convert the outer node instead - assert nt.type_definition - entity = Node(nt.name, nt.type_definition.type) + # XXX if node_template is in nested topology and replaced, partially convert the outer node instead + if not node_template.type_definition: + return Node(node_template.name, "tosca.nodes.Root") + entity = Node(node_template.name, node_template.type_definition.type) has_restrictions = False # print( entity.name ) - types[nt.type_definition.type] = [p.type for p in nt.type_definition.ancestors()] - for cap in nt.get_capabilities_objects(): + types[node_template.type_definition.type] = [ + p.type for p in node_template.type_definition.ancestors() + ] + for cap in node_template.get_capabilities_objects(): # if cap.name == "feature": # continue types[cap.type_definition.type] = [ @@ -218,34 +234,45 @@ def convert( Field(cap.name, FieldValue.Capability(cap.type_definition.type, cap_fields)) ) - for prop in nt.get_properties_objects(): + for prop in node_template.get_properties_objects(): if include_value(prop.value): entity.fields.append(prop2field(prop)) type_requirements: Dict[str, Dict[str, Any]] = ( - nt.type_definition.requirement_definitions + node_template.type_definition.requirement_definitions ) - for name, req_dict in nt.all_requirements: + for name, req_dict in node_template.all_requirements: type_req_dict: Optional[Dict[str, Any]] = type_requirements.get(name) on_type_only = not bool(req_dict) if type_req_dict: + type_req_dict = type_req_dict.copy() + for key in ("node", "relationship", "capability"): + if key in req_dict: + type_req_dict.pop("!namespace-" + key, None) req_dict = dict(type_req_dict, **req_dict) - required = "occurrences" not in req_dict or req_dict["occurrences"][0] + if "occurrences" not in req_dict: + required = True + upper = sys.maxsize + else: + required = bool(req_dict["occurrences"][0]) + max_occurences = req_dict["occurrences"][1] + upper = sys.maxsize if max_occurences == "UNBOUNDED" else int(max_occurences) + # note: ok if multiple requirements with same name on the template, then occurrences should be on type + entity._reqs[name] = upper match_type = not on_type_only or required field, found_restrictions = get_req_terms( - nt, types, topology_template, name, req_dict, match_type + node_template, types, topology_template, name, req_dict, match_type ) - # print("terms", terms) if field: entity.fields.append(field) if found_restrictions: has_restrictions = True - # print("rels", nt.relationships, nt.missing_requirements) + # print("rels", node_template.relationships, node_template.missing_requirements) entity.has_restrictions = has_restrictions return entity -def add_match(terms, match): +def add_match(terms, match) -> None: if isinstance(match, dict) and (node_type := match.get("get_nodes_of_type")): terms.append(CriteriaTerm.NodeType(node_type)) else: @@ -380,37 +407,37 @@ def solve_topology(topology_template: TopologyTemplate) -> Solution: # print("missing", topology_template.node_templates["app"].missing_requirements) # print ('types', types) - logger.debug("!solving " + "\n\n".join(repr(n) for n in nodes.values())) + # print("!solving " + "\n\n".join(repr(n) for n in nodes.values())) solved = cast(Solution, solve(nodes, types)) - logger.debug(f"!solved!") # {solved}") + logger.debug(f"Solve found {len(solved)} matches for {len(nodes)}.") for (source_name, req), targets in solved.items(): - source = topology_template.node_templates[source_name] + source: NodeTemplate = topology_template.node_templates[source_name] + target_nodes = [ + (topology_template.node_templates[node], cap) for (node, cap) in targets + ] if len(targets) > 1: - # filter out defaults - target_nodes = [ - (t, cap) - for (t, cap) in ( - (topology_template.node_templates[node], cap) - for (node, cap) in targets - ) - if "default" not in t.directives + # filter out default nodes + no_defaults = [ + (t, cap) for (t, cap) in target_nodes if "default" not in t.directives ] - # XXX if node filter: report ambiguity - if not target_nodes: - continue # hmm... more than one default match? - if len(target_nodes) > 1: - # XXX don't just skip, only treat as error if exceeds occurrences - continue - target_node, cap = target_nodes[0] - else: - assert targets - target_node = topology_template.node_templates[targets[0][0]] - cap = targets[0][1] - target = target_node.name - # print("solved", source, req, target_node, cap) - # pass target to handle case when there is more than one match per requirement - req_dict = source.find_or_add_requirement(req, target) - req_dict["node"] = target - if cap != "feature": - req_dict["capability"] = cap + if no_defaults: + target_nodes = no_defaults + max_occurrences = nodes[source_name]._reqs[req] + if len(target_nodes) > max_occurrences: + exception.ExceptionCollector.appendException( + exception.ValidationError( + message='requirement "%s" of node "%s" found %s targets more than max occurrences %s' + % (req, source_name, len(target_nodes), max_occurrences) + ) + ) + for target_node, cap in target_nodes: + _set_target(source, req, cap, target_node.name) return solved + + +def _set_target(source: NodeTemplate, req_name: str, cap: str, target: str) -> None: + # updates requirements yaml directly so NodeTemplate won't search for a match later + req_dict: dict = source.find_or_add_requirement(req_name, target) + req_dict["node"] = target + if cap != "feature": + req_dict["capability"] = cap