diff --git a/lib/libimhex/include/hex/api/task.hpp b/lib/libimhex/include/hex/api/task.hpp index 74aa34e00..c4729789f 100644 --- a/lib/libimhex/include/hex/api/task.hpp +++ b/lib/libimhex/include/hex/api/task.hpp @@ -10,6 +10,7 @@ #include #include #include +#include namespace hex { @@ -28,6 +29,7 @@ namespace hex { void update(u64 value = 0); void setMaxValue(u64 value); + [[nodiscard]] bool isRunning() const; [[nodiscard]] bool isBackgroundTask() const; [[nodiscard]] bool isFinished() const; [[nodiscard]] bool hadException() const; @@ -43,6 +45,8 @@ namespace hex { void setInterruptCallback(std::function callback); + void setRunning(bool running); + private: void finish(); void interruption(); @@ -52,16 +56,17 @@ namespace hex { mutable std::mutex m_mutex; std::string m_unlocalizedName; - u64 m_currValue = 0, m_maxValue = 0; - std::thread m_thread; + std::atomic m_currValue = 0, m_maxValue = 0; std::function m_interruptCallback; + std::function m_function; - bool m_shouldInterrupt = false; - bool m_background = true; + std::atomic m_running = false; + std::atomic m_shouldInterrupt = false; + std::atomic m_background = true; - bool m_interrupted = false; - bool m_finished = false; - bool m_hadException = false; + std::atomic m_interrupted = false; + std::atomic m_finished = false; + std::atomic m_hadException = false; std::string m_exceptionMessage; struct TaskInterruptor { virtual ~TaskInterruptor() = default; }; @@ -88,10 +93,14 @@ namespace hex { public: TaskManager() = delete; + static void init(); + static void exit(); + constexpr static auto NoProgress = 0; static TaskHolder createTask(std::string name, u64 maxValue, std::function function); static TaskHolder createBackgroundTask(std::string name, std::function function); + static void collectGarbage(); static size_t getRunningTaskCount(); @@ -104,6 +113,12 @@ namespace hex { static std::list> s_tasks; static std::list> s_deferredCalls; + + static std::mutex s_queueMutex; + static std::condition_variable s_jobCondVar; + static std::vector s_workers; + + static void runner(const std::stop_token &stopToken); }; } \ No newline at end of file diff --git a/lib/libimhex/source/api/task.cpp b/lib/libimhex/source/api/task.cpp index 41b7aa55a..7f45219b9 100644 --- a/lib/libimhex/source/api/task.cpp +++ b/lib/libimhex/source/api/task.cpp @@ -4,59 +4,47 @@ #include #include +#include namespace hex { std::mutex TaskManager::s_deferredCallsMutex; - std::list> TaskManager::s_tasks, s_backgroundTasks; + std::list> TaskManager::s_tasks; std::list> TaskManager::s_deferredCalls; - Task::Task(std::string unlocalizedName, u64 maxValue, bool background, std::function function) - : m_unlocalizedName(std::move(unlocalizedName)), m_currValue(0), m_maxValue(maxValue), m_background(background) { - this->m_thread = std::thread([this, func = std::move(function)] { - try { - func(*this); - } catch (const TaskInterruptor &) { - this->interruption(); - } catch (const std::exception &e) { - log::error("Exception in task {}: {}", this->m_unlocalizedName, e.what()); - this->exception(e.what()); - } catch (...) { - log::error("Exception in task {}", this->m_unlocalizedName); - this->exception("Unknown Exception"); - } + std::mutex TaskManager::s_queueMutex; + std::condition_variable TaskManager::s_jobCondVar; + std::vector TaskManager::s_workers; - this->finish(); - }); - } + Task::Task(std::string unlocalizedName, u64 maxValue, bool background, std::function function) + : m_unlocalizedName(std::move(unlocalizedName)), m_currValue(0), m_maxValue(maxValue), m_function(std::move(function)), m_background(background) { } Task::Task(hex::Task &&other) noexcept { - std::scoped_lock thisLock(this->m_mutex); - std::scoped_lock otherLock(other.m_mutex); + { + std::scoped_lock thisLock(this->m_mutex); + std::scoped_lock otherLock(other.m_mutex); - this->m_thread = std::move(other.m_thread); - this->m_unlocalizedName = std::move(other.m_unlocalizedName); + this->m_function = std::move(other.m_function); + this->m_unlocalizedName = std::move(other.m_unlocalizedName); + } - this->m_maxValue = other.m_maxValue; - this->m_currValue = other.m_currValue; + this->m_maxValue = u64(other.m_maxValue); + this->m_currValue = u64(other.m_currValue); - this->m_finished = other.m_finished; - this->m_hadException = other.m_hadException; - this->m_interrupted = other.m_interrupted; - this->m_shouldInterrupt = other.m_shouldInterrupt; + this->m_finished = bool(other.m_finished); + this->m_hadException = bool(other.m_hadException); + this->m_interrupted = bool(other.m_interrupted); + this->m_shouldInterrupt = bool(other.m_shouldInterrupt); + this->m_running = bool(other.m_running); } Task::~Task() { if (!this->isFinished()) this->interrupt(); - - this->m_thread.join(); } void Task::update(u64 value) { - std::scoped_lock lock(this->m_mutex); - this->m_currValue = value; if (this->m_shouldInterrupt) @@ -64,15 +52,11 @@ namespace hex { } void Task::setMaxValue(u64 value) { - std::scoped_lock lock(this->m_mutex); - this->m_maxValue = value; } void Task::interrupt() { - std::scoped_lock lock(this->m_mutex); - this->m_shouldInterrupt = true; if (this->m_interruptCallback) @@ -83,36 +67,34 @@ namespace hex { this->m_interruptCallback = std::move(callback); } - bool Task::isBackgroundTask() const { - std::scoped_lock lock(this->m_mutex); + void Task::setRunning(bool running) { + this->m_running = running; + } + bool Task::isBackgroundTask() const { return this->m_background; } bool Task::isFinished() const { - std::scoped_lock lock(this->m_mutex); - return this->m_finished; } bool Task::hadException() const { - std::scoped_lock lock(this->m_mutex); - return this->m_hadException; } bool Task::wasInterrupted() const { - std::scoped_lock lock(this->m_mutex); - return this->m_interrupted; } void Task::clearException() { - std::scoped_lock lock(this->m_mutex); - this->m_hadException = false; } + bool Task::isRunning() const { + return this->m_running; + } + std::string Task::getExceptionMessage() const { std::scoped_lock lock(this->m_mutex); @@ -132,14 +114,10 @@ namespace hex { } void Task::finish() { - std::scoped_lock lock(this->m_mutex); - this->m_finished = true; } void Task::interruption() { - std::scoped_lock lock(this->m_mutex); - this->m_interrupted = true; } @@ -169,21 +147,74 @@ namespace hex { } + void TaskManager::init() { + for (u32 i = 0; i < std::thread::hardware_concurrency(); i++) + TaskManager::s_workers.emplace_back(TaskManager::runner); + } + + void TaskManager::exit() { + for (auto &task : TaskManager::s_tasks) + task->interrupt(); + + for (auto &thread : TaskManager::s_workers) + thread.request_stop(); + + s_jobCondVar.notify_all(); + + TaskManager::s_workers.clear(); + } + + void TaskManager::runner(const std::stop_token &stopToken) { + std::mutex mutex; + while (true) { + std::shared_ptr task; + { + std::unique_lock lock(s_queueMutex); + s_jobCondVar.wait(lock, [&] { + return !s_tasks.empty() || stopToken.stop_requested(); + }); + if (stopToken.stop_requested()) + break; + + task = s_tasks.front(); + s_tasks.pop_front(); + } + + try { + task->m_function(*task); + } catch (const Task::TaskInterruptor &) { + task->interruption(); + } catch (const std::exception &e) { + log::error("Exception in task {}: {}", task->m_unlocalizedName, e.what()); + task->exception(e.what()); + } catch (...) { + log::error("Exception in task {}", task->m_unlocalizedName); + task->exception("Unknown Exception"); + } + + task->finish(); + } + } + TaskHolder TaskManager::createTask(std::string name, u64 maxValue, std::function function) { + std::unique_lock lock(s_queueMutex); s_tasks.emplace_back(std::make_shared(std::move(name), maxValue, false, std::move(function))); + s_jobCondVar.notify_one(); return TaskHolder(s_tasks.back()); } TaskHolder TaskManager::createBackgroundTask(std::string name, std::function function) { - s_backgroundTasks.emplace_back(std::make_shared(std::move(name), 0, true, std::move(function))); + std::unique_lock lock(s_queueMutex); + s_tasks.emplace_back(std::make_shared(std::move(name), 0, true, std::move(function))); + s_jobCondVar.notify_one(); - return TaskHolder(s_backgroundTasks.back()); + return TaskHolder(s_tasks.back()); } void TaskManager::collectGarbage() { + std::unique_lock lock(s_queueMutex); std::erase_if(s_tasks, [](const auto &task) { return task->isFinished() && !task->hadException(); }); - std::erase_if(s_backgroundTasks, [](const auto &task) { return task->isFinished(); }); } std::list> &TaskManager::getRunningTasks() { @@ -191,7 +222,11 @@ namespace hex { } size_t TaskManager::getRunningTaskCount() { - return s_tasks.size(); + std::unique_lock lock(s_queueMutex); + + return std::count_if(s_tasks.begin(), s_tasks.end(), [](const auto &task){ + return !task->isBackgroundTask(); + }); } diff --git a/main/source/init/splash_window.cpp b/main/source/init/splash_window.cpp index 3fd8e6eec..12fe01312 100644 --- a/main/source/init/splash_window.cpp +++ b/main/source/init/splash_window.cpp @@ -61,7 +61,7 @@ namespace hex::init { try { if (async) { - std::thread(runTask).detach(); + TaskManager::createBackgroundTask(name, [runTask](auto&){ runTask(); }); } else { runTask(); } diff --git a/main/source/main.cpp b/main/source/main.cpp index e26c8b667..e6e1249b1 100644 --- a/main/source/main.cpp +++ b/main/source/main.cpp @@ -33,6 +33,7 @@ int main(int argc, char **argv, char **envp) { init::WindowSplash splashWindow; + TaskManager::init(); for (const auto &[name, task, async] : init::getInitTasks()) splashWindow.addStartupTask(name, task, async); @@ -44,6 +45,7 @@ int main(int argc, char **argv, char **envp) { ON_SCOPE_EXIT { for (const auto &[name, task, async] : init::getExitTasks()) task(); + TaskManager::exit(); }; // Main window diff --git a/plugins/builtin/include/content/views/view_data_inspector.hpp b/plugins/builtin/include/content/views/view_data_inspector.hpp index 9638bfa83..44495be76 100644 --- a/plugins/builtin/include/content/views/view_data_inspector.hpp +++ b/plugins/builtin/include/content/views/view_data_inspector.hpp @@ -33,7 +33,9 @@ namespace hex::plugin::builtin { u64 m_startAddress = 0; size_t m_validBytes = 0; - std::vector m_cachedData; + std::atomic m_dataValid = false; + std::vector m_cachedData, m_workData; + TaskHolder m_updateTask; std::string m_editingValue; }; diff --git a/plugins/builtin/source/content/ui_items.cpp b/plugins/builtin/source/content/ui_items.cpp index c0505f104..52c355249 100644 --- a/plugins/builtin/source/content/ui_items.cpp +++ b/plugins/builtin/source/content/ui_items.cpp @@ -245,6 +245,9 @@ namespace hex::plugin::builtin { if (ImGui::BeginPopupContextItem("FrontTask", ImGuiPopupFlags_MouseButtonLeft)) { for (const auto &task : tasks) { + if (task->isBackgroundTask()) + continue; + ImGui::PushID(&task); ImGui::TextFormatted("{}", LangEntry(task->getUnlocalizedName())); ImGui::SameLine(); diff --git a/plugins/builtin/source/content/views/view_data_inspector.cpp b/plugins/builtin/source/content/views/view_data_inspector.cpp index 01fc11bd5..15b63d125 100644 --- a/plugins/builtin/source/content/views/view_data_inspector.cpp +++ b/plugins/builtin/source/content/views/view_data_inspector.cpp @@ -37,93 +37,105 @@ namespace hex::plugin::builtin { } void ViewDataInspector::drawContent() { - if (this->m_shouldInvalidate) { + if (this->m_shouldInvalidate && !this->m_updateTask.isRunning()) { this->m_shouldInvalidate = false; - this->m_cachedData.clear(); - auto provider = ImHexApi::Provider::get(); + this->m_updateTask = TaskManager::createBackgroundTask("Update Inspector", + [this, validBytes = this->m_validBytes, startAddress = this->m_startAddress, endian = this->m_endian, invert = this->m_invert, numberDisplayStyle = this->m_numberDisplayStyle](auto &) { + auto provider = ImHexApi::Provider::get(); - // Decode bytes using registered inspectors - for (auto &entry : ContentRegistry::DataInspector::getEntries()) { - if (this->m_validBytes < entry.requiredSize) - continue; + this->m_workData.clear(); - std::vector buffer(this->m_validBytes > entry.maxSize ? entry.maxSize : this->m_validBytes); - provider->read(this->m_startAddress, buffer.data(), buffer.size()); - - if (this->m_invert) { - for (auto &byte : buffer) - byte ^= 0xFF; - } - - this->m_cachedData.push_back({ - entry.unlocalizedName, - entry.generatorFunction(buffer, this->m_endian, this->m_numberDisplayStyle), - entry.editingFunction, - false - }); - } - - - // Decode bytes using custom inspectors defined using the pattern language - const std::map inVariables = { - { "numberDisplayStyle", u128(this->m_numberDisplayStyle) } - }; - - pl::PatternLanguage runtime; - ContentRegistry::PatternLanguage::configureRuntime(runtime, nullptr); - - runtime.setDataSource([this, provider](u64 offset, u8 *buffer, size_t size) { - provider->read(offset, buffer, size); - - if (this->m_invert) { - for (size_t i = 0; i < size; i++) - buffer[i] ^= 0xFF; - } - }, provider->getBaseAddress(), provider->getActualSize()); - - runtime.setDangerousFunctionCallHandler([]{ return false; }); - runtime.setDefaultEndian(this->m_endian); - runtime.setStartAddress(this->m_startAddress); - - for (const auto &folderPath : fs::getDefaultPaths(fs::ImHexPath::Inspectors)) { - for (const auto &filePath : std::fs::recursive_directory_iterator(folderPath)) { - if (!filePath.exists() || !filePath.is_regular_file() || filePath.path().extension() != ".hexpat") + // Decode bytes using registered inspectors + for (auto &entry : ContentRegistry::DataInspector::getEntries()) { + if (validBytes < entry.requiredSize) continue; - fs::File file(filePath, fs::File::Mode::Read); - if (file.isValid()) { - auto inspectorCode = file.readString(); + std::vector buffer(validBytes > entry.maxSize ? entry.maxSize : validBytes); + provider->read(startAddress, buffer.data(), buffer.size()); - if (!inspectorCode.empty()) { - if (runtime.executeString(inspectorCode, {}, inVariables, true)) { - const auto &patterns = runtime.getAllPatterns(); + if (invert) { + for (auto &byte : buffer) + byte ^= 0xFF; + } - for (const auto &pattern : patterns) { - if (pattern->isHidden()) - continue; + this->m_workData.push_back({ + entry.unlocalizedName, + entry.generatorFunction(buffer, endian, numberDisplayStyle), + entry.editingFunction, + false + }); + } - this->m_cachedData.push_back({ - pattern->getDisplayName(), - [value = pattern->getFormattedValue()]() { - ImGui::TextUnformatted(value.c_str()); - return value; - }, - std::nullopt, - false - }); + + // Decode bytes using custom inspectors defined using the pattern language + const std::map inVariables = { + { "numberDisplayStyle", u128(numberDisplayStyle) } + }; + + pl::PatternLanguage runtime; + ContentRegistry::PatternLanguage::configureRuntime(runtime, nullptr); + + runtime.setDataSource([invert, provider](u64 offset, u8 *buffer, size_t size) { + provider->read(offset, buffer, size); + + if (invert) { + for (size_t i = 0; i < size; i++) + buffer[i] ^= 0xFF; + } + }, provider->getBaseAddress(), provider->getActualSize()); + + runtime.setDangerousFunctionCallHandler([]{ return false; }); + runtime.setDefaultEndian(endian); + runtime.setStartAddress(startAddress); + + for (const auto &folderPath : fs::getDefaultPaths(fs::ImHexPath::Inspectors)) { + for (const auto &filePath : std::fs::recursive_directory_iterator(folderPath)) { + if (!filePath.exists() || !filePath.is_regular_file() || filePath.path().extension() != ".hexpat") + continue; + + fs::File file(filePath, fs::File::Mode::Read); + if (file.isValid()) { + auto inspectorCode = file.readString(); + + if (!inspectorCode.empty()) { + if (runtime.executeString(inspectorCode, {}, inVariables, true)) { + const auto &patterns = runtime.getAllPatterns(); + + for (const auto &pattern : patterns) { + if (pattern->isHidden()) + continue; + + this->m_workData.push_back({ + pattern->getDisplayName(), + [value = pattern->getFormattedValue()]() { + ImGui::TextUnformatted(value.c_str()); + return value; + }, + std::nullopt, + false + }); + } + } else { + const auto& error = runtime.getError(); + + log::error("Failed to execute inspectors.hexpat!"); + if (error.has_value()) + log::error("{}", error.value().what()); } - } else { - const auto& error = runtime.getError(); - - log::error("Failed to execute inspectors.hexpat!"); - if (error.has_value()) - log::error("{}", error.value().what()); } } } } - } + + this->m_dataValid = true; + + }); + } + + if (this->m_dataValid) { + this->m_dataValid = false; + this->m_cachedData = this->m_workData; } if (ImGui::Begin(View::toWindowName("hex.builtin.view.data_inspector.name").c_str(), &this->getWindowOpenState(), ImGuiWindowFlags_NoCollapse)) { diff --git a/plugins/builtin/source/content/views/view_pattern_editor.cpp b/plugins/builtin/source/content/views/view_pattern_editor.cpp index b9f8423a3..de3537a7b 100644 --- a/plugins/builtin/source/content/views/view_pattern_editor.cpp +++ b/plugins/builtin/source/content/views/view_pattern_editor.cpp @@ -411,13 +411,12 @@ namespace hex::plugin::builtin { this->m_hasUnevaluatedChanges = false; - std::thread([this, code = this->m_textEditor.GetText()]{ + TaskManager::createBackgroundTask("Pattern Parsing", [this, code = this->m_textEditor.GetText()](auto &){ this->parsePattern(code); if (this->m_runAutomatically) this->evaluatePattern(code); - }).detach(); - + }); } } @@ -775,7 +774,8 @@ namespace hex::plugin::builtin { this->evaluatePattern(code); this->m_textEditor.SetText(code); - std::thread([this, code] { this->parsePattern(code); }).detach(); + + TaskManager::createBackgroundTask("Parse pattern", [this, code](auto&) { this->parsePattern(code); }); } } diff --git a/plugins/windows/source/views/view_tty_console.cpp b/plugins/windows/source/views/view_tty_console.cpp index 0208f52d1..93a81a17e 100644 --- a/plugins/windows/source/views/view_tty_console.cpp +++ b/plugins/windows/source/views/view_tty_console.cpp @@ -302,7 +302,7 @@ namespace hex::plugin::windows { if (this->m_transmitting) return; - auto transmitThread = std::thread([&, this] { + TaskManager::createBackgroundTask("Transmitting data", [&, this](auto&) { OVERLAPPED overlapped = { }; overlapped.hEvent = ::CreateEvent(nullptr, true, false, nullptr); @@ -322,7 +322,6 @@ namespace hex::plugin::windows { this->m_transmitting = false; }); - transmitThread.detach(); } } \ No newline at end of file