Win32: use safe bit transfromation
[git-rebase2.git] / app / Rehi / Win32bits.hsc
blobb816e2140e78d56914941a4cb48a0f7f5a2120ac
1 {-# LANGUAGE ScopedTypeVariables #-}
2 {-# LANGUAGE TypeFamilies #-}
3 module Rehi.Win32bits where
5 import Data.Bits
6 import Data.Int
7 import Data.Proxy
8 import Data.Typeable
9 import Data.Word
10 import Foreign.Ptr(Ptr(), nullPtr)
11 import Foreign.Marshal.Alloc (allocaBytesAligned, alloca)
12 import Foreign.Storable
13 import Numeric (showHex)
15 import qualified System.Win32.Types as WT
17 #include <windef.h>
18 #include <winnt.h>
19 #include <ntdef.h>
21 #if defined(i386_HOST_ARCH)
22 #let WINDOWS_CCONV = "stdcall"
23 #elif defined(x86_64_HOST_ARCH)
24 #let WINDOWS_CCONV = "ccall"
25 #else
26 # error Unknown mingw32 arch
27 #endif
29 #def typedef struct __PUBLIC_OBJECT_TYPE_INFORMATION {
30       UNICODE_STRING TypeName;
31       ULONG Reserved [22];
32   } PUBLIC_OBJECT_TYPE_INFORMATION;
34 #def typedef enum {
35     ObjectNameInformation = 1,
36   } OBJECT_INFORMATION_CLASS;
38 #def NTSTATUS NtQueryObject(
39   HANDLE                   Handle,
40   OBJECT_INFORMATION_CLASS ObjectInformationClass,
41   *PUBLIC_OBJECT_TYPE_INFORMATION  ObjectInformation,
42   ULONG                    ObjectInformationLength,
43   PULONG                   ReturnLength
44  );
46 -- https://stackoverflow.com/a/8354582/2303202
47 -- https://wiki.haskell.org/FFICookBook#Working_with_structs
48 #if __GLASGOW_HASKELL__ < 800
49 #let alignment t = "%lu", (unsigned long)offsetof(struct {char x__; t (y__); }, y__)
50 #endif
52 data NT_OBJECT_NAME_INFORMATION = NT_OBJECT_NAME_INFORMATION
53     { noniLength :: WT.USHORT
54     , noniMaximumLength :: WT.USHORT
55     , noniBuffer :: WT.LPWSTR }
57 instance Storable NT_OBJECT_NAME_INFORMATION where
58   sizeOf = const #{size PUBLIC_OBJECT_TYPE_INFORMATION}
59   alignment = const #{alignment PUBLIC_OBJECT_TYPE_INFORMATION}
60   peek p = NT_OBJECT_NAME_INFORMATION
61             <$> #{peek UNICODE_STRING, Length} p
62             <*> #{peek UNICODE_STRING, MaximumLength} p
63             <*> #{peek UNICODE_STRING, Buffer} p
64   poke p o = do
65     #{poke UNICODE_STRING, Length} p (noniLength o)
66     #{poke UNICODE_STRING, MaximumLength} p (noniMaximumLength o)
67     #{poke UNICODE_STRING, Buffer} p (noniBuffer o)
69 type ObjectInformationClass = #{type OBJECT_INFORMATION_CLASS}
71 #enum ObjectInformationClass, , hs_ObjectNameInformation = ObjectNameInformation
73 type family Unsigned t :: *
74 type instance Unsigned Int32 = Word32
76 type NTSTATUS = Unsigned #{type NTSTATUS}
78 type ULONG = WT.DWORD
80 foreign import #{WINDOWS_CCONV} "NtQueryObject"
81   c_NtQueryObject :: WT.HANDLE
82                   -> ObjectInformationClass
83                   -> Ptr NT_OBJECT_NAME_INFORMATION
84                   -> ULONG
85                   -> Ptr ULONG
86                   -> IO NTSTATUS
88 toIntegralSizedM :: forall a b . (Bits a, Integral a, Show a, Bits b, Integral b, Typeable b) => a -> IO b
89 toIntegralSizedM v = case toIntegralSized v of
90   Just res -> pure res
91   Nothing -> fail ("Cannot cast from " ++ show v ++ " to " ++ show (typeRep (undefined :: Proxy b)))
93 getFileNameInformation :: WT.HANDLE -> IO String
94 getFileNameInformation h =
95   alloca $ \ (p_len :: Ptr ULONG) -> do
96     checkNtStatus
97       (== 0xC0000004) -- STATUS_INFO_LENGTH_MISMATCH
98       $ c_NtQueryObject h hs_ObjectNameInformation nullPtr 0 p_len
99     len <- peek p_len
100     len_signed <- toIntegralSizedM len
101     allocaBytesAligned
102         len_signed
103         (alignment (undefined :: NT_OBJECT_NAME_INFORMATION))
104         $ \ p_oni -> do
105       checkNtStatus
106         (\s -> s >= 0 && s <= 0x7FFFFFFF) -- https://msdn.microsoft.com/en-us/library/windows/hardware/ff565436.aspx
107         $ c_NtQueryObject h hs_ObjectNameInformation p_oni len p_len
108       res <- peek p_len
109       oni <- peek p_oni
110       oni_length <- toIntegralSizedM (noniLength oni `div` 2)
111       WT.peekTStringLen (noniBuffer oni, oni_length)
112   where
113     checkNtStatus :: (NTSTATUS -> Bool) -> IO NTSTATUS -> IO ()
114     checkNtStatus p f = do
115       s <- f
116       if p s
117         then pure ()
118         else fail ("NtQueryObject(ObjectNameInformation) failed: " ++ showHex s "")