Skip to content

Commit

Permalink
solver: better handling of requirements with multiple matches.
Browse files Browse the repository at this point in the history
  • Loading branch information
aszs committed Oct 10, 2024
1 parent 2991f51 commit 2889642
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 52 deletions.
42 changes: 33 additions & 9 deletions tests/test_solver.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
import sys
from unfurl.yamlloader import load_yaml, yaml
from unfurl.solver import solve_topology, tosca_to_rust, Node, solve, Field, FieldValue, ToscaValue, SimpleValue
import os
from unfurl.yamlloader import load_yaml, yaml, ImportResolver
from unfurl.solver import (
solve_topology,
tosca_to_rust,
Node,
solve,
Field,
FieldValue,
ToscaValue,
SimpleValue,
)
from toscaparser.tosca_template import ToscaTemplate
from toscaparser.properties import Property
from toscaparser.elements.portspectype import PortSpec
from toscaparser.common import exception
from ruamel.yaml.comments import CommentedMap


def make_tpl(yaml_str: str):
tosca_yaml = load_yaml(
yaml, yaml_str, readonly=True
) # export uses readonly yaml parser
tosca_yaml = load_yaml(yaml, yaml_str, readonly=True)
tosca_yaml["tosca_definitions_version"] = "tosca_simple_unfurl_1_0_0"
if "topology_template" not in tosca_yaml:
tosca_yaml["topology_template"] = dict(
node_templates={}, relationship_templates={}
)
return ToscaTemplate(path=__file__, yaml_dict_tpl=tosca_yaml)


example_helloworld_yaml = """
description: Template for deploying a single server with predefined properties.
tosca_definitions_version: tosca_simple_unfurl_1_0_0"
node_types:
Example:
derived_from: tosca.nodes.Root
Expand All @@ -29,6 +40,7 @@ def make_tpl(yaml_str: str):
- host:
capability: tosca.capabilities.Compute
node: tosca.nodes.Compute
occurrences: [1, 1]
topology_template:
substitution_mappings:
Expand Down Expand Up @@ -66,21 +78,23 @@ def make_tpl(yaml_str: str):
version: "6.5"
"""


def test_convert():
for val, toscatype in [
(80, "PortDef"),
(CommentedMap(), "map"),
(PortSpec.make("80:80"), "tosca.datatypes.network.PortSpec"),
]:
]:
prop = Property(
toscatype,
val,
dict(type=toscatype),
)
assert tosca_to_rust(prop)


def test_solve():
f = Field('f', FieldValue.Property(ToscaValue(SimpleValue.integer(0))))
f = Field("f", FieldValue.Property(ToscaValue(SimpleValue.integer(0))))
na = Node("a", "Foo", fields=[f])
assert na.name == "a"
assert na.tosca_type == "Foo"
Expand All @@ -89,7 +103,7 @@ def test_solve():
types = dict(a=["a", "Root"])
solved = solve(nodes, types)
assert not solved

tosca = make_tpl(example_helloworld_yaml)
assert tosca.topology_template

Expand All @@ -107,6 +121,16 @@ def test_solve():
# test requirement match for each type of CriteriaTerm and Constraint
# test restrictions

def test_multiple():
tosca_yaml = load_yaml(yaml, example_helloworld_yaml, readonly=True)
tosca_yaml["topology_template"]["node_templates"]["db_server2"] = tosca_yaml["topology_template"]["node_templates"]["db_server"].copy()
t = ToscaTemplate(path=__file__, yaml_dict_tpl=tosca_yaml, import_resolver=ImportResolver(None), verify=False)
exception.ExceptionCollector.start()
t.validate_relationships()
assert str(exception.ExceptionCollector.exceptions[-1]) == 'requirement "host" of node "app" found 2 targets more than max occurrences 1'
# t.topology_template.node_templates["app"]._relationships = None
# print(t.topology_template.node_templates["app"].requirements)
assert len(t.topology_template.node_templates["app"].relationships) == 2

def test_node_filter():
tosca_tpl = (
Expand Down
113 changes: 70 additions & 43 deletions unfurl/solver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright (c) 2024 Adam Souzis
# SPDX-License-Identifier: MIT
from typing import Any, Dict, List, Optional, Tuple, cast
import sys

# import types from rust extension
from .tosca_solver import ( # type: ignore
solve,
CriteriaTerm,
Expand All @@ -17,6 +21,7 @@
from toscaparser.properties import Property
from toscaparser.nodetemplate import NodeTemplate
from toscaparser.topology_template import TopologyTemplate
from toscaparser.common import exception
from .eval import Ref, analyze_expr
from .logs import getLogger

Expand All @@ -25,12 +30,18 @@
Solution = Dict[Tuple[str, str], List[Tuple[str, str]]]


# note: make sure Node in rust/lib.rs staying in sync
class Node:
"""A partial representations of a TOSCA node template (enough for [solve()])"""

def __init__(self, name, type="tosca.nodes.Root", fields=None):
self.name: str = name
self.tosca_type: str = type
self.fields: List[Field] = fields or []
# Set if any of its fields has restrictions
self.has_restrictions: bool = False
self._reqs: Dict[str, int] = {} # extra attribute for book keeping (not used in rust)


def __repr__(self) -> str:
return f"Node({self.name}, {self.tosca_type}, {self.has_restrictions}, {self.fields!r})"
Expand Down Expand Up @@ -195,15 +206,20 @@ def filter2term(


def convert(
nt: NodeTemplate, types: Dict[str, List[str]], topology_template: TopologyTemplate
node_template: NodeTemplate,
types: Dict[str, List[str]],
topology_template: TopologyTemplate,
) -> Node:
# XXX if nt is in nested topology and replaced, partially convert the outer node instead
assert nt.type_definition
entity = Node(nt.name, nt.type_definition.type)
# XXX if node_template is in nested topology and replaced, partially convert the outer node instead
if not node_template.type_definition:
return Node(node_template.name, "tosca.nodes.Root")
entity = Node(node_template.name, node_template.type_definition.type)
has_restrictions = False
# print( entity.name )
types[nt.type_definition.type] = [p.type for p in nt.type_definition.ancestors()]
for cap in nt.get_capabilities_objects():
types[node_template.type_definition.type] = [
p.type for p in node_template.type_definition.ancestors()
]
for cap in node_template.get_capabilities_objects():
# if cap.name == "feature":
# continue
types[cap.type_definition.type] = [
Expand All @@ -218,34 +234,45 @@ def convert(
Field(cap.name, FieldValue.Capability(cap.type_definition.type, cap_fields))
)

for prop in nt.get_properties_objects():
for prop in node_template.get_properties_objects():
if include_value(prop.value):
entity.fields.append(prop2field(prop))

type_requirements: Dict[str, Dict[str, Any]] = (
nt.type_definition.requirement_definitions
node_template.type_definition.requirement_definitions
)
for name, req_dict in nt.all_requirements:
for name, req_dict in node_template.all_requirements:
type_req_dict: Optional[Dict[str, Any]] = type_requirements.get(name)
on_type_only = not bool(req_dict)
if type_req_dict:
type_req_dict = type_req_dict.copy()
for key in ("node", "relationship", "capability"):
if key in req_dict:
type_req_dict.pop("!namespace-" + key, None)
req_dict = dict(type_req_dict, **req_dict)
required = "occurrences" not in req_dict or req_dict["occurrences"][0]
if "occurrences" not in req_dict:
required = True
upper = sys.maxsize
else:
required = bool(req_dict["occurrences"][0])
max_occurences = req_dict["occurrences"][1]
upper = sys.maxsize if max_occurences == "UNBOUNDED" else int(max_occurences)
# note: ok if multiple requirements with same name on the template, then occurrences should be on type
entity._reqs[name] = upper
match_type = not on_type_only or required
field, found_restrictions = get_req_terms(
nt, types, topology_template, name, req_dict, match_type
node_template, types, topology_template, name, req_dict, match_type
)
# print("terms", terms)
if field:
entity.fields.append(field)
if found_restrictions:
has_restrictions = True
# print("rels", nt.relationships, nt.missing_requirements)
# print("rels", node_template.relationships, node_template.missing_requirements)
entity.has_restrictions = has_restrictions
return entity


def add_match(terms, match):
def add_match(terms, match) -> None:
if isinstance(match, dict) and (node_type := match.get("get_nodes_of_type")):
terms.append(CriteriaTerm.NodeType(node_type))
else:
Expand Down Expand Up @@ -380,37 +407,37 @@ def solve_topology(topology_template: TopologyTemplate) -> Solution:

# print("missing", topology_template.node_templates["app"].missing_requirements)
# print ('types', types)
logger.debug("!solving " + "\n\n".join(repr(n) for n in nodes.values()))
# print("!solving " + "\n\n".join(repr(n) for n in nodes.values()))
solved = cast(Solution, solve(nodes, types))
logger.debug(f"!solved!") # {solved}")
logger.debug(f"Solve found {len(solved)} matches for {len(nodes)}.")
for (source_name, req), targets in solved.items():
source = topology_template.node_templates[source_name]
source: NodeTemplate = topology_template.node_templates[source_name]
target_nodes = [
(topology_template.node_templates[node], cap) for (node, cap) in targets
]
if len(targets) > 1:
# filter out defaults
target_nodes = [
(t, cap)
for (t, cap) in (
(topology_template.node_templates[node], cap)
for (node, cap) in targets
)
if "default" not in t.directives
# filter out default nodes
no_defaults = [
(t, cap) for (t, cap) in target_nodes if "default" not in t.directives
]
# XXX if node filter: report ambiguity
if not target_nodes:
continue # hmm... more than one default match?
if len(target_nodes) > 1:
# XXX don't just skip, only treat as error if exceeds occurrences
continue
target_node, cap = target_nodes[0]
else:
assert targets
target_node = topology_template.node_templates[targets[0][0]]
cap = targets[0][1]
target = target_node.name
# print("solved", source, req, target_node, cap)
# pass target to handle case when there is more than one match per requirement
req_dict = source.find_or_add_requirement(req, target)
req_dict["node"] = target
if cap != "feature":
req_dict["capability"] = cap
if no_defaults:
target_nodes = no_defaults
max_occurrences = nodes[source_name]._reqs[req]
if len(target_nodes) > max_occurrences:
exception.ExceptionCollector.appendException(
exception.ValidationError(
message='requirement "%s" of node "%s" found %s targets more than max occurrences %s'
% (req, source_name, len(target_nodes), max_occurrences)
)
)
for target_node, cap in target_nodes:
_set_target(source, req, cap, target_node.name)
return solved


def _set_target(source: NodeTemplate, req_name: str, cap: str, target: str) -> None:
# updates requirements yaml directly so NodeTemplate won't search for a match later
req_dict: dict = source.find_or_add_requirement(req_name, target)
req_dict["node"] = target
if cap != "feature":
req_dict["capability"] = cap

0 comments on commit 2889642

Please sign in to comment.