From 7f0bdc95dad9f3774f11cae0a4a7a8f784eb1d90 Mon Sep 17 00:00:00 2001 From: WerWolv Date: Sun, 20 Jun 2021 21:22:31 +0200 Subject: [PATCH] patterns: Added support for declaring custom functions --- .../source/content/lang_builtin_functions.cpp | 58 ++-- .../include/hex/api/content_registry.hpp | 4 +- .../libimhex/include/hex/lang/ast_node.hpp | 107 +++++- .../libimhex/include/hex/lang/evaluator.hpp | 7 + plugins/libimhex/include/hex/lang/parser.hpp | 3 + .../include/hex/lang/pattern_data.hpp | 13 +- plugins/libimhex/include/hex/lang/token.hpp | 33 +- .../libimhex/source/api/content_registry.cpp | 2 +- plugins/libimhex/source/lang/evaluator.cpp | 327 ++++++++++++------ plugins/libimhex/source/lang/lexer.cpp | 26 +- plugins/libimhex/source/lang/parser.cpp | 76 +++- source/views/view_pattern.cpp | 2 +- source/window.cpp | 2 +- 13 files changed, 471 insertions(+), 189 deletions(-) diff --git a/plugins/builtin/source/content/lang_builtin_functions.cpp b/plugins/builtin/source/content/lang_builtin_functions.cpp index 1f5981ad7..a250dd59b 100644 --- a/plugins/builtin/source/content/lang_builtin_functions.cpp +++ b/plugins/builtin/source/content/lang_builtin_functions.cpp @@ -40,7 +40,7 @@ namespace hex::plugin::builtin { continue; } - return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, offset }); + return new ASTNodeIntegerLiteral(offset); } } @@ -63,11 +63,11 @@ namespace hex::plugin::builtin { SharedData::currentProvider->read(address, value, size); switch ((u8)size) { - case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, *reinterpret_cast(value) }); - case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, *reinterpret_cast(value) }); - case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned32Bit, *reinterpret_cast(value) }); - case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, *reinterpret_cast(value) }); - case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned128Bit, *reinterpret_cast(value) }); + case 1: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); + case 2: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); + case 4: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); + case 8: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); + case 16: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); default: ctx.getConsole().abortEvaluation("invalid read size"); } }, address, size); @@ -89,11 +89,11 @@ namespace hex::plugin::builtin { SharedData::currentProvider->read(address, value, size); switch ((u8)size) { - case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, *reinterpret_cast(value) }); - case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, *reinterpret_cast(value) }); - case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, *reinterpret_cast(value) }); - case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, *reinterpret_cast(value) }); - case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, *reinterpret_cast(value) }); + case 1: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); + case 2: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); + case 4: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); + case 8: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); + case 16: return new ASTNodeIntegerLiteral(*reinterpret_cast(value)); default: ctx.getConsole().abortEvaluation("invalid read size"); } }, address, size); @@ -127,25 +127,19 @@ namespace hex::plugin::builtin { for (auto& param : params) { if (auto integerLiteral = dynamic_cast(param); integerLiteral != nullptr) { std::visit([&](auto &&value) { - switch (integerLiteral->getType()) { - case lang::Token::ValueType::Character: message += (char)value; break; - case lang::Token::ValueType::Boolean: message += value == 0 ? "false" : "true"; break; - case lang::Token::ValueType::Unsigned8Bit: - case lang::Token::ValueType::Unsigned16Bit: - case lang::Token::ValueType::Unsigned32Bit: - case lang::Token::ValueType::Unsigned64Bit: - case lang::Token::ValueType::Unsigned128Bit: - message += std::to_string(static_cast(value)); - break; - case lang::Token::ValueType::Signed8Bit: - case lang::Token::ValueType::Signed16Bit: - case lang::Token::ValueType::Signed32Bit: - case lang::Token::ValueType::Signed64Bit: - case lang::Token::ValueType::Signed128Bit: - message += std::to_string(static_cast(value)); - break; - default: message += "< Custom Type >"; - } + using Type = std::remove_cvref_t; + if constexpr (std::is_same_v) + message += (char)value; + else if constexpr (std::is_same_v) + message += value == 0 ? "false" : "true"; + else if constexpr (std::is_unsigned_v) + message += std::to_string(static_cast(value)); + else if constexpr (std::is_signed_v) + message += std::to_string(static_cast(value)); + else if constexpr (std::is_floating_point_v) + message += std::to_string(value); + else + message += "< Custom Type >"; }, integerLiteral->getValue()); } else if (auto stringLiteral = dynamic_cast(param); stringLiteral != nullptr) @@ -167,12 +161,12 @@ namespace hex::plugin::builtin { return remainder != 0 ? u64(value) + (u64(alignment) - remainder) : u64(value); }, alignment, value); - return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, u64(result) }); + return new ASTNodeIntegerLiteral(u64(result)); }); /* dataSize() */ ContentRegistry::PatternLanguageFunctions::add("dataSize", ContentRegistry::PatternLanguageFunctions::NoParameters, [](auto &ctx, auto params) -> ASTNode* { - return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, u64(SharedData::currentProvider->getActualSize()) }); + return new ASTNodeIntegerLiteral(u64(SharedData::currentProvider->getActualSize())); }); } diff --git a/plugins/libimhex/include/hex/api/content_registry.hpp b/plugins/libimhex/include/hex/api/content_registry.hpp index 27f2850fe..d2da9704e 100644 --- a/plugins/libimhex/include/hex/api/content_registry.hpp +++ b/plugins/libimhex/include/hex/api/content_registry.hpp @@ -87,10 +87,10 @@ namespace hex { struct Function { u32 parameterCount; - std::function)> func; + std::function&)> func; }; - static void add(std::string_view name, u32 parameterCount, const std::function)> &func); + static void add(std::string_view name, u32 parameterCount, const std::function&)> &func); static std::map& getEntries(); }; diff --git a/plugins/libimhex/include/hex/lang/ast_node.hpp b/plugins/libimhex/include/hex/lang/ast_node.hpp index d3fa72e5c..d69d6631f 100644 --- a/plugins/libimhex/include/hex/lang/ast_node.hpp +++ b/plugins/libimhex/include/hex/lang/ast_node.hpp @@ -56,11 +56,7 @@ namespace hex::lang { } [[nodiscard]] const auto& getValue() const { - return this->m_literal.second; - } - - [[nodiscard]] Token::ValueType getType() const { - return this->m_literal.first; + return this->m_literal; } private: @@ -622,4 +618,105 @@ namespace hex::lang { Token::Operator m_op; ASTNode *m_expression; }; + + class ASTNodeFunctionDefinition : public ASTNode { + public: + ASTNodeFunctionDefinition(std::string name, std::vector params, std::vector body) + : m_name(std::move(name)), m_params(std::move(params)), m_body(std::move(body)) { + + } + + ASTNodeFunctionDefinition(const ASTNodeFunctionDefinition &other) : ASTNode(other) { + this->m_name = other.m_name; + this->m_params = other.m_params; + + for (auto statement : other.m_body) { + this->m_body.push_back(statement->clone()); + } + } + + [[nodiscard]] ASTNode* clone() const override { + return new ASTNodeFunctionDefinition(*this); + } + + ~ASTNodeFunctionDefinition() override { + for (auto statement : this->m_body) + delete statement; + } + + [[nodiscard]] std::string_view getName() const { + return this->m_name; + } + + [[nodiscard]] const auto& getParams() const { + return this->m_params; + } + + [[nodiscard]] const auto& getBody() const { + return this->m_body; + } + + private: + std::string m_name; + std::vector m_params; + std::vector m_body; + }; + + class ASTNodeAssignment : public ASTNode { + public: + ASTNodeAssignment(std::string lvalueName, ASTNode *rvalue) : m_lvalueName(std::move(lvalueName)), m_rvalue(rvalue) { + + } + + ASTNodeAssignment(const ASTNodeAssignment &other) : ASTNode(other) { + this->m_lvalueName = other.m_lvalueName; + this->m_rvalue = other.m_rvalue->clone(); + } + + [[nodiscard]] ASTNode* clone() const override { + return new ASTNodeAssignment(*this); + } + + ~ASTNodeAssignment() override { + delete this->m_rvalue; + } + + [[nodiscard]] std::string_view getLValueName() const { + return this->m_lvalueName; + } + + [[nodiscard]] ASTNode* getRValue() const { + return this->m_rvalue; + } + + private: + std::string m_lvalueName; + ASTNode *m_rvalue; + }; + + class ASTNodeReturnStatement : public ASTNode { + public: + ASTNodeReturnStatement(ASTNode *rvalue) : m_rvalue(rvalue) { + + } + + ASTNodeReturnStatement(const ASTNodeReturnStatement &other) : ASTNode(other) { + this->m_rvalue = other.m_rvalue->clone(); + } + + [[nodiscard]] ASTNode* clone() const override { + return new ASTNodeReturnStatement(*this); + } + + ~ASTNodeReturnStatement() override { + delete this->m_rvalue; + } + + [[nodiscard]] ASTNode* getRValue() const { + return this->m_rvalue; + } + + private: + ASTNode *m_rvalue; + }; } \ No newline at end of file diff --git a/plugins/libimhex/include/hex/lang/evaluator.hpp b/plugins/libimhex/include/hex/lang/evaluator.hpp index 182a7d46f..a355c7d0c 100644 --- a/plugins/libimhex/include/hex/lang/evaluator.hpp +++ b/plugins/libimhex/include/hex/lang/evaluator.hpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -46,12 +47,17 @@ namespace hex::lang { std::vector m_endianStack; std::vector m_globalMembers; std::vector*> m_currMembers; + std::vector*> m_localVariables; std::vector m_currMemberScope; + std::vector m_localStack; + std::map m_definedFunctions; LogConsole m_console; u32 m_recursionLimit; u32 m_currRecursionDepth; + void createLocalVariable(std::string_view varName, PatternData *pattern); + void setLocalVariableValue(std::string_view varName, const void *value, size_t size); ASTNodeIntegerLiteral* evaluateScopeResolution(ASTNodeScopeResolution *node); ASTNodeIntegerLiteral* evaluateRValue(ASTNodeRValue *node); @@ -61,6 +67,7 @@ namespace hex::lang { ASTNodeIntegerLiteral* evaluateOperand(ASTNode *node); ASTNodeIntegerLiteral* evaluateTernaryExpression(ASTNodeTernaryExpression *node); ASTNodeIntegerLiteral* evaluateMathematicalExpression(ASTNodeNumericExpression *node); + void evaluateFunctionDefinition(ASTNodeFunctionDefinition *node); PatternData* findPattern(std::vector currMembers, const ASTNodeRValue::Path &path); PatternData* evaluateAttributes(ASTNode *currNode, PatternData *currPattern); diff --git a/plugins/libimhex/include/hex/lang/parser.hpp b/plugins/libimhex/include/hex/lang/parser.hpp index 3d0678375..d48678aba 100644 --- a/plugins/libimhex/include/hex/lang/parser.hpp +++ b/plugins/libimhex/include/hex/lang/parser.hpp @@ -72,6 +72,9 @@ namespace hex::lang { ASTNode* parseMathematicalExpression(); void parseAttribute(Attributable *currNode); + ASTNode* parseFunctionDefintion(); + ASTNode* parseVariableAssignment(); + ASTNode* parseReturnStatement(); ASTNode* parseConditional(); ASTNode* parseWhileStatement(); ASTNode* parseType(s32 startIndex); diff --git a/plugins/libimhex/include/hex/lang/pattern_data.hpp b/plugins/libimhex/include/hex/lang/pattern_data.hpp index cd8b85fef..649f07721 100644 --- a/plugins/libimhex/include/hex/lang/pattern_data.hpp +++ b/plugins/libimhex/include/hex/lang/pattern_data.hpp @@ -167,10 +167,18 @@ namespace hex::lang { this->m_hidden = hidden; } - bool isHidden() const { + [[nodiscard]] bool isHidden() const { return this->m_hidden; } + void setLocal(bool local) { + this->m_local = local; + } + + [[nodiscard]] bool isLocal() const { + return this->m_local; + } + protected: void createDefaultEntry(std::string_view value) const { ImGui::TableNextRow(); @@ -217,6 +225,7 @@ namespace hex::lang { std::string m_typeName; PatternData *m_parent; + bool m_local = false; }; class PatternDataPadding : public PatternData { @@ -886,7 +895,7 @@ namespace hex::lang { } return false; - }, entryValueLiteral.second); + }, entryValueLiteral); if (matches) break; } diff --git a/plugins/libimhex/include/hex/lang/token.hpp b/plugins/libimhex/include/hex/lang/token.hpp index 4fbcb524f..3a3935f2f 100644 --- a/plugins/libimhex/include/hex/lang/token.hpp +++ b/plugins/libimhex/include/hex/lang/token.hpp @@ -33,7 +33,9 @@ namespace hex::lang { If, Else, Parent, - While + While, + Function, + Return }; enum class Operator { @@ -107,8 +109,7 @@ namespace hex::lang { EndOfProgram }; - using Integers = std::variant; - using IntegerLiteral = std::pair; + using IntegerLiteral = std::variant; using ValueTypes = std::variant; Token(Type type, auto value, u32 lineNumber) : type(type), value(value), lineNumber(lineNumber) { @@ -131,28 +132,6 @@ namespace hex::lang { return static_cast(type) >> 4; } - [[nodiscard]] constexpr static inline IntegerLiteral castTo(ValueType type, const Integers &literal) { - return std::visit([type](auto &&value) { - switch (type) { - case ValueType::Signed8Bit: return IntegerLiteral(type, static_cast(value)); - case ValueType::Signed16Bit: return IntegerLiteral(type, static_cast(value)); - case ValueType::Signed32Bit: return IntegerLiteral(type, static_cast(value)); - case ValueType::Signed64Bit: return IntegerLiteral(type, static_cast(value)); - case ValueType::Signed128Bit: return IntegerLiteral(type, static_cast(value)); - case ValueType::Unsigned8Bit: return IntegerLiteral(type, static_cast(value)); - case ValueType::Unsigned16Bit: return IntegerLiteral(type, static_cast(value)); - case ValueType::Unsigned32Bit: return IntegerLiteral(type, static_cast(value)); - case ValueType::Unsigned64Bit: return IntegerLiteral(type, static_cast(value)); - case ValueType::Unsigned128Bit: return IntegerLiteral(type, static_cast(value)); - case ValueType::Float: return IntegerLiteral(type, static_cast(value)); - case ValueType::Double: return IntegerLiteral(type, static_cast(value)); - case ValueType::Character: return IntegerLiteral(type, static_cast(value)); - case ValueType::Character16: return IntegerLiteral(type, static_cast(value)); - default: __builtin_unreachable(); - } - }, literal); - } - [[nodiscard]] constexpr static auto getTypeName(const lang::Token::ValueType type) { switch (type) { case ValueType::Signed8Bit: return "s8"; @@ -227,8 +206,10 @@ namespace hex::lang { #define KEYWORD_ELSE COMPONENT(Keyword, Else) #define KEYWORD_PARENT COMPONENT(Keyword, Parent) #define KEYWORD_WHILE COMPONENT(Keyword, While) +#define KEYWORD_FUNCTION COMPONENT(Keyword, Function) +#define KEYWORD_RETURN COMPONENT(Keyword, Return) -#define INTEGER hex::lang::Token::Type::Integer, hex::lang::Token::IntegerLiteral(hex::lang::Token::ValueType::Any, u64(0)) +#define INTEGER hex::lang::Token::Type::Integer, hex::lang::Token::IntegerLiteral(u64(0)) #define IDENTIFIER hex::lang::Token::Type::Identifier, "" #define STRING hex::lang::Token::Type::String, "" diff --git a/plugins/libimhex/source/api/content_registry.cpp b/plugins/libimhex/source/api/content_registry.cpp index c2f2629b4..8297565a8 100644 --- a/plugins/libimhex/source/api/content_registry.cpp +++ b/plugins/libimhex/source/api/content_registry.cpp @@ -146,7 +146,7 @@ namespace hex { /* Pattern Language Functions */ - void ContentRegistry::PatternLanguageFunctions::add(std::string_view name, u32 parameterCount, const std::function)> &func) { + void ContentRegistry::PatternLanguageFunctions::add(std::string_view name, u32 parameterCount, const std::function&)> &func) { getEntries()[name.data()] = Function{ parameterCount, func }; } diff --git a/plugins/libimhex/source/lang/evaluator.cpp b/plugins/libimhex/source/lang/evaluator.cpp index 6969c6222..1ada296f3 100644 --- a/plugins/libimhex/source/lang/evaluator.cpp +++ b/plugins/libimhex/source/lang/evaluator.cpp @@ -80,8 +80,10 @@ namespace hex::lang { if (currPattern != nullptr) { if (auto arrayPattern = dynamic_cast(currPattern); arrayPattern != nullptr) { - if (Token::isFloatingPoint(arrayIndexNode->getType())) - this->getConsole().abortEvaluation("cannot use float to index into array"); + std::visit([this](auto &&arrayIndex) { + if (std::is_floating_point_v) + this->getConsole().abortEvaluation("cannot use float to index into array"); + }, arrayIndexNode->getValue()); std::visit([&](auto &&arrayIndex){ if (arrayIndex >= 0 && arrayIndex < arrayPattern->getEntries().size()) @@ -110,9 +112,13 @@ namespace hex::lang { PatternData *currPattern = nullptr; - // Local member access - if (this->m_currMembers.size() > 1) + // Local variable access + currPattern = this->findPattern(*this->m_localVariables.back(), path); + + // If no local variable was found try local structure members + if (this->m_currMembers.size() > 1) { currPattern = this->findPattern(*this->m_currMembers[this->m_currMembers.size() - 2], path); + } // If no local member was found, try globally if (currPattern == nullptr) { @@ -143,45 +149,55 @@ namespace hex::lang { ASTNodeIntegerLiteral* Evaluator::evaluateRValue(ASTNodeRValue *node) { if (node->getPath().size() == 1) { if (auto part = std::get_if(&node->getPath()[0]); part != nullptr && *part == "$") - return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, this->m_currOffset }); + return new ASTNodeIntegerLiteral(this->m_currOffset); } auto currPattern = this->patternFromName(node->getPath()); if (auto unsignedPattern = dynamic_cast(currPattern); unsignedPattern != nullptr) { + u8 value[unsignedPattern->getSize()]; - this->m_provider->read(unsignedPattern->getOffset(), value, unsignedPattern->getSize()); + if (currPattern->isLocal()) + std::memcpy(value, this->m_localStack.data() + unsignedPattern->getOffset(), unsignedPattern->getSize()); + else + this->m_provider->read(unsignedPattern->getOffset(), value, unsignedPattern->getSize()); switch (unsignedPattern->getSize()) { - case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast(value), 1, unsignedPattern->getEndian()) }); - case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast(value), 2, unsignedPattern->getEndian()) }); - case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned32Bit, hex::changeEndianess(*reinterpret_cast(value), 4, unsignedPattern->getEndian()) }); - case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, hex::changeEndianess(*reinterpret_cast(value), 8, unsignedPattern->getEndian()) }); - case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned128Bit, hex::changeEndianess(*reinterpret_cast(value), 16, unsignedPattern->getEndian()) }); + case 1: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 1, unsignedPattern->getEndian())); + case 2: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 2, unsignedPattern->getEndian())); + case 4: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 4, unsignedPattern->getEndian())); + case 8: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 8, unsignedPattern->getEndian())); + case 16: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 16, unsignedPattern->getEndian())); default: this->getConsole().abortEvaluation("invalid rvalue size"); } } else if (auto signedPattern = dynamic_cast(currPattern); signedPattern != nullptr) { u8 value[signedPattern->getSize()]; - this->m_provider->read(signedPattern->getOffset(), value, signedPattern->getSize()); + if (currPattern->isLocal()) + std::memcpy(value, this->m_localStack.data() + signedPattern->getOffset(), signedPattern->getSize()); + else + this->m_provider->read(signedPattern->getOffset(), value, signedPattern->getSize()); switch (signedPattern->getSize()) { - case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast(value), 1, signedPattern->getEndian()) }); - case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast(value), 2, signedPattern->getEndian()) }); - case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast(value), 4, signedPattern->getEndian()) }); - case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast(value), 8, signedPattern->getEndian()) }); - case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast(value), 16, signedPattern->getEndian()) }); + case 1: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 1, signedPattern->getEndian())); + case 2: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 2, signedPattern->getEndian())); + case 4: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 4, signedPattern->getEndian())); + case 8: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 8, signedPattern->getEndian())); + case 16: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 16, signedPattern->getEndian())); default: this->getConsole().abortEvaluation("invalid rvalue size"); } } else if (auto enumPattern = dynamic_cast(currPattern); enumPattern != nullptr) { u8 value[enumPattern->getSize()]; - this->m_provider->read(enumPattern->getOffset(), value, enumPattern->getSize()); + if (currPattern->isLocal()) + std::memcpy(value, this->m_localStack.data() + enumPattern->getOffset(), enumPattern->getSize()); + else + this->m_provider->read(enumPattern->getOffset(), value, enumPattern->getSize()); switch (enumPattern->getSize()) { - case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast(value), 1, enumPattern->getEndian()) }); - case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast(value), 2, enumPattern->getEndian()) }); - case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned32Bit, hex::changeEndianess(*reinterpret_cast(value), 4, enumPattern->getEndian()) }); - case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, hex::changeEndianess(*reinterpret_cast(value), 8, enumPattern->getEndian()) }); - case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned128Bit, hex::changeEndianess(*reinterpret_cast(value), 16, enumPattern->getEndian()) }); + case 1: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 1, enumPattern->getEndian())); + case 2: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 2, enumPattern->getEndian())); + case 4: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 4, enumPattern->getEndian())); + case 8: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 8, enumPattern->getEndian())); + case 16: return new ASTNodeIntegerLiteral(hex::changeEndianess(*reinterpret_cast(value), 16, enumPattern->getEndian())); default: this->getConsole().abortEvaluation("invalid rvalue size"); } } else @@ -204,25 +220,28 @@ namespace hex::lang { evaluatedParams.push_back(stringLiteral->clone()); } - if (!ContentRegistry::PatternLanguageFunctions::getEntries().contains(node->getFunctionName().data())) + ContentRegistry::PatternLanguageFunctions::Function *function; + if (this->m_definedFunctions.contains(node->getFunctionName().data())) + function = &this->m_definedFunctions[node->getFunctionName().data()]; + else if (ContentRegistry::PatternLanguageFunctions::getEntries().contains(node->getFunctionName().data())) + function = &ContentRegistry::PatternLanguageFunctions::getEntries()[node->getFunctionName().data()]; + else this->getConsole().abortEvaluation(hex::format("no function named '{0}' found", node->getFunctionName().data())); - auto &function = ContentRegistry::PatternLanguageFunctions::getEntries()[node->getFunctionName().data()]; - - if (function.parameterCount == ContentRegistry::PatternLanguageFunctions::UnlimitedParameters) { + if (function->parameterCount == ContentRegistry::PatternLanguageFunctions::UnlimitedParameters) { ; // Don't check parameter count } - else if (function.parameterCount & ContentRegistry::PatternLanguageFunctions::LessParametersThan) { - if (evaluatedParams.size() >= (function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan)) - this->getConsole().abortEvaluation(hex::format("too many parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan)); - } else if (function.parameterCount & ContentRegistry::PatternLanguageFunctions::MoreParametersThan) { - if (evaluatedParams.size() <= (function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan)) - this->getConsole().abortEvaluation(hex::format("too few parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan)); - } else if (function.parameterCount != evaluatedParams.size()) { - this->getConsole().abortEvaluation(hex::format("invalid number of parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function.parameterCount)); + else if (function->parameterCount & ContentRegistry::PatternLanguageFunctions::LessParametersThan) { + if (evaluatedParams.size() >= (function->parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan)) + this->getConsole().abortEvaluation(hex::format("too many parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function->parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan)); + } else if (function->parameterCount & ContentRegistry::PatternLanguageFunctions::MoreParametersThan) { + if (evaluatedParams.size() <= (function->parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan)) + this->getConsole().abortEvaluation(hex::format("too few parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function->parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan)); + } else if (function->parameterCount != evaluatedParams.size()) { + this->getConsole().abortEvaluation(hex::format("invalid number of parameters for function '{0}'. Expected {1}", node->getFunctionName().data(), function->parameterCount)); } - return function.func(*this, evaluatedParams); + return function->func(*this, evaluatedParams); } ASTNodeIntegerLiteral* Evaluator::evaluateTypeOperator(ASTNodeTypeOperator *typeOperatorNode) { @@ -231,9 +250,9 @@ namespace hex::lang { switch (typeOperatorNode->getOperator()) { case Token::Operator::AddressOf: - return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, static_cast(pattern->getOffset()) }); + return new ASTNodeIntegerLiteral(static_cast(pattern->getOffset())); case Token::Operator::SizeOf: - return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, static_cast(pattern->getSize()) }); + return new ASTNodeIntegerLiteral(static_cast(pattern->getSize())); default: this->getConsole().abortEvaluation("invalid type operator used. This is a bug!"); } @@ -281,85 +300,55 @@ namespace hex::lang { } ASTNodeIntegerLiteral* Evaluator::evaluateOperator(ASTNodeIntegerLiteral *left, ASTNodeIntegerLiteral *right, Token::Operator op) { - auto newType = [&] { - #define CHECK_TYPE(type) if (left->getType() == (type) || right->getType() == (type)) return (type) - #define DEFAULT_TYPE(type) return (type) - - if (left->getType() == Token::ValueType::Any && right->getType() != Token::ValueType::Any) - return right->getType(); - if (left->getType() != Token::ValueType::Any && right->getType() == Token::ValueType::Any) - return left->getType(); - - CHECK_TYPE(Token::ValueType::Double); - CHECK_TYPE(Token::ValueType::Float); - CHECK_TYPE(Token::ValueType::Unsigned128Bit); - CHECK_TYPE(Token::ValueType::Signed128Bit); - CHECK_TYPE(Token::ValueType::Unsigned64Bit); - CHECK_TYPE(Token::ValueType::Signed64Bit); - CHECK_TYPE(Token::ValueType::Unsigned32Bit); - CHECK_TYPE(Token::ValueType::Signed32Bit); - CHECK_TYPE(Token::ValueType::Unsigned16Bit); - CHECK_TYPE(Token::ValueType::Signed16Bit); - CHECK_TYPE(Token::ValueType::Unsigned8Bit); - CHECK_TYPE(Token::ValueType::Signed8Bit); - CHECK_TYPE(Token::ValueType::Character); - CHECK_TYPE(Token::ValueType::Character16); - CHECK_TYPE(Token::ValueType::Boolean); - DEFAULT_TYPE(Token::ValueType::Signed32Bit); - - #undef CHECK_TYPE - #undef DEFAULT_TYPE - }(); - try { return std::visit([&](auto &&leftValue, auto &&rightValue) -> ASTNodeIntegerLiteral * { switch (op) { case Token::Operator::Plus: - return new ASTNodeIntegerLiteral({ newType, leftValue + rightValue }); + return new ASTNodeIntegerLiteral(leftValue + rightValue); case Token::Operator::Minus: - return new ASTNodeIntegerLiteral({ newType, leftValue - rightValue }); + return new ASTNodeIntegerLiteral(leftValue - rightValue); case Token::Operator::Star: - return new ASTNodeIntegerLiteral({ newType, leftValue * rightValue }); + return new ASTNodeIntegerLiteral(leftValue * rightValue); case Token::Operator::Slash: if (rightValue == 0) this->getConsole().abortEvaluation("Division by zero"); - return new ASTNodeIntegerLiteral({ newType, leftValue / rightValue }); + return new ASTNodeIntegerLiteral(leftValue / rightValue); case Token::Operator::Percent: if (rightValue == 0) this->getConsole().abortEvaluation("Division by zero"); - return new ASTNodeIntegerLiteral({ newType, modulus(leftValue, rightValue) }); + return new ASTNodeIntegerLiteral(modulus(leftValue, rightValue)); case Token::Operator::ShiftLeft: - return new ASTNodeIntegerLiteral({ newType, shiftLeft(leftValue, rightValue) }); + return new ASTNodeIntegerLiteral(shiftLeft(leftValue, rightValue)); case Token::Operator::ShiftRight: - return new ASTNodeIntegerLiteral({ newType, shiftRight(leftValue, rightValue) }); + return new ASTNodeIntegerLiteral(shiftRight(leftValue, rightValue)); case Token::Operator::BitAnd: - return new ASTNodeIntegerLiteral({ newType, bitAnd(leftValue, rightValue) }); + return new ASTNodeIntegerLiteral(bitAnd(leftValue, rightValue)); case Token::Operator::BitXor: - return new ASTNodeIntegerLiteral({ newType, bitXor(leftValue, rightValue) }); + return new ASTNodeIntegerLiteral(bitXor(leftValue, rightValue)); case Token::Operator::BitOr: - return new ASTNodeIntegerLiteral({ newType, bitOr(leftValue, rightValue) }); + return new ASTNodeIntegerLiteral(bitOr(leftValue, rightValue)); case Token::Operator::BitNot: - return new ASTNodeIntegerLiteral({ newType, bitNot(leftValue, rightValue) }); + return new ASTNodeIntegerLiteral(bitNot(leftValue, rightValue)); case Token::Operator::BoolEquals: - return new ASTNodeIntegerLiteral({ newType, leftValue == rightValue }); + return new ASTNodeIntegerLiteral(leftValue == rightValue); case Token::Operator::BoolNotEquals: - return new ASTNodeIntegerLiteral({ newType, leftValue != rightValue }); + return new ASTNodeIntegerLiteral(leftValue != rightValue); case Token::Operator::BoolGreaterThan: - return new ASTNodeIntegerLiteral({ newType, leftValue > rightValue }); + return new ASTNodeIntegerLiteral(leftValue > rightValue); case Token::Operator::BoolLessThan: - return new ASTNodeIntegerLiteral({ newType, leftValue < rightValue }); + return new ASTNodeIntegerLiteral(leftValue < rightValue); case Token::Operator::BoolGreaterThanOrEquals: - return new ASTNodeIntegerLiteral({ newType, leftValue >= rightValue }); + return new ASTNodeIntegerLiteral(leftValue >= rightValue); case Token::Operator::BoolLessThanOrEquals: - return new ASTNodeIntegerLiteral({ newType, leftValue <= rightValue }); + return new ASTNodeIntegerLiteral(leftValue <= rightValue); case Token::Operator::BoolAnd: - return new ASTNodeIntegerLiteral({ newType, leftValue && rightValue }); + return new ASTNodeIntegerLiteral(leftValue && rightValue); case Token::Operator::BoolXor: - return new ASTNodeIntegerLiteral({ newType, leftValue && !rightValue || !leftValue && rightValue }); + return new ASTNodeIntegerLiteral(leftValue && !rightValue || !leftValue && rightValue); case Token::Operator::BoolOr: - return new ASTNodeIntegerLiteral({ newType, leftValue || rightValue }); + return new ASTNodeIntegerLiteral(leftValue || rightValue); case Token::Operator::BoolNot: - return new ASTNodeIntegerLiteral({ newType, !rightValue }); + return new ASTNodeIntegerLiteral(!rightValue); default: this->getConsole().abortEvaluation("invalid operator used in mathematical expression"); } @@ -419,6 +408,128 @@ namespace hex::lang { return evaluateOperator(leftInteger, rightInteger, node->getOperator()); } + void Evaluator::createLocalVariable(std::string_view varName, PatternData *pattern) { + auto startOffset = this->m_currOffset; + ON_SCOPE_EXIT { this->m_currOffset = startOffset; }; + + auto endOfStack = this->m_localStack.size(); + + for (auto &variable : *this->m_localVariables.back()) { + if (variable->getVariableName() == varName) + this->getConsole().abortEvaluation(hex::format("redefinition of variable {}", varName)); + } + + this->m_localStack.resize(endOfStack + pattern->getSize()); + + pattern->setVariableName(std::string(varName)); + pattern->setOffset(endOfStack); + pattern->setLocal(true); + this->m_localVariables.back()->push_back(pattern); + std::memset(this->m_localStack.data() + pattern->getOffset(), 0x00, pattern->getSize()); + + } + + void Evaluator::setLocalVariableValue(std::string_view varName, const void *value, size_t size) { + PatternData *varPattern = nullptr; + for (auto &var : *this->m_localVariables.back()) { + if (var->getVariableName() == varName) + varPattern = var; + } + + std::memset(this->m_localStack.data() + varPattern->getOffset(), 0x00, varPattern->getSize()); + std::memcpy(this->m_localStack.data() + varPattern->getOffset(), value, std::min(varPattern->getSize(), size)); + } + + void Evaluator::evaluateFunctionDefinition(ASTNodeFunctionDefinition *node) { + ContentRegistry::PatternLanguageFunctions::Function function = { + (u32)node->getParams().size(), + [paramNames = node->getParams(), body = node->getBody()](Evaluator& evaluator, std::vector ¶ms) -> ASTNode* { + // Create local variables from parameters + std::vector localVariables; + evaluator.m_localVariables.push_back(&localVariables); + + ON_SCOPE_EXIT { + u32 stackSizeToDrop = 0; + for (auto &localVar : *evaluator.m_localVariables.back()) { + stackSizeToDrop += localVar->getSize(); + delete localVar; + } + evaluator.m_localVariables.pop_back(); + evaluator.m_localStack.resize(evaluator.m_localStack.size() - stackSizeToDrop); + }; + + auto startOffset = evaluator.m_currOffset; + for (u32 i = 0; i < params.size(); i++) { + if (auto integerLiteralNode = dynamic_cast(params[i]); integerLiteralNode != nullptr) { + std::visit([&](auto &&value) { + using Type = std::remove_cvref_t; + + PatternData *pattern; + if constexpr (std::is_unsigned_v) + pattern = new PatternDataUnsigned(0, sizeof(value)); + else if constexpr (std::is_signed_v) + pattern = new PatternDataSigned(0, sizeof(value)); + else if constexpr (std::is_floating_point_v) + pattern = new PatternDataFloat(0, sizeof(value)); + else return; + + evaluator.createLocalVariable(paramNames[i], pattern); + evaluator.setLocalVariableValue(paramNames[i], &value, sizeof(value)); + }, integerLiteralNode->getValue()); + } + } + evaluator.m_currOffset = startOffset; + + for (auto &statement : body) { + ON_SCOPE_EXIT { evaluator.m_currOffset = startOffset; }; + + if (auto functionCallNode = dynamic_cast(statement); functionCallNode != nullptr) { + auto result = evaluator.evaluateFunctionCall(functionCallNode); + delete result; + } else if (auto varDeclNode = dynamic_cast(statement); varDeclNode != nullptr) { + auto pattern = evaluator.evaluateVariable(varDeclNode); + evaluator.createLocalVariable(varDeclNode->getName(), pattern); + } else if (auto assignmentNode = dynamic_cast(statement); assignmentNode != nullptr) { + if (auto numericExpressionNode = dynamic_cast(assignmentNode->getRValue()); numericExpressionNode != nullptr) { + auto value = evaluator.evaluateMathematicalExpression(numericExpressionNode); + ON_SCOPE_EXIT { delete value; }; + + std::visit([&](auto &&value) { + evaluator.setLocalVariableValue(assignmentNode->getLValueName(), &value, sizeof(value)); + }, value->getValue()); + } else { + evaluator.getConsole().abortEvaluation("invalid rvalue used in assignment"); + } + } else if (auto assignmentNode = dynamic_cast(statement); assignmentNode != nullptr) { + if (auto numericExpressionNode = dynamic_cast(assignmentNode->getRValue()); numericExpressionNode != nullptr) { + auto value = evaluator.evaluateMathematicalExpression(numericExpressionNode); + ON_SCOPE_EXIT { delete value; }; + + std::visit([&](auto &&value) { + evaluator.setLocalVariableValue(assignmentNode->getLValueName(), &value, sizeof(value)); + }, value->getValue()); + } else { + evaluator.getConsole().abortEvaluation("invalid rvalue used in assignment"); + } + } else if (auto returnNode = dynamic_cast(statement); returnNode != nullptr) { + if (auto numericExpressionNode = dynamic_cast(returnNode->getRValue()); numericExpressionNode != nullptr) { + return evaluator.evaluateMathematicalExpression(numericExpressionNode); + } else { + evaluator.getConsole().abortEvaluation("invalid rvalue used in return statement"); + } + } + } + + return nullptr; + } + }; + + if (this->m_definedFunctions.contains(std::string(node->getName()))) + this->getConsole().abortEvaluation(hex::format("redefinition of function {}", node->getName())); + + this->m_definedFunctions.insert({ std::string(node->getName()), function }); + } + PatternData* Evaluator::evaluateAttributes(ASTNode *currNode, PatternData *currPattern) { auto attributableNode = dynamic_cast(currNode); if (attributableNode == nullptr) @@ -617,7 +728,7 @@ namespace hex::lang { auto valueNode = evaluateMathematicalExpression(expression); ON_SCOPE_EXIT { delete valueNode; }; - entryPatterns.emplace_back(Token::castTo(builtinUnderlyingType->getType(), valueNode->getValue()), name); + entryPatterns.emplace_back(valueNode->getValue(), name); } this->m_currOffset += size; @@ -642,8 +753,9 @@ namespace hex::lang { auto valueNode = evaluateMathematicalExpression(expression); ON_SCOPE_EXIT { delete valueNode; }; - auto fieldBits = std::visit([this, node, type = valueNode->getType()] (auto &&value) { - if (Token::isFloatingPoint(type)) + auto fieldBits = std::visit([this] (auto &&value) { + using Type = std::remove_cvref_t; + if constexpr (std::is_floating_point_v) this->getConsole().abortEvaluation("bitfield entry size must be an integer value"); return static_cast(value); }, valueNode->getValue()); @@ -706,9 +818,10 @@ namespace hex::lang { auto valueNode = evaluateMathematicalExpression(offset); ON_SCOPE_EXIT { delete valueNode; }; - this->m_currOffset = std::visit([this, node, type = valueNode->getType()] (auto &&value) { - if (Token::isFloatingPoint(type)) - this->getConsole().abortEvaluation("placement offset must be an integer value"); + this->m_currOffset = std::visit([this] (auto &&value) { + using Type = std::remove_cvref_t; + if constexpr (std::is_floating_point_v) + this->getConsole().abortEvaluation("bitfield entry size must be an integer value"); return static_cast(value); }, valueNode->getValue()); } @@ -740,9 +853,10 @@ namespace hex::lang { auto valueNode = evaluateMathematicalExpression(offset); ON_SCOPE_EXIT { delete valueNode; }; - this->m_currOffset = std::visit([this, node, type = valueNode->getType()] (auto &&value) { - if (Token::isFloatingPoint(type)) - this->getConsole().abortEvaluation("placement offset must be an integer value"); + this->m_currOffset = std::visit([this] (auto &&value) { + using Type = std::remove_cvref_t; + if constexpr (std::is_floating_point_v) + this->getConsole().abortEvaluation("bitfield entry size must be an integer value"); return static_cast(value); }, valueNode->getValue()); } @@ -790,9 +904,10 @@ namespace hex::lang { auto valueNode = this->evaluateMathematicalExpression(numericExpression); ON_SCOPE_EXIT { delete valueNode; }; - auto arraySize = std::visit([this, node, type = valueNode->getType()] (auto &&value) { - if (Token::isFloatingPoint(type)) - this->getConsole().abortEvaluation("array size must be an integer value"); + auto arraySize = std::visit([this] (auto &&value) { + using Type = std::remove_cvref_t; + if constexpr (std::is_floating_point_v) + this->getConsole().abortEvaluation("bitfield entry size must be an integer value"); return static_cast(value); }, valueNode->getValue()); @@ -874,9 +989,10 @@ namespace hex::lang { auto valueNode = evaluateMathematicalExpression(offset); ON_SCOPE_EXIT { delete valueNode; }; - pointerOffset = std::visit([this, node, type = valueNode->getType()] (auto &&value) { - if (Token::isFloatingPoint(type)) - this->getConsole().abortEvaluation("pointer offset must be an integer value"); + pointerOffset = std::visit([this] (auto &&value) { + using Type = std::remove_cvref_t; + if constexpr (std::is_floating_point_v) + this->getConsole().abortEvaluation("bitfield entry size must be an integer value"); return static_cast(value); }, valueNode->getValue()); this->m_currOffset = pointerOffset; @@ -938,6 +1054,7 @@ namespace hex::lang { this->m_globalMembers.clear(); this->m_types.clear(); this->m_endianStack.clear(); + this->m_definedFunctions.clear(); this->m_currOffset = 0; try { @@ -974,6 +1091,8 @@ namespace hex::lang { } else if (auto functionCallNode = dynamic_cast(node); functionCallNode != nullptr) { auto result = this->evaluateFunctionCall(functionCallNode); delete result; + } else if (auto functionDefNode = dynamic_cast(node); functionDefNode != nullptr) { + this->evaluateFunctionDefinition(functionDefNode); } if (pattern != nullptr) diff --git a/plugins/libimhex/source/lang/lexer.cpp b/plugins/libimhex/source/lang/lexer.cpp index 7611dde5e..daa559971 100644 --- a/plugins/libimhex/source/lang/lexer.cpp +++ b/plugins/libimhex/source/lang/lexer.cpp @@ -121,20 +121,20 @@ namespace hex::lang { } switch (type) { - case Token::ValueType::Unsigned32Bit: return {{ type, u32(integer) }}; - case Token::ValueType::Signed32Bit: return {{ type, s32(integer) }}; - case Token::ValueType::Unsigned64Bit: return {{ type, u64(integer) }}; - case Token::ValueType::Signed64Bit: return {{ type, s64(integer) }}; - case Token::ValueType::Unsigned128Bit: return {{ type, u128(integer) }}; - case Token::ValueType::Signed128Bit: return {{ type, s128(integer) }}; + case Token::ValueType::Unsigned32Bit: return { u32(integer) }; + case Token::ValueType::Signed32Bit: return { s32(integer) }; + case Token::ValueType::Unsigned64Bit: return { u64(integer) }; + case Token::ValueType::Signed64Bit: return { s64(integer) }; + case Token::ValueType::Unsigned128Bit: return { u128(integer) }; + case Token::ValueType::Signed128Bit: return { s128(integer) }; default: return { }; } } else if (Token::isFloatingPoint(type)) { double floatingPoint = strtod(numberData.data(), nullptr); switch (type) { - case Token::ValueType::Float: return {{ type, float(floatingPoint) }}; - case Token::ValueType::Double: return {{ type, double(floatingPoint) }}; + case Token::ValueType::Float: return { float(floatingPoint) }; + case Token::ValueType::Double: return { double(floatingPoint) }; default: return { }; } } @@ -381,7 +381,7 @@ namespace hex::lang { auto [c, charSize] = character.value(); - tokens.emplace_back(VALUE_TOKEN(Integer, Token::IntegerLiteral(Token::ValueType::Character, c) )); + tokens.emplace_back(VALUE_TOKEN(Integer, c)); offset += charSize; } else if (c == '\"') { auto string = getStringLiteral(code.substr(offset)); @@ -417,13 +417,17 @@ namespace hex::lang { else if (identifier == "else") tokens.emplace_back(TOKEN(Keyword, Else)); else if (identifier == "false") - tokens.emplace_back(VALUE_TOKEN(Integer, Token::IntegerLiteral(Token::ValueType::Boolean, s32(0)))); + tokens.emplace_back(VALUE_TOKEN(Integer, bool(0))); else if (identifier == "true") - tokens.emplace_back(VALUE_TOKEN(Integer, Token::IntegerLiteral(Token::ValueType::Boolean, s32(1)))); + tokens.emplace_back(VALUE_TOKEN(Integer, bool(1))); else if (identifier == "parent") tokens.emplace_back(TOKEN(Keyword, Parent)); else if (identifier == "while") tokens.emplace_back(TOKEN(Keyword, While)); + else if (identifier == "fn") + tokens.emplace_back(TOKEN(Keyword, Function)); + else if (identifier == "return") + tokens.emplace_back(TOKEN(Keyword, Return)); // Check for built-in types else if (identifier == "u8") diff --git a/plugins/libimhex/source/lang/parser.cpp b/plugins/libimhex/source/lang/parser.cpp index 58eedbcbe..8f2cab371 100644 --- a/plugins/libimhex/source/lang/parser.cpp +++ b/plugins/libimhex/source/lang/parser.cpp @@ -5,7 +5,7 @@ #define MATCHES(x) (begin() && x) -#define TO_NUMERIC_EXPRESSION(node) new ASTNodeNumericExpression((node), new ASTNodeIntegerLiteral({ Token::ValueType::Any, s32(0) }), Token::Operator::Plus) +#define TO_NUMERIC_EXPRESSION(node) new ASTNodeNumericExpression((node), new ASTNodeIntegerLiteral(s32(0)), Token::Operator::Plus) // Definition syntax: // [A] : Either A or no token @@ -132,7 +132,7 @@ namespace hex::lang { if (MATCHES(oneOf(OPERATOR_PLUS, OPERATOR_MINUS, OPERATOR_BOOLNOT, OPERATOR_BITNOT))) { auto op = getValue(-1); - return new ASTNodeNumericExpression(new ASTNodeIntegerLiteral({ Token::ValueType::Any, 0 }), this->parseFactor(), op); + return new ASTNodeNumericExpression(new ASTNodeIntegerLiteral(0), this->parseFactor(), op); } return this->parseFactor(); @@ -358,6 +358,72 @@ namespace hex::lang { throwParseError("unfinished attribute. Expected ']]'"); } + /* Functions */ + + ASTNode* Parser::parseFunctionDefintion() { + const auto &functionName = getValue(-2); + std::vector params; + + // Parse parameter list + while (MATCHES(sequence(IDENTIFIER))) { + params.push_back(getValue(-1)); + + if (!MATCHES(sequence(SEPARATOR_COMMA))) { + if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) + break; + else + throwParseError("expected closing ')' after parameter list"); + } + } + + if (!MATCHES(sequence(SEPARATOR_CURLYBRACKETOPEN))) + throwParseError("expected opening '{' after function definition"); + + + // Parse function body + std::vector body; + auto bodyCleanup = SCOPE_GUARD { + for (auto &node : body) + delete node; + }; + + while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) { + ASTNode *statement; + if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) + statement = parseFunctionCall(); + else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(IDENTIFIER, SEPARATOR_SQUAREBRACKETOPEN) && sequence(SEPARATOR_SQUAREBRACKETOPEN))) + statement = parseMemberArrayVariable(); + else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(IDENTIFIER))) + statement = parseMemberVariable(); + else if (MATCHES(sequence(IDENTIFIER, OPERATOR_ASSIGNMENT))) + statement = parseVariableAssignment(); + else if (MATCHES(sequence(KEYWORD_RETURN))) + statement = parseReturnStatement(); + else + throwParseError("invalid sequence", 0); + + body.push_back(statement); + + if (!MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION))) + throwParseError("missing ';' at end of expression", -1); + } + + bodyCleanup.release(); + return new ASTNodeFunctionDefinition(functionName, params, body); + } + + ASTNode* Parser::parseVariableAssignment() { + const auto &lvalue = getValue(-2); + + auto rvalue = this->parseMathematicalExpression(); + + return new ASTNodeAssignment(lvalue, rvalue); + } + + ASTNode* Parser::parseReturnStatement() { + return new ASTNodeReturnStatement(this->parseMathematicalExpression()); + } + /* Control flow */ // if ((parseMathematicalExpression)) { (parseMember) } @@ -606,9 +672,9 @@ namespace hex::lang { ASTNode *valueExpr; auto name = getValue(-1); if (enumNode->getEntries().empty()) - valueExpr = lastEntry = TO_NUMERIC_EXPRESSION(new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, u8(0) })); + valueExpr = lastEntry = TO_NUMERIC_EXPRESSION(new ASTNodeIntegerLiteral(u8(0))); else - valueExpr = new ASTNodeNumericExpression(lastEntry->clone(), new ASTNodeIntegerLiteral({ Token::ValueType::Any, s32(1) }), Token::Operator::Plus); + valueExpr = new ASTNodeNumericExpression(lastEntry->clone(), new ASTNodeIntegerLiteral(s32(1)), Token::Operator::Plus); enumNode->addEntry(name, valueExpr); } @@ -743,6 +809,8 @@ namespace hex::lang { statement = parseBitfield(); else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) statement = parseFunctionCall(); + else if (MATCHES(sequence(KEYWORD_FUNCTION, IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) + statement = parseFunctionDefintion(); else throwParseError("invalid sequence", 0); if (MATCHES(sequence(SEPARATOR_SQUAREBRACKETOPEN, SEPARATOR_SQUAREBRACKETOPEN))) diff --git a/source/views/view_pattern.cpp b/source/views/view_pattern.cpp index 498c7d6e0..102daf177 100644 --- a/source/views/view_pattern.cpp +++ b/source/views/view_pattern.cpp @@ -15,7 +15,7 @@ namespace hex { static TextEditor::LanguageDefinition langDef; if (!initialized) { static const char* const keywords[] = { - "using", "struct", "union", "enum", "bitfield", "be", "le", "if", "else", "false", "true", "parent", "addressof", "sizeof", "$", "while" + "using", "struct", "union", "enum", "bitfield", "be", "le", "if", "else", "false", "true", "parent", "addressof", "sizeof", "$", "while", "fn", "return" }; for (auto& k : keywords) langDef.mKeywords.insert(k); diff --git a/source/window.cpp b/source/window.cpp index b137c62c2..ef411390f 100644 --- a/source/window.cpp +++ b/source/window.cpp @@ -563,7 +563,7 @@ namespace hex { void Window::initGLFW() { glfwSetErrorCallback([](int error, const char* desc) { - fprintf(stderr, "Glfw Error %d: %s\n", error, desc); + log::error("GLFW Error [{}] : {}", error, desc); }); if (!glfwInit())