1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Catalyst.Net Ltd 2023
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <https://www.gnu.org/licenses/>.
19 """Group Key Distribution Service module"""
22 from functools
import total_ordering
23 from typing
import Optional
, Tuple
25 from cryptography
.hazmat
.primitives
import hashes
27 from samba
import _glue
28 from samba
.dcerpc
import gkdi
, misc
29 from samba
.ndr
import ndr_pack
, ndr_unpack
30 from samba
.nt_time
import NtTime
, NtTimeDelta
33 uint64_max
: int = 2**64 - 1
35 L1_KEY_ITERATION
: int = _glue
.GKDI_L1_KEY_ITERATION
36 L2_KEY_ITERATION
: int = _glue
.GKDI_L2_KEY_ITERATION
37 KEY_CYCLE_DURATION
: NtTimeDelta
= _glue
.GKDI_KEY_CYCLE_DURATION
38 MAX_CLOCK_SKEW
: NtTimeDelta
= _glue
.GKDI_MAX_CLOCK_SKEW
43 class Algorithm(Enum
):
49 def algorithm(self
) -> hashes
.HashAlgorithm
:
50 if self
is Algorithm
.SHA1
:
53 if self
is Algorithm
.SHA256
:
54 return hashes
.SHA256()
56 if self
is Algorithm
.SHA384
:
57 return hashes
.SHA384()
59 if self
is Algorithm
.SHA512
:
60 return hashes
.SHA512()
62 raise RuntimeError("unknown hash algorithm {self}")
64 def __repr__(self
) -> str:
68 def from_kdf_parameters(kdf_param
: Optional
[bytes
]) -> "Algorithm":
70 return Algorithm
.SHA256
# the default used by Windows.
72 kdf_parameters
= ndr_unpack(gkdi
.KdfParameters
, kdf_param
)
73 return Algorithm(kdf_parameters
.hash_algorithm
)
78 L0_SEED_KEY
= object()
79 L1_SEED_KEY
= object()
80 L2_SEED_KEY
= object()
82 def description(self
) -> str:
83 if self
is GkidType
.DEFAULT
:
84 return "a default GKID"
86 if self
is GkidType
.L0_SEED_KEY
:
87 return "an L0 seed key"
89 if self
is GkidType
.L1_SEED_KEY
:
90 return "an L1 seed key"
92 if self
is GkidType
.L2_SEED_KEY
:
93 return "an L2 seed key"
95 raise RuntimeError("unknown GKID type {self}")
98 class InvalidDerivation(Exception):
102 class UndefinedStartTime(Exception):
108 # L2 increments every 10 hours. It rolls over after 320 hours (13 days and 8 hours).
109 # L1 increments every 320 hours. It rolls over after 10240 hours (426 days and 16 hours).
110 # L0 increments every 10240 hours. It rolls over after 43980465111040 hours (five billion years).
112 __slots__
= ["_l0_idx", "_l1_idx", "_l2_idx"]
114 max_l0_idx
= 0x7FFF_FFFF
116 def __init__(self
, l0_idx
: int, l1_idx
: int, l2_idx
: int) -> None:
117 if not -1 <= l0_idx
<= Gkid
.max_l0_idx
:
118 raise ValueError(f
"L0 index {l0_idx} out of range")
120 if not -1 <= l1_idx
< L1_KEY_ITERATION
:
121 raise ValueError(f
"L1 index {l1_idx} out of range")
123 if not -1 <= l2_idx
< L2_KEY_ITERATION
:
124 raise ValueError(f
"L2 index {l2_idx} out of range")
126 if l0_idx
== -1 and l1_idx
!= -1:
127 raise ValueError("invalid combination of negative and non‐negative indices")
129 if l1_idx
== -1 and l2_idx
!= -1:
130 raise ValueError("invalid combination of negative and non‐negative indices")
132 self
._l0_idx
= l0_idx
133 self
._l1_idx
= l1_idx
134 self
._l2_idx
= l2_idx
137 def l0_idx(self
) -> int:
141 def l1_idx(self
) -> int:
145 def l2_idx(self
) -> int:
148 def gkid_type(self
) -> GkidType
:
149 if self
.l0_idx
== -1:
150 return GkidType
.DEFAULT
152 if self
.l1_idx
== -1:
153 return GkidType
.L0_SEED_KEY
155 if self
.l2_idx
== -1:
156 return GkidType
.L1_SEED_KEY
158 return GkidType
.L2_SEED_KEY
160 def wrapped_l1_idx(self
) -> int:
161 if self
.l1_idx
== -1:
162 return L1_KEY_ITERATION
166 def wrapped_l2_idx(self
) -> int:
167 if self
.l2_idx
== -1:
168 return L2_KEY_ITERATION
172 def derive_l1_seed_key(self
) -> "Gkid":
173 gkid_type
= self
.gkid_type()
175 gkid_type
is not GkidType
.L0_SEED_KEY
176 and gkid_type
is not GkidType
.L1_SEED_KEY
178 raise InvalidDerivation(
179 "Invalid attempt to derive an L1 seed key from"
180 f
" {gkid_type.description()}"
184 raise InvalidDerivation("No further derivation of L1 seed keys is possible")
186 return Gkid(self
.l0_idx
, self
.wrapped_l1_idx() - 1, self
.l2_idx
)
188 def derive_l2_seed_key(self
) -> "Gkid":
189 gkid_type
= self
.gkid_type()
191 gkid_type
is not GkidType
.L1_SEED_KEY
192 and gkid_type
is not GkidType
.L2_SEED_KEY
194 raise InvalidDerivation(
195 f
"Attempt to derive an L2 seed key from {gkid_type.description()}"
199 raise InvalidDerivation("No further derivation of L2 seed keys is possible")
201 return Gkid(self
.l0_idx
, self
.l1_idx
, self
.wrapped_l2_idx() - 1)
203 def __str__(self
) -> str:
204 return f
"Gkid({self.l0_idx}, {self.l1_idx}, {self.l2_idx})"
206 def __repr__(self
) -> str:
209 f
"{cls.__qualname__}({repr(self.l0_idx)}, {repr(self.l1_idx)},"
210 f
" {repr(self.l2_idx)})"
213 def __eq__(self
, other
: object) -> bool:
214 if not isinstance(other
, Gkid
):
215 return NotImplemented
217 return (self
.l0_idx
, self
.l1_idx
, self
.l2_idx
) == (
223 def __lt__(self
, other
: object) -> bool:
224 if not isinstance(other
, Gkid
):
225 return NotImplemented
227 def as_tuple(gkid
: Gkid
) -> Tuple
[int, int, int]:
228 l0_idx
, l1_idx
, l2_idx
= gkid
.l0_idx
, gkid
.l1_idx
, gkid
.l2_idx
230 # DEFAULT is considered less than everything else, so that the
231 # lexical ordering requirement in [MS-GKDI] 3.1.4.1.3 (GetKey) makes
233 if gkid
.gkid_type() is not GkidType
.DEFAULT
:
234 # Use the wrapped indices so that L1 seed keys are considered
235 # greater than their children L2 seed keys, and L0 seed keys are
236 # considered greater than their children L1 seed keys.
237 l1_idx
= gkid
.wrapped_l1_idx()
238 l2_idx
= gkid
.wrapped_l2_idx()
240 return l0_idx
, l1_idx
, l2_idx
242 return as_tuple(self
) < as_tuple(other
)
244 def __hash__(self
) -> int:
245 return hash((self
.l0_idx
, self
.l1_idx
, self
.l2_idx
))
248 def default() -> "Gkid":
249 return Gkid(-1, -1, -1)
252 def l0_seed_key(l0_idx
: int) -> "Gkid":
253 return Gkid(l0_idx
, -1, -1)
256 def l1_seed_key(l0_idx
: int, l1_idx
: int) -> "Gkid":
257 return Gkid(l0_idx
, l1_idx
, -1)
260 def from_nt_time(nt_time
: NtTime
) -> "Gkid":
261 l0
= nt_time
// (L1_KEY_ITERATION
* L2_KEY_ITERATION
* KEY_CYCLE_DURATION
)
264 % (L1_KEY_ITERATION
* L2_KEY_ITERATION
* KEY_CYCLE_DURATION
)
265 // (L2_KEY_ITERATION
* KEY_CYCLE_DURATION
)
267 l2
= nt_time
% (L2_KEY_ITERATION
* KEY_CYCLE_DURATION
) // KEY_CYCLE_DURATION
269 return Gkid(l0
, l1
, l2
)
271 def start_nt_time(self
) -> NtTime
:
272 gkid_type
= self
.gkid_type()
273 if gkid_type
is not GkidType
.L2_SEED_KEY
:
274 raise UndefinedStartTime(
275 f
"{gkid_type.description()} has no defined start time"
280 self
.l0_idx
* L1_KEY_ITERATION
* L2_KEY_ITERATION
281 + self
.l1_idx
* L2_KEY_ITERATION
287 if not 0 <= start_time
<= uint64_max
:
288 raise OverflowError(f
"start time {start_time} out of range")
294 __slots__
= ["l1_key", "l2_key", "gkid", "hash_algorithm", "root_key_id"]
298 l1_key
: Optional
[bytes
],
299 l2_key
: Optional
[bytes
],
301 hash_algorithm
: Algorithm
,
302 root_key_id
: misc
.GUID
,
304 if l1_key
is not None and len(l1_key
) != KEY_LEN_BYTES
:
305 raise ValueError(f
"L1 key ({repr(l1_key)}) must be {KEY_LEN_BYTES} bytes")
306 if l2_key
is not None and len(l2_key
) != KEY_LEN_BYTES
:
307 raise ValueError(f
"L2 key ({repr(l2_key)}) must be {KEY_LEN_BYTES} bytes")
312 self
.hash_algorithm
= hash_algorithm
313 self
.root_key_id
= root_key_id
315 def __str__(self
) -> str:
316 l1_key_hex
= None if self
.l1_key
is None else self
.l1_key
.hex()
317 l2_key_hex
= None if self
.l2_key
is None else self
.l2_key
.hex()
320 f
"SeedKeyPair(L1Key({l1_key_hex}), L2Key({l2_key_hex}), {self.gkid},"
321 f
" {self.root_key_id}, {self.hash_algorithm})"
324 def __repr__(self
) -> str:
327 f
"{cls.__qualname__}({repr(self.l1_key)}, {repr(self.l2_key)},"
328 f
" {repr(self.gkid)}, {repr(self.hash_algorithm)},"
329 f
" {repr(self.root_key_id)})"
332 def __eq__(self
, other
: object) -> bool:
333 if not isinstance(other
, SeedKeyPair
):
334 return NotImplemented
346 other
.hash_algorithm
,
350 def __hash__(self
) -> int:
356 ndr_pack(self
.root_key_id
),
361 __slots__
= ["gkid", "key", "hash_algorithm", "root_key_id"]
364 self
, key
: bytes
, gkid
: Gkid
, hash_algorithm
: Algorithm
, root_key_id
: misc
.GUID
366 if key
is not None and len(key
) != KEY_LEN_BYTES
:
367 raise ValueError(f
"Key ({repr(key)}) must be {KEY_LEN_BYTES} bytes")
371 self
.hash_algorithm
= hash_algorithm
372 self
.root_key_id
= root_key_id
374 def __str__(self
) -> str:
376 f
"GroupKey(Key({self.key.hex()}), {self.gkid}, {self.hash_algorithm},"
377 f
" {self.root_key_id})"
380 def __repr__(self
) -> str:
383 f
"{cls.__qualname__}({repr(self.key)}, {repr(self.gkid)},"
384 f
" {repr(self.hash_algorithm)}, {repr(self.root_key_id)})"
387 def __eq__(self
, other
: object) -> bool:
388 if not isinstance(other
, GroupKey
):
389 return NotImplemented
391 return (self
.key
, self
.gkid
, self
.hash_algorithm
, self
.root_key_id
) == (
394 other
.hash_algorithm
,
398 def __hash__(self
) -> int:
400 (self
.key
, self
.gkid
, self
.hash_algorithm
, ndr_pack(self
.root_key_id
))