python:gkdi: Add Gkdi.from_key_envelope() method
[Samba.git] / python / samba / gkdi.py
blobb62a00ed3c258dad36e275142580732c029a1b4c
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Catalyst.Net Ltd 2023
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <https://www.gnu.org/licenses/>.
19 """Group Key Distribution Service module"""
21 from enum import Enum
22 from functools import total_ordering
23 from typing import Optional, Tuple
25 from cryptography.hazmat.primitives import hashes
27 from samba import _glue
28 from samba.dcerpc import gkdi, misc
29 from samba.ndr import ndr_pack, ndr_unpack
30 from samba.nt_time import NtTime, NtTimeDelta
33 uint64_max: int = 2**64 - 1
35 L1_KEY_ITERATION: int = _glue.GKDI_L1_KEY_ITERATION
36 L2_KEY_ITERATION: int = _glue.GKDI_L2_KEY_ITERATION
37 KEY_CYCLE_DURATION: NtTimeDelta = _glue.GKDI_KEY_CYCLE_DURATION
38 MAX_CLOCK_SKEW: NtTimeDelta = _glue.GKDI_MAX_CLOCK_SKEW
40 KEY_LEN_BYTES = 64
43 class Algorithm(Enum):
44 SHA1 = "SHA1"
45 SHA256 = "SHA256"
46 SHA384 = "SHA384"
47 SHA512 = "SHA512"
49 def algorithm(self) -> hashes.HashAlgorithm:
50 if self is Algorithm.SHA1:
51 return hashes.SHA1()
53 if self is Algorithm.SHA256:
54 return hashes.SHA256()
56 if self is Algorithm.SHA384:
57 return hashes.SHA384()
59 if self is Algorithm.SHA512:
60 return hashes.SHA512()
62 raise RuntimeError("unknown hash algorithm {self}")
64 def __repr__(self) -> str:
65 return str(self)
67 @staticmethod
68 def from_kdf_parameters(kdf_param: Optional[bytes]) -> "Algorithm":
69 if not kdf_param:
70 return Algorithm.SHA256 # the default used by Windows.
72 kdf_parameters = ndr_unpack(gkdi.KdfParameters, kdf_param)
73 return Algorithm(kdf_parameters.hash_algorithm)
76 class GkidType(Enum):
77 DEFAULT = object()
78 L0_SEED_KEY = object()
79 L1_SEED_KEY = object()
80 L2_SEED_KEY = object()
82 def description(self) -> str:
83 if self is GkidType.DEFAULT:
84 return "a default GKID"
86 if self is GkidType.L0_SEED_KEY:
87 return "an L0 seed key"
89 if self is GkidType.L1_SEED_KEY:
90 return "an L1 seed key"
92 if self is GkidType.L2_SEED_KEY:
93 return "an L2 seed key"
95 raise RuntimeError("unknown GKID type {self}")
98 class InvalidDerivation(Exception):
99 pass
102 class UndefinedStartTime(Exception):
103 pass
106 @total_ordering
107 class Gkid:
108 # L2 increments every 10 hours. It rolls over after 320 hours (13 days and 8 hours).
109 # L1 increments every 320 hours. It rolls over after 10240 hours (426 days and 16 hours).
110 # L0 increments every 10240 hours. It rolls over after 43980465111040 hours (five billion years).
112 __slots__ = ["_l0_idx", "_l1_idx", "_l2_idx"]
114 max_l0_idx = 0x7FFF_FFFF
116 def __init__(self, l0_idx: int, l1_idx: int, l2_idx: int) -> None:
117 if not -1 <= l0_idx <= Gkid.max_l0_idx:
118 raise ValueError(f"L0 index {l0_idx} out of range")
120 if not -1 <= l1_idx < L1_KEY_ITERATION:
121 raise ValueError(f"L1 index {l1_idx} out of range")
123 if not -1 <= l2_idx < L2_KEY_ITERATION:
124 raise ValueError(f"L2 index {l2_idx} out of range")
126 if l0_idx == -1 and l1_idx != -1:
127 raise ValueError("invalid combination of negative and non‐negative indices")
129 if l1_idx == -1 and l2_idx != -1:
130 raise ValueError("invalid combination of negative and non‐negative indices")
132 self._l0_idx = l0_idx
133 self._l1_idx = l1_idx
134 self._l2_idx = l2_idx
136 @property
137 def l0_idx(self) -> int:
138 return self._l0_idx
140 @property
141 def l1_idx(self) -> int:
142 return self._l1_idx
144 @property
145 def l2_idx(self) -> int:
146 return self._l2_idx
148 def gkid_type(self) -> GkidType:
149 if self.l0_idx == -1:
150 return GkidType.DEFAULT
152 if self.l1_idx == -1:
153 return GkidType.L0_SEED_KEY
155 if self.l2_idx == -1:
156 return GkidType.L1_SEED_KEY
158 return GkidType.L2_SEED_KEY
160 def wrapped_l1_idx(self) -> int:
161 if self.l1_idx == -1:
162 return L1_KEY_ITERATION
164 return self.l1_idx
166 def wrapped_l2_idx(self) -> int:
167 if self.l2_idx == -1:
168 return L2_KEY_ITERATION
170 return self.l2_idx
172 def derive_l1_seed_key(self) -> "Gkid":
173 gkid_type = self.gkid_type()
174 if (
175 gkid_type is not GkidType.L0_SEED_KEY
176 and gkid_type is not GkidType.L1_SEED_KEY
178 raise InvalidDerivation(
179 "Invalid attempt to derive an L1 seed key from"
180 f" {gkid_type.description()}"
183 if self.l1_idx == 0:
184 raise InvalidDerivation("No further derivation of L1 seed keys is possible")
186 return Gkid(self.l0_idx, self.wrapped_l1_idx() - 1, self.l2_idx)
188 def derive_l2_seed_key(self) -> "Gkid":
189 gkid_type = self.gkid_type()
190 if (
191 gkid_type is not GkidType.L1_SEED_KEY
192 and gkid_type is not GkidType.L2_SEED_KEY
194 raise InvalidDerivation(
195 f"Attempt to derive an L2 seed key from {gkid_type.description()}"
198 if self.l2_idx == 0:
199 raise InvalidDerivation("No further derivation of L2 seed keys is possible")
201 return Gkid(self.l0_idx, self.l1_idx, self.wrapped_l2_idx() - 1)
203 def __str__(self) -> str:
204 return f"Gkid({self.l0_idx}, {self.l1_idx}, {self.l2_idx})"
206 def __repr__(self) -> str:
207 cls = type(self)
208 return (
209 f"{cls.__qualname__}({repr(self.l0_idx)}, {repr(self.l1_idx)},"
210 f" {repr(self.l2_idx)})"
213 def __eq__(self, other: object) -> bool:
214 if not isinstance(other, Gkid):
215 return NotImplemented
217 return (self.l0_idx, self.l1_idx, self.l2_idx) == (
218 other.l0_idx,
219 other.l1_idx,
220 other.l2_idx,
223 def __lt__(self, other: object) -> bool:
224 if not isinstance(other, Gkid):
225 return NotImplemented
227 def as_tuple(gkid: Gkid) -> Tuple[int, int, int]:
228 l0_idx, l1_idx, l2_idx = gkid.l0_idx, gkid.l1_idx, gkid.l2_idx
230 # DEFAULT is considered less than everything else, so that the
231 # lexical ordering requirement in [MS-GKDI] 3.1.4.1.3 (GetKey) makes
232 # sense.
233 if gkid.gkid_type() is not GkidType.DEFAULT:
234 # Use the wrapped indices so that L1 seed keys are considered
235 # greater than their children L2 seed keys, and L0 seed keys are
236 # considered greater than their children L1 seed keys.
237 l1_idx = gkid.wrapped_l1_idx()
238 l2_idx = gkid.wrapped_l2_idx()
240 return l0_idx, l1_idx, l2_idx
242 return as_tuple(self) < as_tuple(other)
244 def __hash__(self) -> int:
245 return hash((self.l0_idx, self.l1_idx, self.l2_idx))
247 @staticmethod
248 def default() -> "Gkid":
249 return Gkid(-1, -1, -1)
251 @staticmethod
252 def l0_seed_key(l0_idx: int) -> "Gkid":
253 return Gkid(l0_idx, -1, -1)
255 @staticmethod
256 def l1_seed_key(l0_idx: int, l1_idx: int) -> "Gkid":
257 return Gkid(l0_idx, l1_idx, -1)
259 @staticmethod
260 def from_nt_time(nt_time: NtTime) -> "Gkid":
261 l0 = nt_time // (L1_KEY_ITERATION * L2_KEY_ITERATION * KEY_CYCLE_DURATION)
262 l1 = (
263 nt_time
264 % (L1_KEY_ITERATION * L2_KEY_ITERATION * KEY_CYCLE_DURATION)
265 // (L2_KEY_ITERATION * KEY_CYCLE_DURATION)
267 l2 = nt_time % (L2_KEY_ITERATION * KEY_CYCLE_DURATION) // KEY_CYCLE_DURATION
269 return Gkid(l0, l1, l2)
271 def start_nt_time(self) -> NtTime:
272 gkid_type = self.gkid_type()
273 if gkid_type is not GkidType.L2_SEED_KEY:
274 raise UndefinedStartTime(
275 f"{gkid_type.description()} has no defined start time"
278 start_time = NtTime(
280 self.l0_idx * L1_KEY_ITERATION * L2_KEY_ITERATION
281 + self.l1_idx * L2_KEY_ITERATION
282 + self.l2_idx
284 * KEY_CYCLE_DURATION
287 if not 0 <= start_time <= uint64_max:
288 raise OverflowError(f"start time {start_time} out of range")
290 return start_time
292 @staticmethod
293 def from_key_envelope(env: gkdi.KeyEnvelope) -> "Gkid":
294 return Gkid(env.l0_index, env.l1_index, env.l2_index)
297 class SeedKeyPair:
298 __slots__ = ["l1_key", "l2_key", "gkid", "hash_algorithm", "root_key_id"]
300 def __init__(
301 self,
302 l1_key: Optional[bytes],
303 l2_key: Optional[bytes],
304 gkid: Gkid,
305 hash_algorithm: Algorithm,
306 root_key_id: misc.GUID,
307 ) -> None:
308 if l1_key is not None and len(l1_key) != KEY_LEN_BYTES:
309 raise ValueError(f"L1 key ({repr(l1_key)}) must be {KEY_LEN_BYTES} bytes")
310 if l2_key is not None and len(l2_key) != KEY_LEN_BYTES:
311 raise ValueError(f"L2 key ({repr(l2_key)}) must be {KEY_LEN_BYTES} bytes")
313 self.l1_key = l1_key
314 self.l2_key = l2_key
315 self.gkid = gkid
316 self.hash_algorithm = hash_algorithm
317 self.root_key_id = root_key_id
319 def __str__(self) -> str:
320 l1_key_hex = None if self.l1_key is None else self.l1_key.hex()
321 l2_key_hex = None if self.l2_key is None else self.l2_key.hex()
323 return (
324 f"SeedKeyPair(L1Key({l1_key_hex}), L2Key({l2_key_hex}), {self.gkid},"
325 f" {self.root_key_id}, {self.hash_algorithm})"
328 def __repr__(self) -> str:
329 cls = type(self)
330 return (
331 f"{cls.__qualname__}({repr(self.l1_key)}, {repr(self.l2_key)},"
332 f" {repr(self.gkid)}, {repr(self.hash_algorithm)},"
333 f" {repr(self.root_key_id)})"
336 def __eq__(self, other: object) -> bool:
337 if not isinstance(other, SeedKeyPair):
338 return NotImplemented
340 return (
341 self.l1_key,
342 self.l2_key,
343 self.gkid,
344 self.hash_algorithm,
345 self.root_key_id,
346 ) == (
347 other.l1_key,
348 other.l2_key,
349 other.gkid,
350 other.hash_algorithm,
351 other.root_key_id,
354 def __hash__(self) -> int:
355 return hash((
356 self.l1_key,
357 self.l2_key,
358 self.gkid,
359 self.hash_algorithm,
360 ndr_pack(self.root_key_id),
364 class GroupKey:
365 __slots__ = ["gkid", "key", "hash_algorithm", "root_key_id"]
367 def __init__(
368 self, key: bytes, gkid: Gkid, hash_algorithm: Algorithm, root_key_id: misc.GUID
369 ) -> None:
370 if key is not None and len(key) != KEY_LEN_BYTES:
371 raise ValueError(f"Key ({repr(key)}) must be {KEY_LEN_BYTES} bytes")
373 self.key = key
374 self.gkid = gkid
375 self.hash_algorithm = hash_algorithm
376 self.root_key_id = root_key_id
378 def __str__(self) -> str:
379 return (
380 f"GroupKey(Key({self.key.hex()}), {self.gkid}, {self.hash_algorithm},"
381 f" {self.root_key_id})"
384 def __repr__(self) -> str:
385 cls = type(self)
386 return (
387 f"{cls.__qualname__}({repr(self.key)}, {repr(self.gkid)},"
388 f" {repr(self.hash_algorithm)}, {repr(self.root_key_id)})"
391 def __eq__(self, other: object) -> bool:
392 if not isinstance(other, GroupKey):
393 return NotImplemented
395 return (self.key, self.gkid, self.hash_algorithm, self.root_key_id) == (
396 other.key,
397 other.gkid,
398 other.hash_algorithm,
399 other.root_key_id,
402 def __hash__(self) -> int:
403 return hash(
404 (self.key, self.gkid, self.hash_algorithm, ndr_pack(self.root_key_id))