Skip to content

Commit

Permalink
Adds a possibility to access the first Token of a SyntaxNode. (#904)
Browse files Browse the repository at this point in the history
* Adds PtrTokenOrSyntax class

In order to add a getFirstToken() as a modifiable value, this structure
will allows retrieving a token by address rather than by value.

There is no real need to get a constptr though so it is not implemented.

* Implements getChildPtr

Allows to retrieve a mutable version of a token of any SyntaxNode, and
in particular the first and last token.

* Adds function to get any token pointer.

Adds the getTokenPtr function.
  • Loading branch information
suzizecat authored Mar 1, 2024
1 parent 13e2f12 commit f0cd4e5
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 4 deletions.
49 changes: 49 additions & 0 deletions include/slang/syntax/SyntaxNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,32 @@ namespace slang::syntax {

class SyntaxNode;

/// A token pointer or a syntax node.
struct SLANG_EXPORT PtrTokenOrSyntax : public std::variant<parsing::Token*, SyntaxNode*> {
using Base = std::variant<parsing::Token*, SyntaxNode*>;
PtrTokenOrSyntax(parsing::Token* token) : Base(token) {}
PtrTokenOrSyntax(SyntaxNode* node) : Base(node) {}
PtrTokenOrSyntax(nullptr_t) : Base((parsing::Token*)nullptr) {}

/// @return true if the object is a token.
bool isToken() const { return this->index() == 0; }

/// @return true if the object is a syntax node.
bool isNode() const { return this->index() == 1; }

/// Gets access to the object as a token (throwing an exception
/// if it's not actually a token).
parsing::Token* token() const { return std::get<0>(*this); }

/// Gets access to the object as a syntax node (throwing an exception
/// if it's not actually a syntax node).
SyntaxNode* node() const { return std::get<1>(*this); }

/// Gets the source range for the token or syntax node or NoLocation if it
/// points to nullptr.
SourceRange range() const;
};

/// A token or a constant syntax node.
struct SLANG_EXPORT ConstTokenOrSyntax : public std::variant<parsing::Token, const SyntaxNode*> {
using Base = std::variant<parsing::Token, const SyntaxNode*>;
Expand Down Expand Up @@ -81,6 +107,12 @@ class SLANG_EXPORT SyntaxNode {
/// Get the last leaf token in this subtree.
Token getLastToken() const;

/// Get the first leaf token as a mutable pointer in this subtree.
Token* getFirstTokenPtr();

/// Get the last leaf token a mutable pointer in this subtree.
Token* getLastTokenPtr();

/// Get the source range of the node.
SourceRange sourceRange() const;

Expand All @@ -97,6 +129,11 @@ class SLANG_EXPORT SyntaxNode {
/// an empty Token.
Token childToken(size_t index) const;

/// Gets a pointer to the child token at the specified index. If the
/// child at the given index is not a token (probably a node) then
/// this returns null.
Token* childTokenPtr(size_t index);

/// Gets the number of (direct) children underneath this node in the tree.
size_t getChildCount() const; // Note: implemented in AllSyntax.cpp

Expand Down Expand Up @@ -159,6 +196,7 @@ class SLANG_EXPORT SyntaxNode {
private:
ConstTokenOrSyntax getChild(size_t index) const;
TokenOrSyntax getChild(size_t index);
PtrTokenOrSyntax getChildPtr(size_t index);
};

/// @brief Performs a shallow clone of the given syntax node.
Expand Down Expand Up @@ -234,6 +272,9 @@ class SLANG_EXPORT SyntaxListBase : public SyntaxNode {
/// Gets the child (token or node) at the given index.
virtual ConstTokenOrSyntax getChild(size_t index) const = 0;

// Gets the child pointer (token or node) at given index.
virtual PtrTokenOrSyntax getChildPtr(size_t index) = 0;

/// Sets the child (token or node) at the given index.
virtual void setChild(size_t index, TokenOrSyntax child) = 0;

Expand Down Expand Up @@ -263,6 +304,7 @@ class SLANG_EXPORT SyntaxList : public SyntaxListBase, public std::span<T*> {
private:
TokenOrSyntax getChild(size_t index) final { return (*this)[index]; }
ConstTokenOrSyntax getChild(size_t index) const final { return (*this)[index]; }
PtrTokenOrSyntax getChildPtr(size_t index) final { return (*this)[index]; };

void setChild(size_t index, TokenOrSyntax child) final {
(*this)[index] = &child.node()->as<T>();
Expand Down Expand Up @@ -301,6 +343,7 @@ class SLANG_EXPORT TokenList : public SyntaxListBase, public std::span<parsing::
private:
TokenOrSyntax getChild(size_t index) final { return (*this)[index]; }
ConstTokenOrSyntax getChild(size_t index) const final { return (*this)[index]; }
PtrTokenOrSyntax getChildPtr(size_t index) final { return &(*this)[index]; };
void setChild(size_t index, TokenOrSyntax child) final { (*this)[index] = child.token(); }

SyntaxListBase* clone(BumpAllocator& alloc) const final {
Expand Down Expand Up @@ -392,6 +435,12 @@ class SLANG_EXPORT SeparatedSyntaxList : public SyntaxListBase {
private:
TokenOrSyntax getChild(size_t index) final { return elements[index]; }
ConstTokenOrSyntax getChild(size_t index) const final { return elements[index]; }
PtrTokenOrSyntax getChildPtr(size_t index) final {
if (elements[index].isNode())
return elements[index].node();
else
return &(std::get<parsing::Token>(elements[index]));
}
void setChild(size_t index, TokenOrSyntax child) final { elements[index] = child; }

SyntaxListBase* clone(BumpAllocator& alloc) const final {
Expand Down
34 changes: 30 additions & 4 deletions scripts/syntax_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def generateSyntax(builddir, alltypes, kindmap):

outf.write(" TokenOrSyntax getChild(size_t index);\n")
outf.write(" ConstTokenOrSyntax getChild(size_t index) const;\n")
outf.write(" PtrTokenOrSyntax getChildPtr(size_t index);\n")
outf.write(" void setChild(size_t index, TokenOrSyntax child);\n\n")

docf.write(
Expand All @@ -382,6 +383,14 @@ def generateSyntax(builddir, alltypes, kindmap):
name
)
)
docf.write(
" @brief Gets the child member (token or syntax node) as a pointer at the provided index within this struct\n"
)
docf.write(
" @fn PtrTokenOrSyntax slang::syntax::{}::getChildPtr(size_t index)\n".format(
name
)
)
docf.write(
" @brief Gets the child member (token or syntax node) at the provided index within this struct\n"
)
Expand Down Expand Up @@ -473,19 +482,35 @@ def generateSyntax(builddir, alltypes, kindmap):
cppf.write("}\n\n")

if v.members or v.final != "":
for returnType in ("TokenOrSyntax", "ConstTokenOrSyntax"):
for returnType in (
"TokenOrSyntax",
"ConstTokenOrSyntax",
"PtrTokenOrSyntax",
):
cppf.write(
"{} {}::getChild(size_t index){} {{\n".format(
returnType, k, "" if returnType == "TokenOrSyntax" else " const"
"{} {}::getChild{}(size_t index){} {{\n".format(
returnType,
k,
("Ptr" if returnType.startswith("Ptr") else ""),
"" if not returnType.startswith("Const") else " const",
)
)

returnPointer = returnType == "PtrTokenOrSyntax"

if v.combinedMembers:
cppf.write(" switch (index) {\n")

index = 0
for m in v.combinedMembers:
addr = "&" if m[1] in v.pointerMembers else ""
addr = ""
if returnPointer:
if m[0] == "Token" or (m[1] in v.pointerMembers):
addr = "&"
elif m[1] in v.pointerMembers:
addr = "&"

# addr = "&" if != (returnPointer and not (m[1] in v.notNullMembers)) else ""
get = ".get()" if m[1] in v.notNullMembers else ""
cppf.write(
" case {}: return {}{}{};\n".format(
Expand Down Expand Up @@ -636,6 +661,7 @@ def generateSyntax(builddir, alltypes, kindmap):
)
outf.write(" TokenOrSyntax getChild(size_t) { return nullptr; }\n")
outf.write(" ConstTokenOrSyntax getChild(size_t) const { return nullptr; }\n")
outf.write(" PtrTokenOrSyntax getChildPtr(size_t) { return nullptr; }\n")
outf.write(" void setChild(size_t, TokenOrSyntax) {}\n")
outf.write("};\n\n")

Expand Down
64 changes: 64 additions & 0 deletions source/syntax/SyntaxNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ namespace {
using namespace slang;
using namespace slang::syntax;

struct PtrGetChildVisitor {
template<typename T>
PtrTokenOrSyntax visit(T& node, size_t index) {
return node.getChildPtr(index);
}
};

struct ConstGetChildVisitor {
template<typename T>
ConstTokenOrSyntax visit(const T& node, size_t index) {
Expand All @@ -33,6 +40,17 @@ struct GetChildVisitor {

namespace slang::syntax {

SourceRange PtrTokenOrSyntax::range() const {
if (isNode())
return node()->sourceRange();
else {
if (token() == nullptr)
return SourceRange::NoLocation;
else
return token()->range();
}
}

SourceRange ConstTokenOrSyntax::range() const {
return isNode() ? node()->sourceRange() : token().range();
}
Expand Down Expand Up @@ -75,6 +93,40 @@ parsing::Token SyntaxNode::getLastToken() const {
return Token();
}

parsing::Token* SyntaxNode::getFirstTokenPtr() {
size_t childCount = getChildCount();
for (size_t i = 0; i < childCount; i++) {
auto child = getChildPtr(i);
if (child.isToken()) {
if (child.token())
return child.token();
}
else if (child.node()) {
auto result = child.node()->getFirstTokenPtr();
if (result)
return result;
}
}
return nullptr;
}

parsing::Token* SyntaxNode::getLastTokenPtr() {
size_t childCount = getChildCount();
for (ptrdiff_t i = ptrdiff_t(childCount) - 1; i >= 0; i--) {
auto child = getChildPtr(size_t(i));
if (child.isToken()) {
if (child.token())
return child.token();
}
else if (child.node()) {
auto result = child.node()->getLastTokenPtr();
if (result)
return result;
}
}
return nullptr;
}

SourceRange SyntaxNode::sourceRange() const {
Token firstToken = getFirstToken();
Token lastToken = getLastToken();
Expand All @@ -86,6 +138,11 @@ ConstTokenOrSyntax SyntaxNode::getChild(size_t index) const {
return visit(visitor, index);
}

PtrTokenOrSyntax SyntaxNode::getChildPtr(size_t index) {
PtrGetChildVisitor visitor;
return visit(visitor, index);
}

TokenOrSyntax SyntaxNode::getChild(size_t index) {
GetChildVisitor visitor;
return visit(visitor, index);
Expand All @@ -112,6 +169,13 @@ parsing::Token SyntaxNode::childToken(size_t index) const {
return child.token();
}

parsing::Token* SyntaxNode::childTokenPtr(size_t index) {
auto child = getChildPtr(index);
if (!child.isToken())
return nullptr;
return child.token();
}

bool SyntaxNode::isEquivalentTo(const SyntaxNode& other) const {
size_t childCount = getChildCount();
if (kind != other.kind || childCount != other.getChildCount())
Expand Down

0 comments on commit f0cd4e5

Please sign in to comment.