Detour GetModuleHandleExW function in x86

vcruntime dispatches tasks to Windows thread pool. When task finishes
the completion callback is invoked. In order to ensure the dll that
contais that callback code is still loaded, the refcount for the dll is
incremented (via GetModuleHandleExW) when task is scheduled, and
decremented (vie FreeLibrary) after callback finishes.

FreeLibrary called with a handle to unregistered module returns an
error, which is converted into unhandled exception and resulting in
crash.
dependabot/npm_and_yarn/Src/WebController/UI/websocket-extensions-0.1.4
Grzegorz Rychlik 2020-01-29 09:58:16 +01:00
parent ba5617a5e1
commit af219394af
1 changed files with 37 additions and 10 deletions

View File

@ -29,7 +29,7 @@ namespace MWR::Loader
} moduleData;
#if defined _M_AMD64
void* RtlPcToFileHeaderHook(PVOID pc, PVOID* baseOfImage)
void* RtlPcToFileHeaderDetour(PVOID pc, PVOID* baseOfImage)
{
if (pc > (void*)moduleData.m_DllBaseAddress and pc < (void*)(moduleData.m_DllBaseAddress + moduleData.m_SizeOfTheDll))
{
@ -41,15 +41,41 @@ namespace MWR::Loader
return RtlPcToFileHeader(pc, baseOfImage);
}
}
#elif defined _M_IX86
BOOL GetModuleHandleExWDetour(DWORD dwFlags, LPCWSTR lpModuleName, HMODULE* phModule)
{
// try to filter out different call sites by checking if the flags match the ones set by `msvcp140.dll.Concurrency::details::'anonymous namespace'::_Task_scheduler_callback`
// and if the address is within our Dll
auto addr = reinterpret_cast<const void*>(lpModuleName);
if ((dwFlags == (GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT)) and
addr > (void*)moduleData.m_DllBaseAddress and addr < (void*)(moduleData.m_DllBaseAddress + moduleData.m_SizeOfTheDll))
return 0;
else
return GetModuleHandleExW(dwFlags, lpModuleName, phModule);
}
#endif
void* GetHookAddress(const char* dllName, const char* funcName)
/// Get address of detour function
/// @param dllName - name of dll to import function from
/// @param funcName - name of a function to import
/// @returns If function should be redirected - address of detour function, else nullptr
void* GetDetourAddress(const char* dllName, const char* funcName)
{
#if defined _M_AMD64
// detour RtlPcToFileHeader to return our Dll base address, std::eception throwing and exception_ptr creation use this address
if (_stricmp(dllName,"kernel32.dll") == 0 && strcmp(funcName, "RtlPcToFileHeader") == 0)
return (void*)RtlPcToFileHeaderHook;
#endif
return (void*)RtlPcToFileHeaderDetour;
#elif defined _M_IX86
// detour GetModuleHandleExW because Windows thread pool tries to free our Dll on callback completion
// see msvcp140.dll.Concurrency::details::`anonymous namespace'::_Task_scheduler_callback
// see C:\Program Files (x86)\Microsoft Visual Studio\2017\Professional\VC\Tools\MSVC\14.16.27023\crt\src\stl\taskscheduler.cpp -> around line 147
if (_stricmp(dllName, "kernel32.dll") == 0 && strcmp(funcName, "GetModuleHandleExW") == 0)
return (void*)GetModuleHandleExWDetour;
#endif
return nullptr;
}
@ -100,6 +126,10 @@ namespace MWR::Loader
if (!baseAddress)
return 1;
// set global module data
moduleData.m_DllBaseAddress = baseAddress;
moduleData.m_SizeOfTheDll = ntHeaders->OptionalHeader.SizeOfImage;
/// Copy headers
memcpy((void*)baseAddress, dllData, ntHeaders->OptionalHeader.SizeOfHeaders);
@ -173,7 +203,8 @@ namespace MWR::Loader
else
{
auto importByName = Rva2Va<PIMAGE_IMPORT_BY_NAME>(baseAddress, origFirstThunk->u1.AddressOfData);
void* addr = GetHookAddress(libName, importByName->Name);
// check if the function should be detoured by redirecting the imported function address
void* addr = GetDetourAddress(libName, importByName->Name);
if (!addr)
addr = GetProcAddress((HMODULE)libraryAddress, importByName->Name);
firstThunk->u1.Function = (ULONG_PTR)addr;
@ -209,7 +240,7 @@ namespace MWR::Loader
else
{
auto importByName = Rva2Va<PIMAGE_IMPORT_BY_NAME>(baseAddress, origFirstThunk->u1.AddressOfData);
void* addr = GetHookAddress(libName, importByName->Name);
void* addr = GetDetourAddress(libName, importByName->Name);
if (!addr)
addr = GetProcAddress((HMODULE)libraryAddress, importByName->Name);
firstThunk->u1.Function = (ULONG_PTR)addr;
@ -294,10 +325,6 @@ namespace MWR::Loader
return 1;
}
// register VEH
moduleData.m_DllBaseAddress = baseAddress;
moduleData.m_SizeOfTheDll = ntHeaders->OptionalHeader.SizeOfImage;
#elif defined _M_IX86
MWR::Loader::UnexportedWinApi::RtlInsertInvertedFunctionTable((void*)baseAddress, ntHeaders->OptionalHeader.SizeOfImage);
#endif