Merge branch 'idea90' of git.labs.intellij.net:idea/community into idea90
[fedora-idea.git] / native / focusKiller / HookImportFunction.cpp
blobab11935a25da31db5410c069878649b4eb4b7fde
1 /*
2 Module : HookImportFunction.cpp
3 Purpose: Defines the implementation for code to hook a call to any imported Win32 SDK
4 Created: PJN / 23-10-1999
5 History: PJN / 01-01-2001 1. Now includes copyright message in the source code and documentation.
6 2. Fixed an access violation in where I was getting the name of the import
7 function but not checking for failure.
8 3. Fixed a compiler error where I was incorrectly casting to a PDWORD instead
9 of a DWORD
10 PJN / 20-04-2002 1. Fixed a potential infinite loop in HookImportFunctionByName. Thanks to
11 David Defoort for spotting this problem.
13 Copyright (c) 1996 - 2002 by PJ Naughter. (Web: www.naughter.com, Email: pjna@naughter.com)
15 All rights reserved.
17 Copyright / Usage Details:
19 You are allowed to include the source code in any product (commercial, shareware, freeware or otherwise)
20 when your product is released in binary form. You are allowed to modify the source code in any way you want
21 except you cannot modify the copyright details at the top of each module. If you want to distribute source
22 code with your application, then you are only allowed to distribute versions released by the author. This is
23 to maintain a single distribution point for the source code.
28 ////////////////// Includes ////////////////////////////////////
30 #include <windows.h>
31 #include "HookImportFunction.h"
33 #define ASSERT(e)
34 #define VERIFY(e) e
35 #define TRACE0(s) OutputDebugString(s)
36 #define _T(s) s
38 ////////////////// Defines / Locals ////////////////////////////
40 #ifdef _DEBUG
41 #define new DEBUG_NEW
42 #undef THIS_FILE
43 static char THIS_FILE[] = __FILE__;
44 #endif
46 #define MakePtr(cast, ptr, AddValue) (cast)((DWORD)(ptr)+(DWORD)(AddValue))
48 BOOL IsNT();
52 ////////////////// Implementation //////////////////////////////
54 BOOL HookImportFunctionsByName(HMODULE hModule, LPCSTR szImportMod, UINT uiCount,
55 LPHOOKFUNCDESC paHookArray, PROC* paOrigFuncs, UINT* puiHooked)
57 // Double check the parameters.
58 ASSERT(szImportMod);
59 ASSERT(uiCount);
60 ASSERT(!IsBadReadPtr(paHookArray, sizeof(HOOKFUNCDESC)*uiCount));
62 #ifdef _DEBUG
63 if (paOrigFuncs)
64 ASSERT(!IsBadWritePtr(paOrigFuncs, sizeof(PROC)*uiCount));
65 if (puiHooked)
66 ASSERT(!IsBadWritePtr(puiHooked, sizeof(UINT)));
68 //Check each function name in the hook array.
69 for (UINT i = 0; i<uiCount; i++)
71 ASSERT(paHookArray[i].szFunc);
72 ASSERT(*paHookArray[i].szFunc != _T('\0'));
74 //If the proc is not NULL, then it is checked.
75 if (paHookArray[i].pProc)
76 ASSERT(!IsBadCodePtr(paHookArray[i].pProc));
78 #endif
80 //Do the parameter validation for real.
81 if (uiCount == 0 || szImportMod == NULL || IsBadReadPtr(paHookArray, sizeof(HOOKFUNCDESC)* uiCount))
83 ASSERT(FALSE);
84 SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR);
85 return FALSE;
88 if (paOrigFuncs && IsBadWritePtr(paOrigFuncs, sizeof(PROC)*uiCount))
90 ASSERT(FALSE);
91 SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR);
92 return FALSE;
95 if (puiHooked && IsBadWritePtr(puiHooked, sizeof(UINT)))
97 ASSERT(FALSE);
98 SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR );
99 return FALSE;
102 //Is this a system DLL, which Windows95 will not let you patch
103 //since it is above the 2GB line?
104 if (!IsNT() && ((DWORD)hModule >= 0x80000000))
106 #ifdef _DEBUG
107 CString sMsg;
108 sMsg.Format(_T("Could not hook module %x because we are on Win9x and it is in shared memory\n"), hModule);
109 OutputDebugString(sMsg);
110 #endif
111 SetLastErrorEx(ERROR_INVALID_HANDLE, SLE_ERROR);
112 return FALSE;
115 //TODO TODO
116 // Should each item in the hook array be checked in release builds?
118 if (puiHooked)
119 *puiHooked = 0; //Set the number of functions hooked to zero.
121 //Get the specific import descriptor.
122 PIMAGE_IMPORT_DESCRIPTOR pImportDesc = GetNamedImportDescriptor(hModule, szImportMod);
123 if (NULL == pImportDesc)
124 return FALSE; // The requested module was not imported.
126 HINSTANCE hImportMod = GetModuleHandle(szImportMod);
127 if (NULL == hImportMod)
129 ASSERT(FALSE);
130 SetLastErrorEx(ERROR_HOOK_NEEDS_HMOD, SLE_ERROR);
131 return FALSE; // The requested module was not available.
134 //Set all the values in paOrigFuncs to NULL.
135 if (NULL != paOrigFuncs)
136 memset(paOrigFuncs, NULL, sizeof(PROC)*uiCount);
138 //Get the original thunk information for this DLL. I cannot use
139 // the thunk information stored in the pImportDesc->FirstThunk
140 // because the that is the array that the loader
141 // has already bashed to fix up all the imports.
142 // This pointer gives us acess to the function names.
143 PIMAGE_THUNK_DATA pOrigThunk = MakePtr(PIMAGE_THUNK_DATA, hModule, pImportDesc->OriginalFirstThunk);
145 //Get the array pointed to by the pImportDesc->FirstThunk.
146 // This is where I will do the actual bash.
147 PIMAGE_THUNK_DATA pRealThunk = MakePtr(PIMAGE_THUNK_DATA, hModule, pImportDesc->FirstThunk);
149 //Loop through and look for the one that matches the name.
150 for (; NULL != pOrigThunk->u1.Function;
151 // Increment both tables.
152 pOrigThunk++, pRealThunk++)
154 //Only look at those that are imported by name, not ordinal.
155 if (IMAGE_ORDINAL_FLAG == (IMAGE_ORDINAL_FLAG & pOrigThunk->u1.Ordinal))
156 continue;
158 //Look get the name of this imported function.
159 PIMAGE_IMPORT_BY_NAME pByName = MakePtr(PIMAGE_IMPORT_BY_NAME, hModule, pOrigThunk->u1.AddressOfData);
161 if (IsBadReadPtr(pByName, MAX_PATH+4))
163 SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR);
164 continue;
167 //If the name starts with NULL, then just skip to next.
168 if (_T('\0') == pByName->Name[0])
169 continue;
171 //Determines if we do the hook.
172 BOOL bDoHook = FALSE;
174 //TODO {
175 // Might want to consider bsearch here.
176 //TODO }
177 //See if the particular function name is in the import
178 // list. It might be good to consider requiring the
179 // paHookArray to be in sorted order so bsearch could be
180 // used so the lookup will be faster. However, the size of
181 // uiCount coming into this function should be rather small
182 // but it is called for each function imported by szImportMod.
183 for (UINT i = 0; i<uiCount; i++)
185 if ((paHookArray[i].szFunc[0] == pByName->Name[0]) &&
186 (strcmpi(paHookArray[i].szFunc, (char*)pByName->Name) == 0))
188 //If the proc is NULL, kick out, otherwise
189 // go ahead and hook it.
190 if (paHookArray[i].pProc)
191 bDoHook = TRUE;
192 break;
196 if (FALSE == bDoHook)
197 continue;
199 // I found it. Now I need to change the protection to
200 // writable before I do the blast. Note that I am now
201 // blasting into the real thunk area!
202 MEMORY_BASIC_INFORMATION mbi_thunk;
203 VirtualQuery(pRealThunk, &mbi_thunk, sizeof(MEMORY_BASIC_INFORMATION));
204 VERIFY(VirtualProtect(mbi_thunk.BaseAddress, mbi_thunk.RegionSize, PAGE_READWRITE, &mbi_thunk.Protect));
206 // Get fast/simple pointer
207 PROC* pFunction = (PROC*) &(pRealThunk->u1.Function);
208 if (*pFunction == paHookArray[i].pProc)
210 SetLastErrorEx(ERROR_ALREADY_INITIALIZED, SLE_ERROR);
211 return FALSE;
213 if (IsBadCodePtr(*pFunction))
215 ASSERT(FALSE);
216 SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR);
217 return FALSE;
219 //Save the original address if requested.
220 if (NULL != paOrigFuncs)
222 if ((DWORD)(*pFunction) < (DWORD)hImportMod && ((DWORD)(0x80000000) > (DWORD)hImportMod))
224 ASSERT(FALSE);
225 SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR);
226 return FALSE;
228 if (*pFunction != paOrigFuncs[i])
230 if (NULL != paOrigFuncs[i])
232 if (paHookArray[i].pProc != paOrigFuncs[i])
234 ASSERT(FALSE);
235 SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR);
236 return FALSE;
239 paOrigFuncs[i] = * pFunction;
242 //Do the actual hook.
243 *pFunction = paHookArray[i].pProc;
245 //Increment the total number hooked.
246 if (puiHooked)
247 *puiHooked += 1;
249 //Change the protection back to what it was before I blasted.
250 DWORD dwOldProtect;
251 VERIFY(VirtualProtect(mbi_thunk.BaseAddress, mbi_thunk.RegionSize, mbi_thunk.Protect, &dwOldProtect));
253 //All OK, JumpMaster!
254 SetLastError(ERROR_SUCCESS);
255 return TRUE;
258 PIMAGE_IMPORT_DESCRIPTOR GetNamedImportDescriptor(HMODULE hModule, LPCSTR szImportMod)
260 //Always check parameters.
261 ASSERT(szImportMod);
262 ASSERT(hModule);
263 if ((szImportMod == NULL) || (hModule == NULL))
265 ASSERT(FALSE);
266 SetLastErrorEx(ERROR_INVALID_PARAMETER, SLE_ERROR);
267 return NULL;
270 //Get the Dos header.
271 PIMAGE_DOS_HEADER pDOSHeader = (PIMAGE_DOS_HEADER) hModule;
273 // Is this the MZ header?
274 if (IsBadReadPtr(pDOSHeader, sizeof(IMAGE_DOS_HEADER)) || (pDOSHeader->e_magic != IMAGE_DOS_SIGNATURE))
276 #ifdef _DEBUG
277 CString sMsg;
278 sMsg.Format(_T("Could not find the MZ Header for %x\n"), hModule);
279 OutputDebugString(sMsg);
280 #endif
281 SetLastErrorEx( ERROR_BAD_EXE_FORMAT, SLE_ERROR);
282 return NULL;
285 // Get the PE header.
286 PIMAGE_NT_HEADERS pNTHeader = MakePtr(PIMAGE_NT_HEADERS, pDOSHeader, pDOSHeader->e_lfanew);
288 //Is this a real PE image?
289 if (IsBadReadPtr(pNTHeader, sizeof(IMAGE_NT_HEADERS)) || (pNTHeader->Signature != IMAGE_NT_SIGNATURE))
291 ASSERT(FALSE);
292 SetLastErrorEx( ERROR_INVALID_EXE_SIGNATURE, SLE_ERROR);
293 return NULL;
296 //If there is no imports section, leave now.
297 if (pNTHeader->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress == 0)
298 return NULL;
300 // Get the pointer to the imports section.
301 PIMAGE_IMPORT_DESCRIPTOR pImportDesc = MakePtr(PIMAGE_IMPORT_DESCRIPTOR, pDOSHeader, pNTHeader->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress);
303 //Loop through the import module descriptors looking for the module whose name matches szImportMod.
304 while (pImportDesc->Name)
306 PSTR szCurrMod = MakePtr(PSTR, pDOSHeader, pImportDesc->Name);
307 if (stricmp(szCurrMod, szImportMod) == 0)
308 break; // Found it.
310 //Look at the next one.
311 pImportDesc++;
314 //If the name is NULL, then the module is not imported.
315 if (pImportDesc->Name == NULL)
316 return NULL;
318 //All OK, Jumpmaster!
319 return pImportDesc;
322 BOOL IsNT()
324 OSVERSIONINFO stOSVI;
325 memset(&stOSVI, NULL, sizeof(OSVERSIONINFO));
326 stOSVI.dwOSVersionInfoSize = sizeof(OSVERSIONINFO);
328 BOOL bRet = GetVersionEx(&stOSVI);
329 ASSERT(TRUE == bRet);
330 if (FALSE == bRet)
332 TRACE0("GetVersionEx failed!\n");
333 return FALSE;
336 //Check the version and call the appropriate thing.
337 return (VER_PLATFORM_WIN32_NT == stOSVI.dwPlatformId);