diff --git a/lib/libimhex/include/hex/pattern_language/parser.hpp b/lib/libimhex/include/hex/pattern_language/parser.hpp index 928462193..14cccc42b 100644 --- a/lib/libimhex/include/hex/pattern_language/parser.hpp +++ b/lib/libimhex/include/hex/pattern_language/parser.hpp @@ -27,7 +27,7 @@ namespace hex::pl { private: std::optional m_error; TokenIter m_curr; - TokenIter m_originalPosition; + TokenIter m_originalPosition, m_partOriginalPosition; std::unordered_map m_types; std::vector m_matchedOptionals; @@ -71,6 +71,7 @@ namespace hex::pl { ASTNode *parseStringLiteral(); std::string parseNamespaceResolution(); ASTNode *parseScopeResolution(); + ASTNode *parseRValue(); ASTNode *parseRValue(ASTNodeRValue::Path &path); ASTNode *parseFactor(); ASTNode *parseCastExpression(); @@ -92,7 +93,8 @@ namespace hex::pl { ASTNode *parseFunctionDefinition(); ASTNode *parseFunctionVariableDecl(); ASTNode *parseFunctionStatement(); - ASTNode *parseFunctionVariableAssignment(); + ASTNode *parseFunctionVariableAssignment(const std::string &lvalue); + ASTNode *parseFunctionVariableCompoundAssignment(const std::string &lvalue); ASTNode *parseFunctionControlFlowStatement(); std::vector parseStatementBody(); ASTNode *parseFunctionConditional(); @@ -158,12 +160,29 @@ namespace hex::pl { return true; } + bool partBegin() { + this->m_partOriginalPosition = this->m_curr; + this->m_matchedOptionals.clear(); + + return true; + } + void reset() { this->m_curr = this->m_originalPosition; } + void partReset() { + this->m_curr = this->m_partOriginalPosition; + } + + bool resetIfFailed(bool value) { + if (!value) reset(); + + return value; + } + template - bool sequence() { + bool sequenceImpl() { if constexpr (S == Normal) return true; else if constexpr (S == Not) @@ -173,17 +192,17 @@ namespace hex::pl { } template - bool sequence(Token::Type type, auto value, auto... args) { + bool sequenceImpl(Token::Type type, auto value, auto... args) { if constexpr (S == Normal) { if (!peek(type, value)) { - reset(); + partReset(); return false; } this->m_curr++; - if (!sequence(args...)) { - reset(); + if (!sequenceImpl(args...)) { + partReset(); return false; } @@ -194,17 +213,22 @@ namespace hex::pl { this->m_curr++; - if (!sequence(args...)) + if (!sequenceImpl(args...)) return true; - reset(); + partReset(); return false; } else __builtin_unreachable(); } template - bool oneOf() { + bool sequence(Token::Type type, auto value, auto... args) { + return partBegin() && sequenceImpl(type, value, args...); + } + + template + bool oneOfImpl() { if constexpr (S == Normal) return false; else if constexpr (S == Not) @@ -214,19 +238,24 @@ namespace hex::pl { } template - bool oneOf(Token::Type type, auto value, auto... args) { + bool oneOfImpl(Token::Type type, auto value, auto... args) { if constexpr (S == Normal) - return sequence(type, value) || oneOf(args...); + return sequenceImpl(type, value) || oneOfImpl(args...); else if constexpr (S == Not) - return sequence(type, value) && oneOf(args...); + return sequenceImpl(type, value) && oneOfImpl(args...); else __builtin_unreachable(); } - bool variant(Token::Type type1, auto value1, Token::Type type2, auto value2) { + template + bool oneOf(Token::Type type, auto value, auto... args) { + return partBegin() && oneOfImpl(type, value, args...); + } + + bool variantImpl(Token::Type type1, auto value1, Token::Type type2, auto value2) { if (!peek(type1, value1)) { if (!peek(type2, value2)) { - reset(); + partReset(); return false; } } @@ -236,7 +265,11 @@ namespace hex::pl { return true; } - bool optional(Token::Type type, auto value) { + bool variant(Token::Type type1, auto value1, Token::Type type2, auto value2) { + return partBegin() && variantImpl(type1, value1, type2, value2); + } + + bool optionalImpl(Token::Type type, auto value) { if (peek(type, value)) { this->m_matchedOptionals.push_back(this->m_curr); this->m_curr++; @@ -245,6 +278,10 @@ namespace hex::pl { return true; } + bool optional(Token::Type type, auto value) { + return partBegin() && optionalImpl(type, value); + } + bool peek(Token::Type type, auto value, i32 index = 0) { return this->m_curr[index].type == type && this->m_curr[index] == value; } diff --git a/lib/libimhex/include/hex/pattern_language/token.hpp b/lib/libimhex/include/hex/pattern_language/token.hpp index d6bed5044..eaaadef3f 100644 --- a/lib/libimhex/include/hex/pattern_language/token.hpp +++ b/lib/libimhex/include/hex/pattern_language/token.hpp @@ -306,6 +306,7 @@ namespace hex::pl { #define IDENTIFIER hex::pl::Token::Type::Identifier, "" #define STRING hex::pl::Token::Type::String, hex::pl::Token::Literal("") +#define OPERATOR_ANY COMPONENT(Operator, Any) #define OPERATOR_AT COMPONENT(Operator, AtDeclaration) #define OPERATOR_ASSIGNMENT COMPONENT(Operator, Assignment) #define OPERATOR_INHERIT COMPONENT(Operator, Inherit) @@ -341,7 +342,6 @@ namespace hex::pl { #define VALUETYPE_UNSIGNED COMPONENT(ValueType, Unsigned) #define VALUETYPE_SIGNED COMPONENT(ValueType, Signed) #define VALUETYPE_FLOATINGPOINT COMPONENT(ValueType, FloatingPoint) -#define VALUETYPE_INTEGER COMPONENT(ValueType, Integer) #define VALUETYPE_ANY COMPONENT(ValueType, Any) #define SEPARATOR_ROUNDBRACKETOPEN COMPONENT(Separator, RoundBracketOpen) diff --git a/lib/libimhex/source/pattern_language/parser.cpp b/lib/libimhex/source/pattern_language/parser.cpp index de071a686..0aec12b67 100644 --- a/lib/libimhex/source/pattern_language/parser.cpp +++ b/lib/libimhex/source/pattern_language/parser.cpp @@ -2,7 +2,7 @@ #include -#define MATCHES(x) (begin() && x) +#define MATCHES(x) (begin() && resetIfFailed(x)) // Definition syntax: // [A] : Either A or no token @@ -87,6 +87,11 @@ namespace hex::pl { throwParseError("failed to parse scope resolution. Expected 'TypeName::Identifier'"); } + ASTNode *Parser::parseRValue() { + ASTNodeRValue::Path path; + return this->parseRValue(path); + } + // ASTNode *Parser::parseRValue(ASTNodeRValue::Path &path) { if (peek(IDENTIFIER, -1)) @@ -136,28 +141,33 @@ namespace hex::pl { } else if (peek(OPERATOR_SCOPERESOLUTION, 0)) { return this->parseScopeResolution(); } else { - ASTNodeRValue::Path path; - return this->parseRValue(path); + return this->parseRValue(); } } else if (MATCHES(oneOf(KEYWORD_PARENT, KEYWORD_THIS))) { - ASTNodeRValue::Path path; - return this->parseRValue(path); + return this->parseRValue(); } else if (MATCHES(sequence(OPERATOR_DOLLAR))) { return new ASTNodeRValue({ "$" }); } else if (MATCHES(oneOf(OPERATOR_ADDRESSOF, OPERATOR_SIZEOF) && sequence(SEPARATOR_ROUNDBRACKETOPEN))) { auto op = getValue(-2); - if (!MATCHES(oneOf(IDENTIFIER, KEYWORD_PARENT, KEYWORD_THIS))) { - throwParseError("expected rvalue identifier"); + ASTNode *result = nullptr; + + if (MATCHES(oneOf(IDENTIFIER, KEYWORD_PARENT, KEYWORD_THIS))) { + result = create(new ASTNodeTypeOperator(op, this->parseRValue())); + } else if (MATCHES(sequence(VALUETYPE_ANY))) { + auto type = getValue(-1); + + result = new ASTNodeLiteral(u128(Token::getTypeSize(type))); + } else { + throwParseError("expected rvalue identifier or built-in type"); } - ASTNodeRValue::Path path; - auto node = create(new ASTNodeTypeOperator(op, this->parseRValue(path))); if (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) { - delete node; + delete result; throwParseError("expected closing parenthesis"); } - return node; + + return result; } else throwParseError("expected value or parenthesis"); } @@ -490,9 +500,13 @@ namespace hex::pl { ASTNode *statement; if (MATCHES(sequence(IDENTIFIER, OPERATOR_ASSIGNMENT))) - statement = parseFunctionVariableAssignment(); + statement = parseFunctionVariableAssignment(getValue(-2).get()); else if (MATCHES(sequence(OPERATOR_DOLLAR, OPERATOR_ASSIGNMENT))) - statement = create(new ASTNodeAssignment("$", parseMathematicalExpression())); + statement = parseFunctionVariableAssignment("$"); + else if (MATCHES(oneOf(IDENTIFIER) && oneOf(OPERATOR_PLUS, OPERATOR_MINUS, OPERATOR_STAR, OPERATOR_SLASH, OPERATOR_PERCENT, OPERATOR_SHIFTLEFT, OPERATOR_SHIFTRIGHT, OPERATOR_BITOR, OPERATOR_BITAND, OPERATOR_BITXOR) && sequence(OPERATOR_ASSIGNMENT))) + statement = parseFunctionVariableCompoundAssignment(getValue(-3).get()); + else if (MATCHES(oneOf(OPERATOR_DOLLAR) && oneOf(OPERATOR_PLUS, OPERATOR_MINUS, OPERATOR_STAR, OPERATOR_SLASH, OPERATOR_PERCENT, OPERATOR_SHIFTLEFT, OPERATOR_SHIFTRIGHT, OPERATOR_BITOR, OPERATOR_BITAND, OPERATOR_BITXOR) && sequence(OPERATOR_ASSIGNMENT))) + statement = parseFunctionVariableCompoundAssignment("$"); else if (MATCHES(oneOf(KEYWORD_RETURN, KEYWORD_BREAK, KEYWORD_CONTINUE))) statement = parseFunctionControlFlowStatement(); else if (MATCHES(sequence(KEYWORD_IF, SEPARATOR_ROUNDBRACKETOPEN))) { @@ -533,14 +547,20 @@ namespace hex::pl { return statement; } - ASTNode *Parser::parseFunctionVariableAssignment() { - const auto &lvalue = getValue(-2).get(); - + ASTNode *Parser::parseFunctionVariableAssignment(const std::string &lvalue) { auto rvalue = this->parseMathematicalExpression(); return create(new ASTNodeAssignment(lvalue, rvalue)); } + ASTNode *Parser::parseFunctionVariableCompoundAssignment(const std::string &lvalue) { + const auto &op = getValue(-2); + + auto rvalue = this->parseMathematicalExpression(); + + return create(new ASTNodeAssignment(lvalue, create(new ASTNodeMathematicalExpression(create(new ASTNodeRValue({ lvalue })), rvalue, op)))); + } + ASTNode *Parser::parseFunctionControlFlowStatement() { ControlFlowStatement type; if (peek(KEYWORD_RETURN, -1)) @@ -640,7 +660,7 @@ namespace hex::pl { if (!MATCHES(sequence(IDENTIFIER, OPERATOR_ASSIGNMENT))) throwParseError("expected for loop variable assignment"); - auto postExpression = parseFunctionVariableAssignment(); + auto postExpression = parseFunctionVariableAssignment(getValue(-2).get()); auto postExpressionCleanup = SCOPE_GUARD { delete postExpression; }; std::vector body; @@ -880,7 +900,9 @@ namespace hex::pl { else if (MATCHES(sequence(KEYWORD_CONTINUE))) member = new ASTNodeControlFlowStatement(ControlFlowStatement::Continue, nullptr); else if (MATCHES(sequence(OPERATOR_DOLLAR, OPERATOR_ASSIGNMENT))) - member = create(new ASTNodeAssignment("$", parseMathematicalExpression())); + member = parseFunctionVariableAssignment("$"); + else if (MATCHES(oneOf(OPERATOR_DOLLAR) && oneOf(OPERATOR_PLUS, OPERATOR_MINUS, OPERATOR_STAR, OPERATOR_SLASH, OPERATOR_PERCENT, OPERATOR_SHIFTLEFT, OPERATOR_SHIFTRIGHT, OPERATOR_BITOR, OPERATOR_BITAND, OPERATOR_BITXOR) && sequence(OPERATOR_ASSIGNMENT))) + member = parseFunctionVariableCompoundAssignment("$"); else throwParseError("invalid struct member", 0);