Skip to content

Commit

Permalink
added back previous templated scoping mechanism which is sufficient a…
Browse files Browse the repository at this point in the history
…fter rework
  • Loading branch information
drexlerd committed Apr 12, 2024
1 parent 2e0d952 commit 17448a4
Show file tree
Hide file tree
Showing 16 changed files with 177 additions and 320 deletions.
78 changes: 38 additions & 40 deletions include/loki/pddl/scope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,36 @@ namespace loki
/// The position points to the matched location
/// in the input stream and is used for error reporting.
template<typename T>
using BindingValueType = std::tuple<T, std::optional<Position>>;
using BindingValueType = std::tuple<PDDLElement<T>, std::optional<Position>>;

/// @brief Datastructure to store bindings of a type T.
template<typename T>
using Bindings = std::unordered_map<std::string, BindingValueType<T>>;
using BindingMapType = std::unordered_map<std::string, BindingValueType<T>>;

/// @brief Encapsulates bindings for different types.
template<typename... Ts>
class Bindings
{
private:
std::tuple<BindingMapType<Ts>...> bindings;

public:
/// @brief Returns a binding if it exists.
template<typename T>
std::optional<BindingValueType<T>> get(const std::string& key) const;

/// @brief Inserts a binding of type T
template<typename T>
void insert(const std::string& key, const PDDLElement<T>& binding, const std::optional<Position>& position);
};

/// @brief Wraps bindings in a scope with reference to a parent scope.
class Scope
{
private:
const Scope* m_parent_scope;

Bindings<pddl::Type> m_types;
Bindings<pddl::Object> m_objects;
Bindings<pddl::FunctionSkeleton> m_function_skeletons;
Bindings<pddl::Variable> m_variables;
Bindings<pddl::Predicate> m_predicates;
Bindings<pddl::Predicate> m_derived_predicates;
Bindings<pddl::TypeImpl, pddl::ObjectImpl, pddl::PredicateImpl, pddl::FunctionSkeletonImpl, pddl::VariableImpl> bindings;

public:
explicit Scope(const Scope* parent_scope = nullptr);
Expand All @@ -71,26 +83,18 @@ class Scope
Scope(Scope&& other) = delete;
Scope& operator=(Scope&& other) = delete;

/// @brief Return a binding if it exists.
std::optional<BindingValueType<pddl::Type>> get_type(const std::string& name) const;
std::optional<BindingValueType<pddl::Object>> get_object(const std::string& name) const;
std::optional<BindingValueType<pddl::FunctionSkeleton>> get_function_skeleton(const std::string& name) const;
std::optional<BindingValueType<pddl::Variable>> get_variable(const std::string& name) const;
std::optional<BindingValueType<pddl::Predicate>> get_predicate(const std::string& name) const;
std::optional<BindingValueType<pddl::Predicate>> get_derived_predicate(const std::string& name) const;

/// @brief Insert a binding.
void insert_type(const std::string& name, const pddl::Type& type, const std::optional<Position>& position);
void insert_object(const std::string& name, const pddl::Object& object, const std::optional<Position>& position);
void insert_function_skeleton(const std::string& name, const pddl::FunctionSkeleton& function_skeleton, const std::optional<Position>& position);
void insert_variable(const std::string& name, const pddl::Variable& variable, const std::optional<Position>& position);
void insert_predicate(const std::string& name, const pddl::Predicate& predicate, const std::optional<Position>& position);
void insert_derived_predicate(const std::string& name, const pddl::Predicate& derived_predicate, const std::optional<Position>& position);
/// @brief Returns a binding if it exists.
template<typename T>
std::optional<BindingValueType<T>> get(const std::string& name) const;

/// @brief Insert a binding of type T.
template<typename T>
void insert(const std::string& name, const PDDLElement<T>& element, const std::optional<Position>& position);
};

/// @brief Encapsulates the result of search for a binding with the corresponding ErrorHandler.
template<typename T>
using ScopeStackSearchResult = std::tuple<T, const std::optional<Position>, const PDDLErrorHandler&>;
using ScopeStackSearchResult = std::tuple<const PDDLElement<T>, const std::optional<Position>, const PDDLErrorHandler&>;

/// @brief Implements a scoping mechanism to store bindings which are mappings from name to a pointer to a PDDL object
/// type and a position in the input stream that can be used to construct error messages with the given ErrorHandler.
Expand Down Expand Up @@ -130,21 +134,13 @@ class ScopeStack
/// @brief Deletes the topmost scope from the stack.
void close_scope();

/// @brief Return a binding if it exists.
std::optional<ScopeStackSearchResult<pddl::Type>> get_type(const std::string& name) const;
std::optional<ScopeStackSearchResult<pddl::Object>> get_object(const std::string& name) const;
std::optional<ScopeStackSearchResult<pddl::FunctionSkeleton>> get_function_skeleton(const std::string& name) const;
std::optional<ScopeStackSearchResult<pddl::Variable>> get_variable(const std::string& name) const;
std::optional<ScopeStackSearchResult<pddl::Predicate>> get_predicate(const std::string& name) const;
std::optional<ScopeStackSearchResult<pddl::Predicate>> get_derived_predicate(const std::string& name) const;

/// @brief Insert a binding.
void insert_type(const std::string& name, const pddl::Type& type, const std::optional<Position>& position);
void insert_object(const std::string& name, const pddl::Object& object, const std::optional<Position>& position);
void insert_function_skeleton(const std::string& name, const pddl::FunctionSkeleton& function_skeleton, const std::optional<Position>& position);
void insert_variable(const std::string& name, const pddl::Variable& variable, const std::optional<Position>& position);
void insert_predicate(const std::string& name, const pddl::Predicate& predicate, const std::optional<Position>& position);
void insert_derived_predicate(const std::string& name, const pddl::Predicate& derived_predicate, const std::optional<Position>& position);
/// @brief Returns a binding if it exists.
template<typename T>
std::optional<ScopeStackSearchResult<T>> get(const std::string& name) const;

/// @brief Insert a binding of type T.
template<typename T>
void insert(const std::string& name, const PDDLElement<T>& element, const std::optional<Position>& position);

/// @brief Get the error handler to print an error message.
const PDDLErrorHandler& get_error_handler() const;
Expand All @@ -155,4 +151,6 @@ class ScopeStack

}

#endif
#include "scope.tpp"

#endif
87 changes: 87 additions & 0 deletions include/loki/pddl/scope.tpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (C) 2023 Dominik Drexler
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

namespace loki
{

template<typename... Ts>
template<typename T>
std::optional<BindingValueType<T>> Bindings<Ts...>::get(const std::string& key) const
{
const auto& t_bindings = std::get<BindingMapType<T>>(bindings);
auto it = t_bindings.find(key);
if (it != t_bindings.end())
{
return { it->second };
}
return std::nullopt;
}

template<typename... Ts>
template<typename T>
void Bindings<Ts...>::insert(const std::string& key, const PDDLElement<T>& element, const std::optional<Position>& position)
{
assert(element);
auto& t_bindings = std::get<BindingMapType<T>>(bindings);
assert(!t_bindings.count(key));
t_bindings.emplace(key, std::make_tuple(element, position));
}

template<typename T>
std::optional<BindingValueType<T>> Scope::get(const std::string& name) const
{
const auto result = bindings.get<T>(name);
if (result.has_value())
return result.value();
if (m_parent_scope)
{
return m_parent_scope->get<T>(name);
}
return std::nullopt;
}

template<typename T>
void Scope::insert(const std::string& name, const PDDLElement<T>& element, const std::optional<Position>& position)
{
assert(element);
assert(!this->get<T>(name));
bindings.insert<T>(name, element, position);
}

template<typename T>
std::optional<ScopeStackSearchResult<T>> ScopeStack::get(const std::string& name) const
{
assert(!m_stack.empty());
auto result = m_stack.back()->get<T>(name);
if (result.has_value())
{
return std::make_tuple(std::get<0>(result.value()), std::get<1>(result.value()), std::cref(m_error_handler));
}
if (m_parent)
return m_parent->get<T>(name);
return std::nullopt;
}

/// @brief Insert a binding of type T.
template<typename T>
void ScopeStack::insert(const std::string& name, const PDDLElement<T>& element, const std::optional<Position>& position)
{
assert(!m_stack.empty());
m_stack.back()->insert(name, element, position);
}

}
6 changes: 3 additions & 3 deletions src/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ DomainParser::DomainParser(const fs::path& file_path) :
// Create base types.
const auto base_type_object = context.factories.types.get_or_create<pddl::TypeImpl>("object", pddl::TypeList());
const auto base_type_number = context.factories.types.get_or_create<pddl::TypeImpl>("number", pddl::TypeList());
context.scopes.insert_type("object", base_type_object, {});
context.scopes.insert_type("number", base_type_number, {});
context.scopes.insert("object", base_type_object, {});
context.scopes.insert("number", base_type_number, {});

// Create equal predicate with name "=" and two parameters "?left_arg" and "?right_arg"
const auto binary_parameterlist = pddl::ParameterList {
Expand All @@ -75,7 +75,7 @@ DomainParser::DomainParser(const fs::path& file_path) :

};
const auto equal_predicate = context.factories.predicates.get_or_create<pddl::PredicateImpl>("=", binary_parameterlist);
context.scopes.insert_predicate("=", equal_predicate, {});
context.scopes.insert("=", equal_predicate, {});

m_domain = parse(node, context);

Expand Down
10 changes: 5 additions & 5 deletions src/pddl/parser/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pddl::Term TermDeclarationTermVisitor::operator()(const ast::Name& node) const
{
const auto constant_name = parse(node);
// Test for undefined constant.
const auto binding = context.scopes.get_object(constant_name);
const auto binding = context.scopes.get<pddl::ObjectImpl>(constant_name);
if (!binding.has_value())
{
throw UndefinedConstantError(constant_name, context.scopes.get_error_handler()(node, ""));
Expand All @@ -63,7 +63,7 @@ pddl::Term TermDeclarationTermVisitor::operator()(const ast::Variable& node) con
{
const auto variable = parse(node, context);
// Test for multiple definition
const auto binding = context.scopes.get_variable(variable->get_name());
const auto binding = context.scopes.get<pddl::VariableImpl>(variable->get_name());
if (binding.has_value())
{
const auto message_1 = context.scopes.get_error_handler()(node, "Defined here:");
Expand All @@ -73,7 +73,7 @@ pddl::Term TermDeclarationTermVisitor::operator()(const ast::Variable& node) con
throw MultiDefinitionVariableError(variable->get_name(), message_1 + message_2);
}
// Add binding to scope
context.scopes.insert_variable(variable->get_name(), variable, node);
context.scopes.insert(variable->get_name(), variable, node);
// Construct Term and return it
const auto term = context.factories.terms.get_or_create<pddl::TermVariableImpl>(variable);
// Add position of PDDL object
Expand All @@ -87,7 +87,7 @@ pddl::Term TermReferenceTermVisitor::operator()(const ast::Name& node) const
{
const auto object_name = parse(node);
// Test for undefined constant.
const auto binding = context.scopes.get_object(object_name);
const auto binding = context.scopes.get<pddl::ObjectImpl>(object_name);
if (!binding.has_value())
{
throw UndefinedConstantError(object_name, context.scopes.get_error_handler()(node, ""));
Expand All @@ -105,7 +105,7 @@ pddl::Term TermReferenceTermVisitor::operator()(const ast::Variable& node) const
{
const auto variable = parse(node, context);
// Test for undefined variable
const auto binding = context.scopes.get_variable(variable->get_name());
const auto binding = context.scopes.get<pddl::VariableImpl>(variable->get_name());
if (!binding.has_value())
{
throw UndefinedVariableError(variable->get_name(), context.scopes.get_error_handler()(node, ""));
Expand Down
8 changes: 4 additions & 4 deletions src/pddl/parser/constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace loki
static void test_multiple_definition(const pddl::Object& constant, const ast::Name& node, const Context& context)
{
const auto constant_name = constant->get_name();
const auto binding = context.scopes.get_object(constant_name);
const auto binding = context.scopes.get<pddl::ObjectImpl>(constant_name);
if (binding.has_value())
{
const auto message_1 = context.scopes.get_error_handler()(node, "Defined here:");
Expand All @@ -45,7 +45,7 @@ static void test_multiple_definition(const pddl::Object& constant, const ast::Na
static void insert_context_information(const pddl::Object& constant, const ast::Name& node, Context& context)
{
context.positions.push_back(constant, node);
context.scopes.insert_object(constant->get_name(), constant, node);
context.scopes.insert(constant->get_name(), constant, node);
}

static pddl::Object parse_constant_definition(const ast::Name& node, const pddl::TypeList& type_list, Context& context)
Expand Down Expand Up @@ -76,8 +76,8 @@ ConstantListVisitor::ConstantListVisitor(Context& context_) : context(context_)
pddl::ObjectList ConstantListVisitor::operator()(const std::vector<ast::Name>& name_nodes)
{
// std::vector<ast::Name> has single base type "object"
assert(context.scopes.get_type("object").has_value());
const auto [type, _position, _error_handler] = context.scopes.get_type("object").value();
assert(context.scopes.get<pddl::TypeImpl>("object").has_value());
const auto [type, _position, _error_handler] = context.scopes.get<pddl::TypeImpl>("object").value();
return parse_constant_definitions(name_nodes, pddl::TypeList { type }, context);
}

Expand Down
2 changes: 1 addition & 1 deletion src/pddl/parser/effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ pddl::Effect parse(const ast::EffectProductionNumericFluentTotalCost& node, Cont
const auto assign_operator_increase = parse(node.assign_operator_increase);
auto function_name = parse(node.function_symbol_total_cost.name);
assert(function_name == "total-cost");
auto binding = context.scopes.get_function_skeleton(function_name);
auto binding = context.scopes.get<pddl::FunctionSkeletonImpl>(function_name);
if (!binding.has_value())
{
throw UndefinedFunctionSkeletonError(function_name, context.scopes.get_error_handler()(node.function_symbol_total_cost, ""));
Expand Down
14 changes: 7 additions & 7 deletions src/pddl/parser/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pddl::Function parse(const ast::FunctionHead& node, Context& context)
pddl::FunctionSkeleton parse_function_skeleton_reference(const ast::FunctionSymbol& node, Context& context)
{
auto function_name = parse(node.name);
auto binding = context.scopes.get_function_skeleton(function_name);
auto binding = context.scopes.get<pddl::FunctionSkeletonImpl>(function_name);
if (!binding.has_value())
{
throw UndefinedFunctionSkeletonError(function_name, context.scopes.get_error_handler()(node, ""));
Expand All @@ -122,7 +122,7 @@ pddl::FunctionSkeleton parse_function_skeleton_reference(const ast::FunctionSymb
static void test_multiple_definition(const pddl::FunctionSkeleton& function_skeleton, const ast::Name& node, const Context& context)
{
const auto function_name = function_skeleton->get_name();
const auto binding = context.scopes.get_function_skeleton(function_name);
const auto binding = context.scopes.get<pddl::FunctionSkeletonImpl>(function_name);
if (binding.has_value())
{
const auto message_1 = context.scopes.get_error_handler()(node, "Defined here:");
Expand All @@ -139,7 +139,7 @@ static void test_multiple_definition(const pddl::FunctionSkeleton& function_skel
static void insert_context_information(const pddl::FunctionSkeleton& function_skeleton, const ast::Name& node, Context& context)
{
context.positions.push_back(function_skeleton, node);
context.scopes.insert_function_skeleton(function_skeleton->get_name(), function_skeleton, node);
context.scopes.insert(function_skeleton->get_name(), function_skeleton, node);
}

pddl::FunctionSkeleton parse(const ast::AtomicFunctionSkeletonTotalCost& node, Context& context)
Expand All @@ -162,8 +162,8 @@ pddl::FunctionSkeleton parse(const ast::AtomicFunctionSkeletonTotalCost& node, C
context.references.untrack(pddl::RequirementEnum::NUMERIC_FLUENTS);
}

assert(context.scopes.get_type("number").has_value());
const auto [type, _position, _error_handler] = context.scopes.get_type("number").value();
assert(context.scopes.get<pddl::TypeImpl>("number").has_value());
const auto [type, _position, _error_handler] = context.scopes.get<pddl::TypeImpl>("number").value();
auto function_name = parse(node.function_symbol.name);
auto function_skeleton = context.factories.function_skeletons.get_or_create<pddl::FunctionSkeletonImpl>(function_name, pddl::ParameterList {}, type);

Expand All @@ -185,8 +185,8 @@ pddl::FunctionSkeleton parse(const ast::AtomicFunctionSkeletonGeneral& node, Con
auto function_parameters = boost::apply_visitor(ParameterListVisitor(context), node.arguments);
context.scopes.close_scope();

assert(context.scopes.get_type("number").has_value());
const auto [type, _position, _error_handler] = context.scopes.get_type("number").value();
assert(context.scopes.get<pddl::TypeImpl>("number").has_value());
const auto [type, _position, _error_handler] = context.scopes.get<pddl::TypeImpl>("number").value();
auto function_name = parse(node.function_symbol.name);
auto function_skeleton = context.factories.function_skeletons.get_or_create<pddl::FunctionSkeletonImpl>(function_name, function_parameters, type);

Expand Down
6 changes: 3 additions & 3 deletions src/pddl/parser/ground_literal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace loki
pddl::GroundAtom parse(const ast::AtomicFormulaOfNamesPredicate& node, Context& context)
{
const auto name = parse(node.predicate.name);
const auto binding = context.scopes.get_predicate(name);
const auto binding = context.scopes.get<pddl::PredicateImpl>(name);
if (!binding.has_value())
{
throw UndefinedPredicateError(name, context.scopes.get_error_handler()(node, ""));
Expand All @@ -54,8 +54,8 @@ pddl::GroundAtom parse(const ast::AtomicFormulaOfNamesEquality& node, Context& c
{
throw UndefinedRequirementError(pddl::RequirementEnum::EQUALITY, context.scopes.get_error_handler()(node, ""));
}
assert(context.scopes.get_predicate("=").has_value());
const auto [equal_predicate, _position, _error_handler] = context.scopes.get_predicate("=").value();
assert(context.scopes.get<pddl::PredicateImpl>("=").has_value());
const auto [equal_predicate, _position, _error_handler] = context.scopes.get<pddl::PredicateImpl>("=").value();
const auto object_left = parse_object_reference(node.name_left, context);
const auto object_right = parse_object_reference(node.name_right, context);
const auto atom = context.factories.ground_atoms.get_or_create<pddl::GroundAtomImpl>(equal_predicate, pddl::ObjectList { object_left, object_right });
Expand Down
Loading

0 comments on commit 17448a4

Please sign in to comment.