diff --git a/CMakeLists.txt b/CMakeLists.txt index bdecfc7d9..3a49b8539 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -135,6 +135,7 @@ add_executable(imhex ${application_type} source/lang/parser.cpp source/lang/validator.cpp source/lang/evaluator.cpp + source/lang/builtin_functions.cpp source/providers/file_provider.cpp diff --git a/include/lang/ast_node.hpp b/include/lang/ast_node.hpp index a293fa033..9ac151856 100644 --- a/include/lang/ast_node.hpp +++ b/include/lang/ast_node.hpp @@ -442,4 +442,39 @@ namespace hex::lang { std::vector m_trueBody, m_falseBody; }; + class ASTNodeFunctionCall : public ASTNode { + public: + explicit ASTNodeFunctionCall(std::string_view functionName, std::vector params) + : ASTNode(), m_functionName(functionName), m_params(std::move(params)) { } + + ~ASTNodeFunctionCall() override { + for (auto ¶m : this->m_params) + delete param; + } + + ASTNodeFunctionCall(const ASTNodeFunctionCall &other) : ASTNode(other) { + this->m_functionName = other.m_functionName; + + for (auto ¶m : other.m_params) + this->m_params.push_back(param->clone()); + } + + ASTNode* clone() const override { + return new ASTNodeFunctionCall(*this); + } + + [[nodiscard]] std::string_view getFunctionName() { + return this->m_functionName; + } + + [[nodiscard]] const std::vector& getParams() const { + return this->m_params; + } + + private: + std::string m_functionName; + std::vector m_params; + }; + + } \ No newline at end of file diff --git a/include/lang/evaluator.hpp b/include/lang/evaluator.hpp index 3ac937fc2..441e46c79 100644 --- a/include/lang/evaluator.hpp +++ b/include/lang/evaluator.hpp @@ -21,6 +21,17 @@ namespace hex::lang { const std::pair& getError() { return this->m_error; } + + struct Function { + constexpr static u32 UnlimitedParameters = 0xFFFF'FFFF; + constexpr static u32 MoreParametersThan = 0x8000'0000; + constexpr static u32 LessParametersThan = 0x4000'0000; + constexpr static u32 NoParameters = 0x0000'0000; + + u32 parameterCount; + std::function)> func; + }; + private: std::map m_types; prv::Provider* &m_provider; @@ -28,6 +39,7 @@ namespace hex::lang { u64 m_currOffset = 0; std::optional m_currEndian; std::vector*> m_currMembers; + std::map m_functions; std::pair m_error; @@ -41,8 +53,16 @@ namespace hex::lang { return this->m_currEndian.value_or(this->m_defaultDataEndian); } + void addFunction(std::string_view name, u32 parameterCount, std::function)> func) { + if (this->m_functions.contains(name.data())) + throwEvaluateError(hex::format("redefinition of function '%s'", name.data()), 1); + + this->m_functions[name.data()] = { parameterCount, func }; + } + ASTNodeIntegerLiteral* evaluateScopeResolution(ASTNodeScopeResolution *node); ASTNodeIntegerLiteral* evaluateRValue(ASTNodeRValue *node); + ASTNodeIntegerLiteral* evaluateFunctionCall(ASTNodeFunctionCall *node); ASTNodeIntegerLiteral* evaluateOperator(ASTNodeIntegerLiteral *left, ASTNodeIntegerLiteral *right, Token::Operator op); ASTNodeIntegerLiteral* evaluateOperand(ASTNode *node); ASTNodeIntegerLiteral* evaluateTernaryExpression(ASTNodeTernaryExpression *node); @@ -59,6 +79,14 @@ namespace hex::lang { PatternData* evaluateArray(ASTNodeArrayVariableDecl *node); PatternData* evaluatePointer(ASTNodePointerVariableDecl *node); + + #define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* name(std::vector params) + + BUILTIN_FUNCTION(findSequence); + BUILTIN_FUNCTION(readUnsigned); + BUILTIN_FUNCTION(readSigned); + + #undef BUILTIN_FUNCTION }; } \ No newline at end of file diff --git a/include/lang/parser.hpp b/include/lang/parser.hpp index d3e384da7..f494ba956 100644 --- a/include/lang/parser.hpp +++ b/include/lang/parser.hpp @@ -53,6 +53,7 @@ namespace hex::lang { return this->m_curr[index].type; } + ASTNode* parseFunctionCall(); ASTNode* parseScopeResolution(std::vector &path); ASTNode* parseRValue(std::vector &path); ASTNode* parseFactor(); diff --git a/source/lang/builtin_functions.cpp b/source/lang/builtin_functions.cpp new file mode 100644 index 000000000..04c4c2eac --- /dev/null +++ b/source/lang/builtin_functions.cpp @@ -0,0 +1,92 @@ +#include "lang/evaluator.hpp" + +namespace hex::lang { + + #define BUILTIN_FUNCTION(name) ASTNodeIntegerLiteral* Evaluator::name(std::vector params) + + #define LITERAL_COMPARE(literal, cond) std::visit([&, this](auto &&literal) { return (cond) != 0; }, literal) + + BUILTIN_FUNCTION(findSequence) { + auto& occurrenceIndex = params[0]->getValue(); + std::vector sequence; + for (u32 i = 1; i < params.size(); i++) { + sequence.push_back(std::visit([](auto &&value) -> u8 { + if (value <= 0xFF) + return value; + else + throwEvaluateError("sequence bytes need to fit into 1 byte", 1); + }, params[i]->getValue())); + } + + std::vector bytes(sequence.size(), 0x00); + u32 occurrences = 0; + for (u64 offset = 0; offset < this->m_provider->getSize() - sequence.size(); offset++) { + this->m_provider->read(offset, bytes.data(), bytes.size()); + + if (bytes == sequence) { + if (LITERAL_COMPARE(occurrenceIndex, occurrenceIndex < occurrences)) { + occurrences++; + continue; + } + + return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, offset }); + } + } + + throwEvaluateError("failed to find sequence", 1); + } + + BUILTIN_FUNCTION(readUnsigned) { + auto address = params[0]->getValue(); + auto size = params[1]->getValue(); + + if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize())) + throwEvaluateError("address out of range", 1); + + return std::visit([this](auto &&address, auto &&size) { + if (size <= 0 || size > 16) + throwEvaluateError("invalid read size", 1); + + u8 value[(u8)size]; + this->m_provider->read(address, value, size); + + switch ((u8)size) { + case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast(value), 1, this->getCurrentEndian()) }); + case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast(value), 2, this->getCurrentEndian()) }); + case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned32Bit, hex::changeEndianess(*reinterpret_cast(value), 4, this->getCurrentEndian()) }); + case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, hex::changeEndianess(*reinterpret_cast(value), 8, this->getCurrentEndian()) }); + case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned128Bit, hex::changeEndianess(*reinterpret_cast(value), 16, this->getCurrentEndian()) }); + default: throwEvaluateError("invalid rvalue size", 1); + } + }, address, size); + } + + BUILTIN_FUNCTION(readSigned) { + auto address = params[0]->getValue(); + auto size = params[1]->getValue(); + + if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize())) + throwEvaluateError("address out of range", 1); + + return std::visit([this](auto &&address, auto &&size) { + if (size <= 0 || size > 16) + throwEvaluateError("invalid read size", 1); + + u8 value[(u8)size]; + this->m_provider->read(address, value, size); + + switch ((u8)size) { + case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast(value), 1, this->getCurrentEndian()) }); + case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast(value), 2, this->getCurrentEndian()) }); + case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast(value), 4, this->getCurrentEndian()) }); + case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast(value), 8, this->getCurrentEndian()) }); + case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast(value), 16, this->getCurrentEndian()) }); + default: throwEvaluateError("invalid rvalue size", 1); + } + }, address, size); + } + + + #undef BUILTIN_FUNCTION + +} \ No newline at end of file diff --git a/source/lang/evaluator.cpp b/source/lang/evaluator.cpp index 28a915619..e3a886e5f 100644 --- a/source/lang/evaluator.cpp +++ b/source/lang/evaluator.cpp @@ -12,6 +12,18 @@ namespace hex::lang { Evaluator::Evaluator(prv::Provider* &provider, std::endian defaultDataEndian) : m_provider(provider), m_defaultDataEndian(defaultDataEndian) { + + this->addFunction("findSequence", Function::MoreParametersThan | 1, [this](auto params) { + return this->findSequence(params); + }); + + this->addFunction("readUnsigned", 2, [this](auto params) { + return this->readUnsigned(params); + }); + + this->addFunction("readSigned", 2, [this](auto params) { + return this->readSigned(params); + }); } ASTNodeIntegerLiteral* Evaluator::evaluateScopeResolution(ASTNodeScopeResolution *node) { @@ -84,8 +96,6 @@ namespace hex::lang { u8 value[enumPattern->getSize()]; this->m_provider->read(enumPattern->getOffset(), value, enumPattern->getSize()); - - switch (enumPattern->getSize()) { case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast(value), 1, this->getCurrentEndian()) }); case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast(value), 2, this->getCurrentEndian()) }); @@ -98,6 +108,37 @@ namespace hex::lang { throwEvaluateError("tried to use non-integer value in numeric expression", node->getLineNumber()); } + ASTNodeIntegerLiteral* Evaluator::evaluateFunctionCall(ASTNodeFunctionCall *node) { + std::vector evaluatedParams; + ScopeExit paramCleanup([&] { + for (auto ¶m : evaluatedParams) + delete param; + }); + + for (auto ¶m : node->getParams()) + evaluatedParams.push_back(this->evaluateMathematicalExpression(static_cast(param))); + + if (!this->m_functions.contains(node->getFunctionName().data())) + throwEvaluateError(hex::format("no function named '%s' found", node->getFunctionName().data()), node->getLineNumber()); + + auto &function = this->m_functions[node->getFunctionName().data()]; + + if (function.parameterCount == Function::UnlimitedParameters) { + ; // Don't check parameter count + } + else if (function.parameterCount & Function::LessParametersThan) { + if (evaluatedParams.size() >= (function.parameterCount & ~Function::LessParametersThan)) + throwEvaluateError(hex::format("too many parameters for function '%s'. Expected %d", node->getFunctionName().data(), function.parameterCount & ~Function::LessParametersThan), node->getLineNumber()); + } else if (function.parameterCount & Function::MoreParametersThan) { + if (evaluatedParams.size() <= (function.parameterCount & ~Function::MoreParametersThan)) + throwEvaluateError(hex::format("too few parameters for function '%s'. Expected %d", node->getFunctionName().data(), function.parameterCount & ~Function::MoreParametersThan), node->getLineNumber()); + } else if (function.parameterCount != evaluatedParams.size()) { + throwEvaluateError(hex::format("invalid number of parameters for function '%s'. Expected %d", node->getFunctionName().data(), function.parameterCount), node->getLineNumber()); + } + + return function.func(evaluatedParams); + } + #define FLOAT_BIT_OPERATION(name) \ auto name(std::floating_point auto left, auto right) { throw std::runtime_error(""); return 0; } \ auto name(auto left, std::floating_point auto right) { throw std::runtime_error(""); return 0; } \ @@ -220,6 +261,8 @@ namespace hex::lang { return evaluateScopeResolution(exprScopeResolution); else if (auto exprTernary = dynamic_cast(node); exprTernary != nullptr) return evaluateTernaryExpression(exprTernary); + else if (auto exprFunctionCall = dynamic_cast(node); exprFunctionCall != nullptr) + return evaluateFunctionCall(exprFunctionCall); else throwEvaluateError("invalid operand", node->getLineNumber()); } diff --git a/source/lang/lexer.cpp b/source/lang/lexer.cpp index 8c59cae9f..13a2136da 100644 --- a/source/lang/lexer.cpp +++ b/source/lang/lexer.cpp @@ -54,12 +54,14 @@ namespace hex::lang { } else if (numberData.ends_with("LL")) { type = Token::ValueType::Signed128Bit; numberData.remove_suffix(2); - } else if (numberData.ends_with('F')) { - type = Token::ValueType::Float; - numberData.remove_suffix(1); - } else if (numberData.ends_with('D')) { - type = Token::ValueType::Double; - numberData.remove_suffix(1); + } else if (!numberData.starts_with("0x") && !numberData.starts_with("0b")) { + if (numberData.ends_with('F')) { + type = Token::ValueType::Float; + numberData.remove_suffix(1); + } else if (numberData.ends_with('D')) { + type = Token::ValueType::Double; + numberData.remove_suffix(1); + } } if (numberData.starts_with("0x")) { diff --git a/source/lang/parser.cpp b/source/lang/parser.cpp index ab21357e9..267c2400b 100644 --- a/source/lang/parser.cpp +++ b/source/lang/parser.cpp @@ -18,6 +18,32 @@ namespace hex::lang { /* Mathematical expressions */ + // Identifier([(parseMathematicalExpression)|<(parseMathematicalExpression),...>(parseMathematicalExpression)] + ASTNode* Parser::parseFunctionCall() { + auto functionName = getValue(-2); + std::vector params; + ScopeExit paramCleanup([&]{ + for (auto ¶m : params) + delete param; + }); + + while (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) { + params.push_back(parseMathematicalExpression()); + + if (MATCHES(sequence(SEPARATOR_COMMA, SEPARATOR_ROUNDBRACKETCLOSE))) + throwParseError("unexpected ',' at end of function parameter list", -1); + else if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) + break; + else if (!MATCHES(sequence(SEPARATOR_COMMA))) + throwParseError("missing ',' between parameters", -1); + + } + + paramCleanup.release(); + + return TO_NUMERIC_EXPRESSION(new ASTNodeFunctionCall(functionName, params)); + } + // Identifier:: ASTNode* Parser::parseScopeResolution(std::vector &path) { if (peek(IDENTIFIER, -1)) @@ -59,12 +85,12 @@ namespace hex::lang { std::vector path; this->m_curr--; return this->parseScopeResolution(path); - } - else if (MATCHES(sequence(IDENTIFIER))) { + } else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) { + return this->parseFunctionCall(); + } else if (MATCHES(sequence(IDENTIFIER))) { std::vector path; return this->parseRValue(path); - } - else + } else throwParseError("expected integer or parenthesis"); }