diff --git a/plugins/libimhex/include/hex/data_processor/node.hpp b/plugins/libimhex/include/hex/data_processor/node.hpp index 810f40093..1fab2b429 100644 --- a/plugins/libimhex/include/hex/data_processor/node.hpp +++ b/plugins/libimhex/include/hex/data_processor/node.hpp @@ -2,6 +2,9 @@ #include +#include +#include + namespace hex::dp { class Node { @@ -31,10 +34,15 @@ namespace hex::dp { attribute.getOutputData().reset(); } + void resetProcessedInputs() { + this->m_processedInputs.clear(); + } + private: u32 m_id; std::string m_unlocalizedName; std::vector m_attributes; + std::set m_processedInputs; prv::Overlay *m_overlay = nullptr; Attribute* getConnectedInputAttribute(u32 index) { @@ -49,6 +57,12 @@ namespace hex::dp { return connectedAttribute.begin()->second; } + void markInputProcessed(u32 index) { + const auto &[iter, inserted] = this->m_processedInputs.insert(index); + if (!inserted) + throwNodeError("Recursion detected!"); + } + protected: [[noreturn]] void throwNodeError(std::string_view message) { @@ -64,6 +78,7 @@ namespace hex::dp { if (attribute->getType() != Attribute::Type::Buffer) throwNodeError("Tried to read buffer from non-buffer attribute"); + markInputProcessed(index); attribute->getParentNode()->process(); auto &outputData = attribute->getOutputData(); @@ -83,6 +98,7 @@ namespace hex::dp { if (attribute->getType() != Attribute::Type::Integer) throwNodeError("Tried to read integer from non-integer attribute"); + markInputProcessed(index); attribute->getParentNode()->process(); auto &outputData = attribute->getOutputData(); @@ -105,6 +121,7 @@ namespace hex::dp { if (attribute->getType() != Attribute::Type::Float) throwNodeError("Tried to read float from non-float attribute"); + markInputProcessed(index); attribute->getParentNode()->process(); auto &outputData = attribute->getOutputData(); diff --git a/plugins/libimhex/include/hex/lang/evaluator.hpp b/plugins/libimhex/include/hex/lang/evaluator.hpp index c289be06e..687361fef 100644 --- a/plugins/libimhex/include/hex/lang/evaluator.hpp +++ b/plugins/libimhex/include/hex/lang/evaluator.hpp @@ -24,6 +24,7 @@ namespace hex::lang { LogConsole& getConsole() { return this->m_console; } void setDefaultEndian(std::endian endian) { this->m_defaultDataEndian = endian; } + void setRecursionLimit(u32 limit) { this->m_recursionLimit = limit; } void setProvider(prv::Provider *provider) { this->m_provider = provider; } [[nodiscard]] std::endian getCurrentEndian() const { return this->m_endianStack.back(); } @@ -47,6 +48,8 @@ namespace hex::lang { std::vector*> m_currMembers; LogConsole m_console; + u32 m_recursionLimit = 16; + u32 m_currRecursionDepth; ASTNodeIntegerLiteral* evaluateScopeResolution(ASTNodeScopeResolution *node); diff --git a/plugins/libimhex/include/hex/lang/pattern_language.hpp b/plugins/libimhex/include/hex/lang/pattern_language.hpp index 9156eca91..e39c0b1c1 100644 --- a/plugins/libimhex/include/hex/lang/pattern_language.hpp +++ b/plugins/libimhex/include/hex/lang/pattern_language.hpp @@ -41,6 +41,7 @@ namespace hex::lang { prv::Provider *m_provider; std::endian m_defaultEndian; + u32 m_recursionLimit; std::optional> m_currError; }; diff --git a/plugins/libimhex/source/lang/evaluator.cpp b/plugins/libimhex/source/lang/evaluator.cpp index 354dc9443..228a4b5e9 100644 --- a/plugins/libimhex/source/lang/evaluator.cpp +++ b/plugins/libimhex/source/lang/evaluator.cpp @@ -538,6 +538,10 @@ namespace hex::lang { PatternData *pattern; + this->m_currRecursionDepth++; + if (this->m_currRecursionDepth > this->m_recursionLimit) + this->getConsole().abortEvaluation(hex::format("evaluation depth exceeds maximum of {0}. Use #pragma eval_depth to increase the maximum", this->m_recursionLimit)); + if (auto builtinTypeNode = dynamic_cast(type); builtinTypeNode != nullptr) return this->evaluateBuiltinType(builtinTypeNode); else if (auto typeDeclNode = dynamic_cast(type); typeDeclNode != nullptr) @@ -553,6 +557,8 @@ namespace hex::lang { else this->getConsole().abortEvaluation("type could not be evaluated"); + this->m_currRecursionDepth--; + if (!node->getName().empty()) pattern->setTypeName(node->getName().data()); @@ -752,6 +758,7 @@ namespace hex::lang { try { for (const auto& node : ast) { this->m_endianStack.push_back(this->m_defaultDataEndian); + this->m_currRecursionDepth = 0; if (auto variableDeclNode = dynamic_cast(node); variableDeclNode != nullptr) { this->m_globalMembers.push_back(this->evaluateVariable(variableDeclNode)); diff --git a/plugins/libimhex/source/lang/pattern_language.cpp b/plugins/libimhex/source/lang/pattern_language.cpp index ae6144c5d..0b7afa0ac 100644 --- a/plugins/libimhex/source/lang/pattern_language.cpp +++ b/plugins/libimhex/source/lang/pattern_language.cpp @@ -33,6 +33,16 @@ namespace hex::lang { } else return false; }); + + this->m_preprocessor->addPragmaHandler("eval_depth", [this](std::string value) { + auto limit = strtol(value.c_str(), nullptr, 0); + + if (limit <= 0) + return false; + + this->m_recursionLimit = limit; + return true; + }); this->m_preprocessor->addDefaultPragmaHandlers(); } @@ -56,6 +66,9 @@ namespace hex::lang { return { }; } + this->m_evaluator->setDefaultEndian(this->m_defaultEndian); + this->m_evaluator->setRecursionLimit(this->m_recursionLimit); + auto tokens = this->m_lexer->lex(preprocessedCode.value()); if (!tokens.has_value()) { this->m_currError = this->m_lexer->getError(); diff --git a/source/views/view_data_processor.cpp b/source/views/view_data_processor.cpp index b44fd8cd7..0224dc8a6 100644 --- a/source/views/view_data_processor.cpp +++ b/source/views/view_data_processor.cpp @@ -117,6 +117,10 @@ namespace hex { try { for (auto &endNode : this->m_endNodes) { endNode->resetOutputData(); + + for (auto &node : this->m_nodes) + node->resetProcessedInputs(); + endNode->process(); } } catch (dp::Node::NodeError &e) {