#include "lang/evaluator.hpp" #include #include namespace hex::lang { Evaluator::Evaluator(prv::Provider* &provider, std::endian dataEndianess) : m_provider(provider), m_dataEndianess(dataEndianess) { PatternData::setEndianess(dataEndianess); } std::pair Evaluator::createStructPattern(ASTNodeVariableDecl *varDeclNode, u64 offset) { std::vector members; auto structNode = static_cast(this->m_types[varDeclNode->getCustomVariableTypeName()]); if (structNode == nullptr) { this->m_error = { varDeclNode->getLineNumber(), hex::format("'%s' does not name a type", varDeclNode->getCustomVariableTypeName().c_str()) }; return { nullptr, 0 }; } size_t structSize = 0; for (const auto &node : structNode->getNodes()) { const auto &member = static_cast(node); u64 memberOffset = 0; if (member->getPointerSize().has_value()) { this->m_provider->read(offset + structSize, &memberOffset, member->getPointerSize().value()); memberOffset = hex::changeEndianess(memberOffset, member->getPointerSize().value(), this->m_dataEndianess); } else memberOffset = offset + structSize; const auto typeDeclNode = static_cast(this->m_types[member->getCustomVariableTypeName()]); PatternData *pattern = nullptr; u64 memberSize = 0; if (member->getVariableType() == Token::TypeToken::Type::Signed8Bit && member->getArraySize() > 1) { std::tie(pattern, memberSize) = this->createStringPattern(member, memberOffset); } else if (member->getVariableType() == Token::TypeToken::Type::CustomType && typeDeclNode != nullptr && typeDeclNode->getAssignedType() == Token::TypeToken::Type::Signed8Bit && member->getArraySize() > 1) { std::tie(pattern, memberSize) = this->createStringPattern(member, memberOffset); } else if (member->getArraySize() > 1) { std::tie(pattern, memberSize) = this->createArrayPattern(member, memberOffset); } else if (member->getArraySizeVariable().has_value()) { std::optional arraySize; for (auto &prevMember : members) { if (prevMember->getPatternType() == PatternData::Type::Unsigned && prevMember->getName() == member->getArraySizeVariable()) { u64 value = 0; this->m_provider->read(prevMember->getOffset(), &value, prevMember->getSize()); value = hex::changeEndianess(value, prevMember->getSize(), this->m_dataEndianess); arraySize = value; } } if (!arraySize.has_value()) { this->m_error = { varDeclNode->getLineNumber(), hex::format("'%s' does not name a previous member of '%s'", member->getArraySizeVariable().value().c_str(), varDeclNode->getCustomVariableTypeName().c_str()) }; return { nullptr, 0 }; } ASTNodeVariableDecl *processedMember = new ASTNodeVariableDecl(member->getLineNumber(), member->getVariableType(), member->getVariableName(), member->getCustomVariableTypeName(), member->getOffset(), arraySize.value()); std::tie(pattern, memberSize) = this->createArrayPattern(processedMember, memberOffset); } else if (member->getVariableType() != Token::TypeToken::Type::CustomType) { std::tie(pattern, memberSize) = this->createBuiltInTypePattern(member, memberOffset); } else { std::tie(pattern, memberSize) = this->createCustomTypePattern(member, memberOffset); } if (pattern == nullptr) return { nullptr, 0 }; if (member->getPointerSize().has_value()) { members.push_back(new PatternDataPointer(offset + structSize, member->getPointerSize().value(), member->getVariableName(), pattern)); structSize += member->getPointerSize().value(); } else { members.push_back(pattern); structSize += memberSize; } } return { new PatternDataStruct(offset, structSize, varDeclNode->getVariableName(), structNode->getName(), members, 0x00FFFFFF), structSize }; } std::pair Evaluator::createUnionPattern(ASTNodeVariableDecl *varDeclNode, u64 offset) { std::vector members; auto unionNode = static_cast(this->m_types[varDeclNode->getCustomVariableTypeName()]); if (unionNode == nullptr) { this->m_error = { varDeclNode->getLineNumber(), hex::format("'%s' does not name a type", varDeclNode->getCustomVariableTypeName().c_str()) }; return { nullptr, 0 }; } size_t unionSize = 0; for (const auto &node : unionNode->getNodes()) { const auto &member = static_cast(node); u64 memberOffset = 0; if (member->getPointerSize().has_value()) { this->m_provider->read(offset + unionSize, &memberOffset, member->getPointerSize().value()); memberOffset = hex::changeEndianess(memberOffset, member->getPointerSize().value(), this->m_dataEndianess); } else memberOffset = offset; const auto typeDeclNode = static_cast(this->m_types[member->getCustomVariableTypeName()]); PatternData *pattern = nullptr; u64 memberSize = 0; if (member->getVariableType() == Token::TypeToken::Type::Signed8Bit && member->getArraySize() > 1) { std::tie(pattern, memberSize) = this->createStringPattern(member, memberOffset); } else if (member->getVariableType() == Token::TypeToken::Type::CustomType && typeDeclNode != nullptr && typeDeclNode->getAssignedType() == Token::TypeToken::Type::Signed8Bit && member->getArraySize() > 1) { std::tie(pattern, memberSize) = this->createStringPattern(member, memberOffset); } else if (member->getArraySize() > 1) { std::tie(pattern, memberSize) = this->createArrayPattern(member, memberOffset); } else if (member->getArraySizeVariable().has_value()) { std::optional arraySize; for (auto &prevMember : members) { if (prevMember->getPatternType() == PatternData::Type::Unsigned && prevMember->getName() == member->getArraySizeVariable()) { u64 value = 0; this->m_provider->read(prevMember->getOffset(), &value, prevMember->getSize()); value = hex::changeEndianess(value, prevMember->getSize(), this->m_dataEndianess); arraySize = value; } } if (!arraySize.has_value()) { this->m_error = { varDeclNode->getLineNumber(), hex::format("'%s' does not name a previous member of '%s'", member->getArraySizeVariable().value().c_str(), varDeclNode->getCustomVariableTypeName().c_str()) }; return { nullptr, 0 }; } if (arraySize.value() == 0) { this->m_error = { varDeclNode->getLineNumber(), hex::format("Value of '%s' is zero", member->getArraySizeVariable().value().c_str()) }; return { nullptr, 0 }; } ASTNodeVariableDecl *processedMember = new ASTNodeVariableDecl(member->getLineNumber(), member->getVariableType(), member->getVariableName(), member->getCustomVariableTypeName(), member->getOffset(), arraySize.value()); std::tie(pattern, memberSize) = this->createArrayPattern(processedMember, memberOffset); } else if (member->getVariableType() != Token::TypeToken::Type::CustomType) { std::tie(pattern, memberSize) = this->createBuiltInTypePattern(member, memberOffset); } else { std::tie(pattern, memberSize) = this->createCustomTypePattern(member, memberOffset); } if (pattern == nullptr) return { nullptr, 0 }; if (member->getPointerSize().has_value()) { members.push_back(new PatternDataPointer(offset, member->getPointerSize().value(), member->getVariableName(), pattern)); unionSize = std::max(size_t(member->getPointerSize().value()), unionSize); } else { members.push_back(pattern); unionSize = std::max(memberSize, unionSize); } } return { new PatternDataUnion(offset, unionSize, varDeclNode->getVariableName(), unionNode->getName(), members, 0x00FFFFFF), unionSize }; } std::pair Evaluator::createEnumPattern(ASTNodeVariableDecl *varDeclNode, u64 offset) { auto *enumType = static_cast(this->m_types[varDeclNode->getCustomVariableTypeName()]); if (enumType == nullptr) { this->m_error = { varDeclNode->getLineNumber(), hex::format("'%s' does not name a type", varDeclNode->getCustomVariableTypeName().c_str()) }; return { nullptr, 0 }; } size_t size = getTypeSize(enumType->getUnderlyingType()); return { new PatternDataEnum(offset, size, varDeclNode->getVariableName(), enumType->getName(), enumType->getValues()), size }; } std::pair Evaluator::createBitfieldPattern(ASTNodeVariableDecl *varDeclNode, u64 offset) { auto *bitfieldType = static_cast(this->m_types[varDeclNode->getCustomVariableTypeName()]); if (bitfieldType == nullptr) { this->m_error = { varDeclNode->getLineNumber(), hex::format("'%s' does not name a type", varDeclNode->getCustomVariableTypeName().c_str()) }; return { nullptr, 0 }; } size_t size = 0; for (auto &[fieldName, fieldSize] : bitfieldType->getFields()) size += fieldSize; size = std::bit_ceil(size) / 8; return { new PatternDataBitfield(offset, size, varDeclNode->getVariableName(), bitfieldType->getName(), bitfieldType->getFields()), size }; } std::pair Evaluator::createArrayPattern(ASTNodeVariableDecl *varDeclNode, u64 offset) { std::vector entries; size_t arrayOffset = 0; std::optional arrayColor; for (u32 i = 0; i < varDeclNode->getArraySize(); i++) { ASTNodeVariableDecl *nonArrayVarDeclNode = new ASTNodeVariableDecl(varDeclNode->getLineNumber(), varDeclNode->getVariableType(), "[" + std::to_string(i) + "]", varDeclNode->getCustomVariableTypeName(), varDeclNode->getOffset(), 1); if (varDeclNode->getVariableType() == Token::TypeToken::Type::Padding) { return { new PatternDataPadding(offset, varDeclNode->getArraySize()), varDeclNode->getArraySize() }; } else if (varDeclNode->getVariableType() != Token::TypeToken::Type::CustomType) { const auto& [pattern, size] = this->createBuiltInTypePattern(nonArrayVarDeclNode, offset + arrayOffset); if (pattern == nullptr) return { nullptr, 0 }; if (!arrayColor.has_value()) arrayColor = pattern->getColor(); pattern->setColor(arrayColor.value()); entries.push_back(pattern); arrayOffset += size; } else { const auto &[pattern, size] = this->createCustomTypePattern(nonArrayVarDeclNode, offset + arrayOffset); if (pattern == nullptr) return { nullptr, 0 }; if (!arrayColor.has_value()) arrayColor = pattern->getColor(); pattern->setColor(arrayColor.value()); entries.push_back(pattern); arrayOffset += size; } delete nonArrayVarDeclNode; } return { new PatternDataArray(offset, arrayOffset, varDeclNode->getVariableName(), entries, arrayColor.value()), arrayOffset }; } std::pair Evaluator::createStringPattern(ASTNodeVariableDecl *varDeclNode, u64 offset) { size_t arraySize = varDeclNode->getArraySize(); return { new PatternDataString(offset, arraySize, varDeclNode->getVariableName()), arraySize }; } std::pair Evaluator::createCustomTypePattern(ASTNodeVariableDecl *varDeclNode, u64 offset) { auto &currType = this->m_types[varDeclNode->getCustomVariableTypeName()]; if (currType == nullptr) { this->m_error = { varDeclNode->getLineNumber(), hex::format("'%s' does not name a type", varDeclNode->getCustomVariableTypeName().c_str()) }; return { nullptr, 0 }; } switch (currType->getType()) { case ASTNode::Type::Struct: return this->createStructPattern(varDeclNode, offset); case ASTNode::Type::Union: return this->createUnionPattern(varDeclNode, offset); case ASTNode::Type::Enum: return this->createEnumPattern(varDeclNode, offset); case ASTNode::Type::Bitfield: return this->createBitfieldPattern(varDeclNode, offset); case ASTNode::Type::TypeDecl: return this->createBuiltInTypePattern(varDeclNode, offset); } return { nullptr, 0 }; } std::pair Evaluator::createBuiltInTypePattern(ASTNodeVariableDecl *varDeclNode, u64 offset) { auto type = varDeclNode->getVariableType(); if (type == Token::TypeToken::Type::CustomType) { const auto &currType = static_cast(this->m_types[varDeclNode->getCustomVariableTypeName()]); if (currType == nullptr) { this->m_error = { varDeclNode->getLineNumber(), hex::format("'%s' does not name a type", varDeclNode->getCustomVariableTypeName().c_str()) }; return { nullptr, 0 }; } type = currType->getAssignedType(); } size_t typeSize = getTypeSize(type); size_t arraySize = varDeclNode->getArraySize(); if (isSigned(type)) { if (typeSize == 1 && arraySize == 1) return { new PatternDataCharacter(offset, typeSize, varDeclNode->getVariableName()), 1 }; else if (arraySize > 1) return createArrayPattern(varDeclNode, offset); else return { new PatternDataSigned(offset, typeSize, varDeclNode->getVariableName()), typeSize * arraySize }; } else if (isUnsigned(varDeclNode->getVariableType())) { if (arraySize > 1) return createArrayPattern(varDeclNode, offset); else return { new PatternDataUnsigned(offset, typeSize, varDeclNode->getVariableName()), typeSize * arraySize }; } else if (isFloatingPoint(varDeclNode->getVariableType())) { if (arraySize > 1) return createArrayPattern(varDeclNode, offset); else return { new PatternDataFloat(offset, typeSize, varDeclNode->getVariableName()), typeSize * arraySize }; } return { nullptr, 0 }; } std::pair> Evaluator::evaluate(const std::vector &ast) { // Evaluate types for (const auto &node : ast) { switch(node->getType()) { case ASTNode::Type::Struct: { auto *structNode = static_cast(node); this->m_types.emplace(structNode->getName(), structNode); } break; case ASTNode::Type::Union: { auto *unionNode = static_cast(node); this->m_types.emplace(unionNode->getName(), unionNode); } break; case ASTNode::Type::Enum: { auto *enumNode = static_cast(node); this->m_types.emplace(enumNode->getName(), enumNode); } break; case ASTNode::Type::Bitfield: { auto *bitfieldNode = static_cast(node); this->m_types.emplace(bitfieldNode->getName(), bitfieldNode); } break; case ASTNode::Type::TypeDecl: { auto *typeDeclNode = static_cast(node); if (typeDeclNode->getAssignedType() == Token::TypeToken::Type::CustomType) this->m_types.emplace(typeDeclNode->getTypeName(), this->m_types[typeDeclNode->getAssignedCustomTypeName()]); else this->m_types.emplace(typeDeclNode->getTypeName(), typeDeclNode); } break; case ASTNode::Type::VariableDecl: break; case ASTNode::Type::Scope: break; } } // Evaluate variable declarations std::vector variables; for (const auto &node : ast) { if (node->getType() != ASTNode::Type::VariableDecl) continue; auto *varDeclNode = static_cast(node); if (varDeclNode->getVariableType() == Token::TypeToken::Type::Signed8Bit && varDeclNode->getArraySize() > 1) { const auto &[pattern, _] = createStringPattern(varDeclNode, varDeclNode->getOffset().value()); variables.push_back(pattern); } else if (varDeclNode->getArraySize() > 1) { const auto &[pattern, _] = this->createArrayPattern(varDeclNode, varDeclNode->getOffset().value()); variables.push_back(pattern); } else if (varDeclNode->getVariableType() != Token::TypeToken::Type::CustomType) { const auto &[pattern, _] = this->createBuiltInTypePattern(varDeclNode, varDeclNode->getOffset().value()); variables.push_back(pattern); } else { const auto &[pattern, _] = this->createCustomTypePattern(varDeclNode, varDeclNode->getOffset().value()); variables.push_back(pattern); } } for (const auto &var : variables) if (var == nullptr) return { ResultEvaluatorError, { } }; return { ResultSuccess, variables }; } }