diff --git a/client/sources/MyLoadLibrary.c b/client/sources/MyLoadLibrary.c index 8923801a..455d7a29 100644 --- a/client/sources/MyLoadLibrary.c +++ b/client/sources/MyLoadLibrary.c @@ -17,6 +17,8 @@ typedef struct { PSTR name; PSTR fileName; + LPTOP_LEVEL_EXCEPTION_FILTER ehFilter; + HCUSTOMMODULE module; int refcount; @@ -287,6 +289,7 @@ static PHCUSTOMLIBRARY _AddMemoryModule( hmodule->name = strdup(srcName); hmodule->fileName = strdup(name); hmodule->module = module; + hmodule->ehFilter = NULL; _strupr(hmodule->name); _strupr(hmodule->fileName); @@ -673,10 +676,32 @@ LPVOID CALLBACK MyLoadResource(HMODULE hModule, HRSRC resource) return res; } -VOID MySetUnhandledExceptionFilter(LPTOP_LEVEL_EXCEPTION_FILTER handler) +BOOL MySetUnhandledExceptionFilter(LPCSTR pszModuleName, LPTOP_LEVEL_EXCEPTION_FILTER handler) { - lpDefaultExceptionHandler = handler; - dprint("Set default thread handler to %p\n", lpDefaultExceptionHandler); + PHCUSTOMLIBRARY lib; + + if (!pszModuleName) { + lpDefaultExceptionHandler = handler; + dprint("Set default thread handler to %p\n", lpDefaultExceptionHandler); + return TRUE; + } + + lib = _FindMemoryModule(pszModuleName, NULL); + if (!lib) { + dprint( + "Failed to set default thread handler for %s to %p - module not found\n", + pszModuleName, lpDefaultExceptionHandler + ); + } + + lib->ehFilter = handler; + + dprint( + "Set default thread handler for %s to %p\n", + pszModuleName, lpDefaultExceptionHandler + ); + + return TRUE; } LPTOP_LEVEL_EXCEPTION_FILTER MyGetUnhandledExceptionFilter(VOID) { @@ -684,11 +709,51 @@ LPTOP_LEVEL_EXCEPTION_FILTER MyGetUnhandledExceptionFilter(VOID) { } LONG WINAPI ThreadUnhandledExceptionFilter( - PEXCEPTION_POINTERS pExceptionPointers, LPTOP_LEVEL_EXCEPTION_FILTER lpFilter, LONG lResult + DWORD dwExceptionCode, PEXCEPTION_POINTERS pExceptionPointers, + PVOID pvThreadProc, LPTOP_LEVEL_EXCEPTION_FILTER lpFilter, LONG lResult ) { - dprint("ThreadUnhandledExceptionFilter; Process exception info at %p\n", pExceptionPointers); - lpFilter(pExceptionPointers); - return lResult; + LPCSTR pszName = NULL; + LPTOP_LEVEL_EXCEPTION_FILTER lpCustomFilter = NULL; + + LONG lVerdict = EXCEPTION_CONTINUE_SEARCH; + + if (dwExceptionCode == EXCEPTION_BREAKPOINT) { + dprint( + "ThreadUnhandledExceptionFilter (ThreadProc=%p): hit breakpoint (???) - ignore", + pvThreadProc + ); + + return EXCEPTION_CONTINUE_SEARCH; + } + + if (MyFindMemoryModuleNameByAddr(pvThreadProc, &pszName, NULL, &lpCustomFilter)) { + dprint( + "ThreadUnhandledExceptionFilter (ThreadProc=%p) Fatal exception, " + "original ThreadProc from %s\n", + pvThreadProc, pszName + ); + + if (lpCustomFilter) { + lpFilter = lpCustomFilter; + + dprint( + "Using custom exception filter for %s: %p\n", + pszName, lpCustomFilter + ); + + } + } else { + dprint( + "ThreadUnhandledExceptionFilter (ThreadProc=%p); " + "Handling fatal exception with filter %p\n", + pvThreadProc, lpFilter + ); + } + + if (lpFilter) + lVerdict = lpFilter(pExceptionPointers); + + return lVerdict; } static DWORD WINAPI WrappedThreadRoutine(LPVOID lpThreadParameter) @@ -713,9 +778,18 @@ static DWORD WINAPI WrappedThreadRoutine(LPVOID lpThreadParameter) ); } __except(ThreadUnhandledExceptionFilter( - GetExceptionInformation(), OriginalThreadArgs.lpExceptionFilter, EXCEPTION_CONTINUE_SEARCH + GetExceptionCode(), GetExceptionInformation(), + OriginalThreadArgs.lpOriginalRoutine, + OriginalThreadArgs.lpExceptionFilter, EXCEPTION_CONTINUE_SEARCH )) { - dprint("Thread wrapper caught fatal error (%p)\n", OriginalThreadArgs.lpOriginalRoutine); + dprint( + "Thread wrapper with original args: %p(%p) fatal error " + "and will die, but we'll try to countine\n", + OriginalThreadArgs.lpOriginalRoutine, + OriginalThreadArgs.lpOriginalParameter + ); + + return (DWORD)(-1); } dprint("Thread wrapper exited (%p)\n", OriginalThreadArgs.lpOriginalRoutine); @@ -731,26 +805,24 @@ HANDLE CALLBACK MyCreateThread( LPDWORD lpThreadId ) { + PORIGINAL_THREAD_ARGS pOriginalArgsCopy = LocalAlloc( + LMEM_FIXED, sizeof(ORIGINAL_THREAD_ARGS) + ); + dprint( "MyCreateThread(func=%p, args=%p eh=%p)\n", lpStartAddress, lpParameter, lpDefaultExceptionHandler ); - if (lpDefaultExceptionHandler) { - PORIGINAL_THREAD_ARGS pOriginalArgsCopy = LocalAlloc( - LMEM_FIXED, sizeof(ORIGINAL_THREAD_ARGS) - ); + if (pOriginalArgsCopy) { + pOriginalArgsCopy->lpOriginalRoutine = lpStartAddress; + pOriginalArgsCopy->lpOriginalParameter = lpParameter; + pOriginalArgsCopy->lpExceptionFilter = lpDefaultExceptionHandler; - if (pOriginalArgsCopy) { - pOriginalArgsCopy->lpOriginalRoutine = lpStartAddress; - pOriginalArgsCopy->lpOriginalParameter = lpParameter; - pOriginalArgsCopy->lpExceptionFilter = lpDefaultExceptionHandler; - - lpStartAddress = WrappedThreadRoutine; - lpParameter = (PVOID) pOriginalArgsCopy; - } else { - dprint("MyCreateThread: LocalAlloc failed\n"); - } + lpStartAddress = WrappedThreadRoutine; + lpParameter = (PVOID) pOriginalArgsCopy; + } else { + dprint("MyCreateThread: LocalAlloc failed\n"); } return CreateThread( @@ -794,3 +866,38 @@ VOID MyEnumerateLibraries(LibraryInfoCb_t callback, PVOID pvCallbackData) dprint("Enumerating libraries: %p - complete\n", libraries); } + +BOOL MyFindMemoryModuleNameByAddr( + PVOID pvAddress, LPCSTR *ppszName, PVOID *ppvBaseAddress, + LPTOP_LEVEL_EXCEPTION_FILTER *pehFilter +) { + PHCUSTOMLIBRARY module, tmp; + UINT_PTR uiAddress = (UINT_PTR) pvAddress; + + if (!pvAddress) + return FALSE; + + HASH_ITER(by_module, libraries->by_module, module, tmp) { + PVOID pvBaseAddress = NULL; + ULONG ulSize = 0; + + if (GetMemoryModuleInfo(module->module, &pvBaseAddress, &ulSize)) { + UINT_PTR uiBaseAddress = (UINT_PTR) pvBaseAddress; + + if (uiAddress >= uiBaseAddress && uiAddress <= (uiBaseAddress + ulSize)) { + if (ppszName) + *ppszName = module->name; + + if (ppvBaseAddress) + *ppvBaseAddress = pvBaseAddress; + + if (pehFilter) + *pehFilter = module->ehFilter; + + return TRUE; + } + } + } + + return FALSE; +} diff --git a/client/sources/MyLoadLibrary.h b/client/sources/MyLoadLibrary.h index 06b7e5fd..8482866c 100644 --- a/client/sources/MyLoadLibrary.h +++ b/client/sources/MyLoadLibrary.h @@ -42,7 +42,10 @@ HANDLE CALLBACK MyCreateThread( ); VOID MySetLibraries(PVOID pLibraries); -VOID MySetUnhandledExceptionFilter(LPTOP_LEVEL_EXCEPTION_FILTER handler); +BOOL MySetUnhandledExceptionFilter( + LPCSTR pszModuleName, LPTOP_LEVEL_EXCEPTION_FILTER handler +); + LPTOP_LEVEL_EXCEPTION_FILTER MyGetUnhandledExceptionFilter(VOID); PVOID MyGetLibraries(); @@ -51,6 +54,10 @@ typedef BOOL (*LibraryInfoCb_t) ( ); VOID MyEnumerateLibraries(LibraryInfoCb_t callback, PVOID pvCallbackData); +BOOL MyFindMemoryModuleNameByAddr( + PVOID pvAddress, LPCSTR *ppszName, PVOID *ppvBaseAddress, + LPTOP_LEVEL_EXCEPTION_FILTER *pehFilter +); #ifndef DLL_QUERY_HMODULE #define DLL_QUERY_HMODULE 6