traffic_replay: Make sure naming assumptions are in a single place
[Samba.git] / python / samba / emulate / traffic.py
blob77a1862f1a8090188cd0dd32f7e54bf4f176b344
1 # -*- encoding: utf-8 -*-
2 # Samba traffic replay and learning
4 # Copyright (C) Catalyst IT Ltd. 2017
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from __future__ import print_function, division
21 import time
22 import os
23 import random
24 import json
25 import math
26 import sys
27 import signal
28 import itertools
30 from collections import OrderedDict, Counter, defaultdict
31 from samba.emulate import traffic_packets
32 from samba.samdb import SamDB
33 import ldb
34 from ldb import LdbError
35 from samba.dcerpc import ClientConnection
36 from samba.dcerpc import security, drsuapi, lsa
37 from samba.dcerpc import netlogon
38 from samba.dcerpc.netlogon import netr_Authenticator
39 from samba.dcerpc import srvsvc
40 from samba.dcerpc import samr
41 from samba.drs_utils import drs_DsBind
42 import traceback
43 from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
44 from samba.auth import system_session
45 from samba.dsdb import (
46 UF_NORMAL_ACCOUNT,
47 UF_SERVER_TRUST_ACCOUNT,
48 UF_TRUSTED_FOR_DELEGATION,
49 UF_WORKSTATION_TRUST_ACCOUNT
51 from samba.dcerpc.misc import SEC_CHAN_BDC
52 from samba import gensec
53 from samba import sd_utils
54 from samba.compat import get_string
55 from samba.logger import get_samba_logger
56 import bisect
58 SLEEP_OVERHEAD = 3e-4
60 # we don't use None, because it complicates [de]serialisation
61 NON_PACKET = '-'
63 CLIENT_CLUES = {
64 ('dns', '0'): 1.0, # query
65 ('smb', '0x72'): 1.0, # Negotiate protocol
66 ('ldap', '0'): 1.0, # bind
67 ('ldap', '3'): 1.0, # searchRequest
68 ('ldap', '2'): 1.0, # unbindRequest
69 ('cldap', '3'): 1.0,
70 ('dcerpc', '11'): 1.0, # bind
71 ('dcerpc', '14'): 1.0, # Alter_context
72 ('nbns', '0'): 1.0, # query
75 SERVER_CLUES = {
76 ('dns', '1'): 1.0, # response
77 ('ldap', '1'): 1.0, # bind response
78 ('ldap', '4'): 1.0, # search result
79 ('ldap', '5'): 1.0, # search done
80 ('cldap', '5'): 1.0,
81 ('dcerpc', '12'): 1.0, # bind_ack
82 ('dcerpc', '13'): 1.0, # bind_nak
83 ('dcerpc', '15'): 1.0, # Alter_context response
86 SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
88 WAIT_SCALE = 10.0
89 WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
90 NO_WAIT_LOG_TIME_RANGE = (-10, -3)
92 # DEBUG_LEVEL can be changed by scripts with -d
93 DEBUG_LEVEL = 0
95 LOGGER = get_samba_logger(name=__name__)
98 def debug(level, msg, *args):
99 """Print a formatted debug message to standard error.
102 :param level: The debug level, message will be printed if it is <= the
103 currently set debug level. The debug level can be set with
104 the -d option.
105 :param msg: The message to be logged, can contain C-Style format
106 specifiers
107 :param args: The parameters required by the format specifiers
109 if level <= DEBUG_LEVEL:
110 if not args:
111 print(msg, file=sys.stderr)
112 else:
113 print(msg % tuple(args), file=sys.stderr)
116 def debug_lineno(*args):
117 """ Print an unformatted log message to stderr, contaning the line number
119 tb = traceback.extract_stack(limit=2)
120 print((" %s:" "\033[01;33m"
121 "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
122 file=sys.stderr)
123 for a in args:
124 print(a, file=sys.stderr)
125 print(file=sys.stderr)
126 sys.stderr.flush()
129 def random_colour_print():
130 """Return a function that prints a randomly coloured line to stderr"""
131 n = 18 + random.randrange(214)
132 prefix = "\033[38;5;%dm" % n
134 def p(*args):
135 for a in args:
136 print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
138 return p
141 class FakePacketError(Exception):
142 pass
145 class Packet(object):
146 """Details of a network packet"""
147 def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
148 protocol, opcode, desc, extra):
150 self.timestamp = timestamp
151 self.ip_protocol = ip_protocol
152 self.stream_number = stream_number
153 self.src = src
154 self.dest = dest
155 self.protocol = protocol
156 self.opcode = opcode
157 self.desc = desc
158 self.extra = extra
159 if self.src < self.dest:
160 self.endpoints = (self.src, self.dest)
161 else:
162 self.endpoints = (self.dest, self.src)
164 @classmethod
165 def from_line(self, line):
166 fields = line.rstrip('\n').split('\t')
167 (timestamp,
168 ip_protocol,
169 stream_number,
170 src,
171 dest,
172 protocol,
173 opcode,
174 desc) = fields[:8]
175 extra = fields[8:]
177 timestamp = float(timestamp)
178 src = int(src)
179 dest = int(dest)
181 return Packet(timestamp, ip_protocol, stream_number, src, dest,
182 protocol, opcode, desc, extra)
184 def as_summary(self, time_offset=0.0):
185 """Format the packet as a traffic_summary line.
187 extra = '\t'.join(self.extra)
188 t = self.timestamp + time_offset
189 return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
191 self.ip_protocol,
192 self.stream_number or '',
193 self.src,
194 self.dest,
195 self.protocol,
196 self.opcode,
197 self.desc,
198 extra))
200 def __str__(self):
201 return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
202 (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
203 self.stream_number, self.protocol, self.opcode, self.desc,
204 ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
206 def __repr__(self):
207 return "<Packet @%s>" % self
209 def copy(self):
210 return self.__class__(self.timestamp,
211 self.ip_protocol,
212 self.stream_number,
213 self.src,
214 self.dest,
215 self.protocol,
216 self.opcode,
217 self.desc,
218 self.extra)
220 def as_packet_type(self):
221 t = '%s:%s' % (self.protocol, self.opcode)
222 return t
224 def client_score(self):
225 """A positive number means we think it is a client; a negative number
226 means we think it is a server. Zero means no idea. range: -1 to 1.
228 key = (self.protocol, self.opcode)
229 if key in CLIENT_CLUES:
230 return CLIENT_CLUES[key]
231 if key in SERVER_CLUES:
232 return -SERVER_CLUES[key]
233 return 0.0
235 def play(self, conversation, context):
236 """Send the packet over the network, if required.
238 Some packets are ignored, i.e. for protocols not handled,
239 server response messages, or messages that are generated by the
240 protocol layer associated with other packets.
242 fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
243 try:
244 fn = getattr(traffic_packets, fn_name)
246 except AttributeError as e:
247 print("Conversation(%s) Missing handler %s" %
248 (conversation.conversation_id, fn_name),
249 file=sys.stderr)
250 return
252 # Don't display a message for kerberos packets, they're not directly
253 # generated they're used to indicate kerberos should be used
254 if self.protocol != "kerberos":
255 debug(2, "Conversation(%s) Calling handler %s" %
256 (conversation.conversation_id, fn_name))
258 start = time.time()
259 try:
260 if fn(self, conversation, context):
261 # Only collect timing data for functions that generate
262 # network traffic, or fail
263 end = time.time()
264 duration = end - start
265 print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
266 (end, conversation.conversation_id, self.protocol,
267 self.opcode, duration))
268 except Exception as e:
269 end = time.time()
270 duration = end - start
271 print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
272 (end, conversation.conversation_id, self.protocol,
273 self.opcode, duration, e))
275 def __cmp__(self, other):
276 return self.timestamp - other.timestamp
278 def is_really_a_packet(self, missing_packet_stats=None):
279 """Is the packet one that can be ignored?
281 If so removing it will have no effect on the replay
283 if self.protocol in SKIPPED_PROTOCOLS:
284 # Ignore any packets for the protocols we're not interested in.
285 return False
286 if self.protocol == "ldap" and self.opcode == '':
287 # skip ldap continuation packets
288 return False
290 fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
291 fn = getattr(traffic_packets, fn_name, None)
292 if not fn:
293 print("missing packet %s" % fn_name, file=sys.stderr)
294 return False
295 if fn is traffic_packets.null_packet:
296 return False
297 return True
300 class ReplayContext(object):
301 """State/Context for an individual conversation between an simulated client
302 and a server.
305 def __init__(self,
306 server=None,
307 lp=None,
308 creds=None,
309 badpassword_frequency=None,
310 prefer_kerberos=None,
311 tempdir=None,
312 statsdir=None,
313 ou=None,
314 base_dn=None,
315 domain=None,
316 domain_sid=None):
318 self.server = server
319 self.ldap_connections = []
320 self.dcerpc_connections = []
321 self.lsarpc_connections = []
322 self.lsarpc_connections_named = []
323 self.drsuapi_connections = []
324 self.srvsvc_connections = []
325 self.samr_contexts = []
326 self.netlogon_connection = None
327 self.creds = creds
328 self.lp = lp
329 self.prefer_kerberos = prefer_kerberos
330 self.ou = ou
331 self.base_dn = base_dn
332 self.domain = domain
333 self.statsdir = statsdir
334 self.global_tempdir = tempdir
335 self.domain_sid = domain_sid
336 self.realm = lp.get('realm')
338 # Bad password attempt controls
339 self.badpassword_frequency = badpassword_frequency
340 self.last_lsarpc_bad = False
341 self.last_lsarpc_named_bad = False
342 self.last_simple_bind_bad = False
343 self.last_bind_bad = False
344 self.last_srvsvc_bad = False
345 self.last_drsuapi_bad = False
346 self.last_netlogon_bad = False
347 self.last_samlogon_bad = False
348 self.generate_ldap_search_tables()
349 self.next_conversation_id = itertools.count()
351 def generate_ldap_search_tables(self):
352 session = system_session()
354 db = SamDB(url="ldap://%s" % self.server,
355 session_info=session,
356 credentials=self.creds,
357 lp=self.lp)
359 res = db.search(db.domain_dn(),
360 scope=ldb.SCOPE_SUBTREE,
361 controls=["paged_results:1:1000"],
362 attrs=['dn'])
364 # find a list of dns for each pattern
365 # e.g. CN,CN,CN,DC,DC
366 dn_map = {}
367 attribute_clue_map = {
368 'invocationId': []
371 for r in res:
372 dn = str(r.dn)
373 pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
374 dns = dn_map.setdefault(pattern, [])
375 dns.append(dn)
376 if dn.startswith('CN=NTDS Settings,'):
377 attribute_clue_map['invocationId'].append(dn)
379 # extend the map in case we are working with a different
380 # number of DC components.
381 # for k, v in self.dn_map.items():
382 # print >>sys.stderr, k, len(v)
384 for k in list(dn_map.keys()):
385 if k[-3:] != ',DC':
386 continue
387 p = k[:-3]
388 while p[-3:] == ',DC':
389 p = p[:-3]
390 for i in range(5):
391 p += ',DC'
392 if p != k and p in dn_map:
393 print('dn_map collison %s %s' % (k, p),
394 file=sys.stderr)
395 continue
396 dn_map[p] = dn_map[k]
398 self.dn_map = dn_map
399 self.attribute_clue_map = attribute_clue_map
401 def generate_process_local_config(self, account, conversation):
402 if account is None:
403 return
404 self.netbios_name = account.netbios_name
405 self.machinepass = account.machinepass
406 self.username = account.username
407 self.userpass = account.userpass
409 self.tempdir = mk_masked_dir(self.global_tempdir,
410 'conversation-%d' %
411 conversation.conversation_id)
413 self.lp.set("private dir", self.tempdir)
414 self.lp.set("lock dir", self.tempdir)
415 self.lp.set("state directory", self.tempdir)
416 self.lp.set("tls verify peer", "no_check")
418 # If the domain was not specified, check for the environment
419 # variable.
420 if self.domain is None:
421 self.domain = os.environ["DOMAIN"]
423 self.remoteAddress = "/root/ncalrpc_as_system"
424 self.samlogon_dn = ("cn=%s,%s" %
425 (self.netbios_name, self.ou))
426 self.user_dn = ("cn=%s,%s" %
427 (self.username, self.ou))
429 self.generate_machine_creds()
430 self.generate_user_creds()
432 def with_random_bad_credentials(self, f, good, bad, failed_last_time):
433 """Execute the supplied logon function, randomly choosing the
434 bad credentials.
436 Based on the frequency in badpassword_frequency randomly perform the
437 function with the supplied bad credentials.
438 If run with bad credentials, the function is re-run with the good
439 credentials.
440 failed_last_time is used to prevent consecutive bad credential
441 attempts. So the over all bad credential frequency will be lower
442 than that requested, but not significantly.
444 if not failed_last_time:
445 if (self.badpassword_frequency and self.badpassword_frequency > 0
446 and random.random() < self.badpassword_frequency):
447 try:
448 f(bad)
449 except:
450 # Ignore any exceptions as the operation may fail
451 # as it's being performed with bad credentials
452 pass
453 failed_last_time = True
454 else:
455 failed_last_time = False
457 result = f(good)
458 return (result, failed_last_time)
460 def generate_user_creds(self):
461 """Generate the conversation specific user Credentials.
463 Each Conversation has an associated user account used to simulate
464 any non Administrative user traffic.
466 Generates user credentials with good and bad passwords and ldap
467 simple bind credentials with good and bad passwords.
469 self.user_creds = Credentials()
470 self.user_creds.guess(self.lp)
471 self.user_creds.set_workstation(self.netbios_name)
472 self.user_creds.set_password(self.userpass)
473 self.user_creds.set_username(self.username)
474 self.user_creds.set_domain(self.domain)
475 if self.prefer_kerberos:
476 self.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
477 else:
478 self.user_creds.set_kerberos_state(DONT_USE_KERBEROS)
480 self.user_creds_bad = Credentials()
481 self.user_creds_bad.guess(self.lp)
482 self.user_creds_bad.set_workstation(self.netbios_name)
483 self.user_creds_bad.set_password(self.userpass[:-4])
484 self.user_creds_bad.set_username(self.username)
485 if self.prefer_kerberos:
486 self.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
487 else:
488 self.user_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
490 # Credentials for ldap simple bind.
491 self.simple_bind_creds = Credentials()
492 self.simple_bind_creds.guess(self.lp)
493 self.simple_bind_creds.set_workstation(self.netbios_name)
494 self.simple_bind_creds.set_password(self.userpass)
495 self.simple_bind_creds.set_username(self.username)
496 self.simple_bind_creds.set_gensec_features(
497 self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
498 if self.prefer_kerberos:
499 self.simple_bind_creds.set_kerberos_state(MUST_USE_KERBEROS)
500 else:
501 self.simple_bind_creds.set_kerberos_state(DONT_USE_KERBEROS)
502 self.simple_bind_creds.set_bind_dn(self.user_dn)
504 self.simple_bind_creds_bad = Credentials()
505 self.simple_bind_creds_bad.guess(self.lp)
506 self.simple_bind_creds_bad.set_workstation(self.netbios_name)
507 self.simple_bind_creds_bad.set_password(self.userpass[:-4])
508 self.simple_bind_creds_bad.set_username(self.username)
509 self.simple_bind_creds_bad.set_gensec_features(
510 self.simple_bind_creds_bad.get_gensec_features() |
511 gensec.FEATURE_SEAL)
512 if self.prefer_kerberos:
513 self.simple_bind_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
514 else:
515 self.simple_bind_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
516 self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
518 def generate_machine_creds(self):
519 """Generate the conversation specific machine Credentials.
521 Each Conversation has an associated machine account.
523 Generates machine credentials with good and bad passwords.
526 self.machine_creds = Credentials()
527 self.machine_creds.guess(self.lp)
528 self.machine_creds.set_workstation(self.netbios_name)
529 self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
530 self.machine_creds.set_password(self.machinepass)
531 self.machine_creds.set_username(self.netbios_name + "$")
532 self.machine_creds.set_domain(self.domain)
533 if self.prefer_kerberos:
534 self.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
535 else:
536 self.machine_creds.set_kerberos_state(DONT_USE_KERBEROS)
538 self.machine_creds_bad = Credentials()
539 self.machine_creds_bad.guess(self.lp)
540 self.machine_creds_bad.set_workstation(self.netbios_name)
541 self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
542 self.machine_creds_bad.set_password(self.machinepass[:-4])
543 self.machine_creds_bad.set_username(self.netbios_name + "$")
544 if self.prefer_kerberos:
545 self.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
546 else:
547 self.machine_creds_bad.set_kerberos_state(DONT_USE_KERBEROS)
549 def get_matching_dn(self, pattern, attributes=None):
550 # If the pattern is an empty string, we assume ROOTDSE,
551 # Otherwise we try adding or removing DC suffixes, then
552 # shorter leading patterns until we hit one.
553 # e.g if there is no CN,CN,CN,CN,DC,DC
554 # we first try CN,CN,CN,CN,DC
555 # and CN,CN,CN,CN,DC,DC,DC
556 # then change to CN,CN,CN,DC,DC
557 # and as last resort we use the base_dn
558 attr_clue = self.attribute_clue_map.get(attributes)
559 if attr_clue:
560 return random.choice(attr_clue)
562 pattern = pattern.upper()
563 while pattern:
564 if pattern in self.dn_map:
565 return random.choice(self.dn_map[pattern])
566 # chop one off the front and try it all again.
567 pattern = pattern[3:]
569 return self.base_dn
571 def get_dcerpc_connection(self, new=False):
572 guid = '12345678-1234-abcd-ef00-01234567cffb' # RPC_NETLOGON UUID
573 if self.dcerpc_connections and not new:
574 return self.dcerpc_connections[-1]
575 c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
576 (guid, 1), self.lp)
577 self.dcerpc_connections.append(c)
578 return c
580 def get_srvsvc_connection(self, new=False):
581 if self.srvsvc_connections and not new:
582 return self.srvsvc_connections[-1]
584 def connect(creds):
585 return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
586 self.lp,
587 creds)
589 (c, self.last_srvsvc_bad) = \
590 self.with_random_bad_credentials(connect,
591 self.user_creds,
592 self.user_creds_bad,
593 self.last_srvsvc_bad)
595 self.srvsvc_connections.append(c)
596 return c
598 def get_lsarpc_connection(self, new=False):
599 if self.lsarpc_connections and not new:
600 return self.lsarpc_connections[-1]
602 def connect(creds):
603 binding_options = 'schannel,seal,sign'
604 return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
605 (self.server, binding_options),
606 self.lp,
607 creds)
609 (c, self.last_lsarpc_bad) = \
610 self.with_random_bad_credentials(connect,
611 self.machine_creds,
612 self.machine_creds_bad,
613 self.last_lsarpc_bad)
615 self.lsarpc_connections.append(c)
616 return c
618 def get_lsarpc_named_pipe_connection(self, new=False):
619 if self.lsarpc_connections_named and not new:
620 return self.lsarpc_connections_named[-1]
622 def connect(creds):
623 return lsa.lsarpc("ncacn_np:%s" % (self.server),
624 self.lp,
625 creds)
627 (c, self.last_lsarpc_named_bad) = \
628 self.with_random_bad_credentials(connect,
629 self.machine_creds,
630 self.machine_creds_bad,
631 self.last_lsarpc_named_bad)
633 self.lsarpc_connections_named.append(c)
634 return c
636 def get_drsuapi_connection_pair(self, new=False, unbind=False):
637 """get a (drs, drs_handle) tuple"""
638 if self.drsuapi_connections and not new:
639 c = self.drsuapi_connections[-1]
640 return c
642 def connect(creds):
643 binding_options = 'seal'
644 binding_string = "ncacn_ip_tcp:%s[%s]" %\
645 (self.server, binding_options)
646 return drsuapi.drsuapi(binding_string, self.lp, creds)
648 (drs, self.last_drsuapi_bad) = \
649 self.with_random_bad_credentials(connect,
650 self.user_creds,
651 self.user_creds_bad,
652 self.last_drsuapi_bad)
654 (drs_handle, supported_extensions) = drs_DsBind(drs)
655 c = (drs, drs_handle)
656 self.drsuapi_connections.append(c)
657 return c
659 def get_ldap_connection(self, new=False, simple=False):
660 if self.ldap_connections and not new:
661 return self.ldap_connections[-1]
663 def simple_bind(creds):
665 To run simple bind against Windows, we need to run
666 following commands in PowerShell:
668 Install-windowsfeature ADCS-Cert-Authority
669 Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
670 Restart-Computer
673 return SamDB('ldaps://%s' % self.server,
674 credentials=creds,
675 lp=self.lp)
677 def sasl_bind(creds):
678 return SamDB('ldap://%s' % self.server,
679 credentials=creds,
680 lp=self.lp)
681 if simple:
682 (samdb, self.last_simple_bind_bad) = \
683 self.with_random_bad_credentials(simple_bind,
684 self.simple_bind_creds,
685 self.simple_bind_creds_bad,
686 self.last_simple_bind_bad)
687 else:
688 (samdb, self.last_bind_bad) = \
689 self.with_random_bad_credentials(sasl_bind,
690 self.user_creds,
691 self.user_creds_bad,
692 self.last_bind_bad)
694 self.ldap_connections.append(samdb)
695 return samdb
697 def get_samr_context(self, new=False):
698 if not self.samr_contexts or new:
699 self.samr_contexts.append(
700 SamrContext(self.server, lp=self.lp, creds=self.creds))
701 return self.samr_contexts[-1]
703 def get_netlogon_connection(self):
705 if self.netlogon_connection:
706 return self.netlogon_connection
708 def connect(creds):
709 return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
710 (self.server),
711 self.lp,
712 creds)
713 (c, self.last_netlogon_bad) = \
714 self.with_random_bad_credentials(connect,
715 self.machine_creds,
716 self.machine_creds_bad,
717 self.last_netlogon_bad)
718 self.netlogon_connection = c
719 return c
721 def guess_a_dns_lookup(self):
722 return (self.realm, 'A')
724 def get_authenticator(self):
725 auth = self.machine_creds.new_client_authenticator()
726 current = netr_Authenticator()
727 current.cred.data = [x if isinstance(x, int) else ord(x) for x in auth["credential"]]
728 current.timestamp = auth["timestamp"]
730 subsequent = netr_Authenticator()
731 return (current, subsequent)
734 class SamrContext(object):
735 """State/Context associated with a samr connection.
737 def __init__(self, server, lp=None, creds=None):
738 self.connection = None
739 self.handle = None
740 self.domain_handle = None
741 self.domain_sid = None
742 self.group_handle = None
743 self.user_handle = None
744 self.rids = None
745 self.server = server
746 self.lp = lp
747 self.creds = creds
749 def get_connection(self):
750 if not self.connection:
751 self.connection = samr.samr(
752 "ncacn_ip_tcp:%s[seal]" % (self.server),
753 lp_ctx=self.lp,
754 credentials=self.creds)
756 return self.connection
758 def get_handle(self):
759 if not self.handle:
760 c = self.get_connection()
761 self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
762 return self.handle
765 class Conversation(object):
766 """Details of a converation between a simulated client and a server."""
767 conversation_id = None
769 def __init__(self, start_time=None, endpoints=None):
770 self.start_time = start_time
771 self.endpoints = endpoints
772 self.packets = []
773 self.msg = random_colour_print()
774 self.client_balance = 0.0
776 def __cmp__(self, other):
777 if self.start_time is None:
778 if other.start_time is None:
779 return 0
780 return -1
781 if other.start_time is None:
782 return 1
783 return self.start_time - other.start_time
785 def add_packet(self, packet):
786 """Add a packet object to this conversation, making a local copy with
787 a conversation-relative timestamp."""
788 p = packet.copy()
790 if self.start_time is None:
791 self.start_time = p.timestamp
793 if self.endpoints is None:
794 self.endpoints = p.endpoints
796 if p.endpoints != self.endpoints:
797 raise FakePacketError("Conversation endpoints %s don't match"
798 "packet endpoints %s" %
799 (self.endpoints, p.endpoints))
801 p.timestamp -= self.start_time
803 if p.src == p.endpoints[0]:
804 self.client_balance -= p.client_score()
805 else:
806 self.client_balance += p.client_score()
808 if p.is_really_a_packet():
809 self.packets.append(p)
811 def add_short_packet(self, timestamp, protocol, opcode, extra,
812 client=True):
813 """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
814 (possibly empty) list of extra data. If client is True, assume
815 this packet is from the client to the server.
817 src, dest = self.guess_client_server()
818 if not client:
819 src, dest = dest, src
820 key = (protocol, opcode)
821 desc = OP_DESCRIPTIONS[key] if key in OP_DESCRIPTIONS else ''
822 if protocol in IP_PROTOCOLS:
823 ip_protocol = IP_PROTOCOLS[protocol]
824 else:
825 ip_protocol = '06'
826 packet = Packet(timestamp - self.start_time, ip_protocol,
827 '', src, dest,
828 protocol, opcode, desc, extra)
829 # XXX we're assuming the timestamp is already adjusted for
830 # this conversation?
831 # XXX should we adjust client balance for guessed packets?
832 if packet.src == packet.endpoints[0]:
833 self.client_balance -= packet.client_score()
834 else:
835 self.client_balance += packet.client_score()
836 if packet.is_really_a_packet():
837 self.packets.append(packet)
839 def __str__(self):
840 return ("<Conversation %s %s starting %.3f %d packets>" %
841 (self.conversation_id, self.endpoints, self.start_time,
842 len(self.packets)))
844 __repr__ = __str__
846 def __iter__(self):
847 return iter(self.packets)
849 def __len__(self):
850 return len(self.packets)
852 def get_duration(self):
853 if len(self.packets) < 2:
854 return 0
855 return self.packets[-1].timestamp - self.packets[0].timestamp
857 def replay_as_summary_lines(self):
858 lines = []
859 for p in self.packets:
860 lines.append(p.as_summary(self.start_time))
861 return lines
863 def replay_in_fork_with_delay(self, start, context=None, account=None):
864 """Fork a new process and replay the conversation.
866 def signal_handler(signal, frame):
867 """Signal handler closes standard out and error.
869 Triggered by a sigterm, ensures that the log messages are flushed
870 to disk and not lost.
872 sys.stderr.close()
873 sys.stdout.close()
874 os._exit(0)
876 t = self.start_time
877 now = time.time() - start
878 gap = t - now
879 # we are replaying strictly in order, so it is safe to sleep
880 # in the main process if the gap is big enough. This reduces
881 # the number of concurrent threads, which allows us to make
882 # larger loads.
883 if gap > 0.15 and False:
884 print("sleeping for %f in main process" % (gap - 0.1),
885 file=sys.stderr)
886 time.sleep(gap - 0.1)
887 now = time.time() - start
888 gap = t - now
889 print("gap is now %f" % gap, file=sys.stderr)
891 self.conversation_id = next(context.next_conversation_id)
892 pid = os.fork()
893 if pid != 0:
894 return pid
895 pid = os.getpid()
896 signal.signal(signal.SIGTERM, signal_handler)
897 # we must never return, or we'll end up running parts of the
898 # parent's clean-up code. So we work in a try...finally, and
899 # try to print any exceptions.
901 try:
902 context.generate_process_local_config(account, self)
903 sys.stdin.close()
904 os.close(0)
905 filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
906 self.conversation_id)
907 sys.stdout.close()
908 sys.stdout = open(filename, 'w')
910 sleep_time = gap - SLEEP_OVERHEAD
911 if sleep_time > 0:
912 time.sleep(sleep_time)
914 miss = t - (time.time() - start)
915 self.msg("starting %s [miss %.3f pid %d]" % (self, miss, pid))
916 self.replay(context)
917 except Exception:
918 print(("EXCEPTION in child PID %d, conversation %s" % (pid, self)),
919 file=sys.stderr)
920 traceback.print_exc(sys.stderr)
921 finally:
922 sys.stderr.close()
923 sys.stdout.close()
924 os._exit(0)
926 def replay(self, context=None):
927 start = time.time()
929 for p in self.packets:
930 now = time.time() - start
931 gap = p.timestamp - now
932 sleep_time = gap - SLEEP_OVERHEAD
933 if sleep_time > 0:
934 time.sleep(sleep_time)
936 miss = p.timestamp - (time.time() - start)
937 if context is None:
938 self.msg("packet %s [miss %.3f pid %d]" % (p, miss,
939 os.getpid()))
940 continue
941 p.play(self, context)
943 def guess_client_server(self, server_clue=None):
944 """Have a go at deciding who is the server and who is the client.
945 returns (client, server)
947 a, b = self.endpoints
949 if self.client_balance < 0:
950 return (a, b)
952 # in the absense of a clue, we will fall through to assuming
953 # the lowest number is the server (which is usually true).
955 if self.client_balance == 0 and server_clue == b:
956 return (a, b)
958 return (b, a)
960 def forget_packets_outside_window(self, s, e):
961 """Prune any packets outside the timne window we're interested in
963 :param s: start of the window
964 :param e: end of the window
966 self.packets = [p for p in self.packets if s <= p.timestamp <= e]
967 self.start_time = self.packets[0].timestamp if self.packets else None
969 def renormalise_times(self, start_time):
970 """Adjust the packet start times relative to the new start time."""
971 for p in self.packets:
972 p.timestamp -= start_time
974 if self.start_time is not None:
975 self.start_time -= start_time
978 class DnsHammer(Conversation):
979 """A lightweight conversation that generates a lot of dns:0 packets on
980 the fly"""
982 def __init__(self, dns_rate, duration):
983 n = int(dns_rate * duration)
984 self.times = [random.uniform(0, duration) for i in range(n)]
985 self.times.sort()
986 self.rate = dns_rate
987 self.duration = duration
988 self.start_time = 0
989 self.msg = random_colour_print()
991 def __str__(self):
992 return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
993 (len(self.times), self.duration, self.rate))
995 def replay_in_fork_with_delay(self, start, context=None, account=None):
996 return Conversation.replay_in_fork_with_delay(self,
997 start,
998 context,
999 account)
1001 def replay(self, context=None):
1002 start = time.time()
1003 fn = traffic_packets.packet_dns_0
1004 for t in self.times:
1005 now = time.time() - start
1006 gap = t - now
1007 sleep_time = gap - SLEEP_OVERHEAD
1008 if sleep_time > 0:
1009 time.sleep(sleep_time)
1011 if context is None:
1012 miss = t - (time.time() - start)
1013 self.msg("packet %s [miss %.3f pid %d]" % (t, miss,
1014 os.getpid()))
1015 continue
1017 packet_start = time.time()
1018 try:
1019 fn(self, self, context)
1020 end = time.time()
1021 duration = end - packet_start
1022 print("%f\tDNS\tdns\t0\t%f\tTrue\t" % (end, duration))
1023 except Exception as e:
1024 end = time.time()
1025 duration = end - packet_start
1026 print("%f\tDNS\tdns\t0\t%f\tFalse\t%s" % (end, duration, e))
1029 def ingest_summaries(files, dns_mode='count'):
1030 """Load a summary traffic summary file and generated Converations from it.
1033 dns_counts = defaultdict(int)
1034 packets = []
1035 for f in files:
1036 if isinstance(f, str):
1037 f = open(f)
1038 print("Ingesting %s" % (f.name,), file=sys.stderr)
1039 for line in f:
1040 p = Packet.from_line(line)
1041 if p.protocol == 'dns' and dns_mode != 'include':
1042 dns_counts[p.opcode] += 1
1043 else:
1044 packets.append(p)
1046 f.close()
1048 if not packets:
1049 return [], 0
1051 start_time = min(p.timestamp for p in packets)
1052 last_packet = max(p.timestamp for p in packets)
1054 print("gathering packets into conversations", file=sys.stderr)
1055 conversations = OrderedDict()
1056 for p in packets:
1057 p.timestamp -= start_time
1058 c = conversations.get(p.endpoints)
1059 if c is None:
1060 c = Conversation()
1061 conversations[p.endpoints] = c
1062 c.add_packet(p)
1064 # We only care about conversations with actual traffic, so we
1065 # filter out conversations with nothing to say. We do that here,
1066 # rather than earlier, because those empty packets contain useful
1067 # hints as to which end of the conversation was the client.
1068 conversation_list = []
1069 for c in conversations.values():
1070 if len(c) != 0:
1071 conversation_list.append(c)
1073 # This is obviously not correct, as many conversations will appear
1074 # to start roughly simultaneously at the beginning of the snapshot.
1075 # To which we say: oh well, so be it.
1076 duration = float(last_packet - start_time)
1077 mean_interval = len(conversations) / duration
1079 return conversation_list, mean_interval, duration, dns_counts
1082 def guess_server_address(conversations):
1083 # we guess the most common address.
1084 addresses = Counter()
1085 for c in conversations:
1086 addresses.update(c.endpoints)
1087 if addresses:
1088 return addresses.most_common(1)[0]
1091 def stringify_keys(x):
1092 y = {}
1093 for k, v in x.items():
1094 k2 = '\t'.join(k)
1095 y[k2] = v
1096 return y
1099 def unstringify_keys(x):
1100 y = {}
1101 for k, v in x.items():
1102 t = tuple(str(k).split('\t'))
1103 y[t] = v
1104 return y
1107 class TrafficModel(object):
1108 def __init__(self, n=3):
1109 self.ngrams = {}
1110 self.query_details = {}
1111 self.n = n
1112 self.dns_opcounts = defaultdict(int)
1113 self.cumulative_duration = 0.0
1114 self.conversation_rate = [0, 1]
1116 def learn(self, conversations, dns_opcounts={}):
1117 prev = 0.0
1118 cum_duration = 0.0
1119 key = (NON_PACKET,) * (self.n - 1)
1121 server = guess_server_address(conversations)
1123 for k, v in dns_opcounts.items():
1124 self.dns_opcounts[k] += v
1126 if len(conversations) > 1:
1127 elapsed =\
1128 conversations[-1].start_time - conversations[0].start_time
1129 self.conversation_rate[0] = len(conversations)
1130 self.conversation_rate[1] = elapsed
1132 for c in conversations:
1133 client, server = c.guess_client_server(server)
1134 cum_duration += c.get_duration()
1135 key = (NON_PACKET,) * (self.n - 1)
1136 for p in c:
1137 if p.src != client:
1138 continue
1140 elapsed = p.timestamp - prev
1141 prev = p.timestamp
1142 if elapsed > WAIT_THRESHOLD:
1143 # add the wait as an extra state
1144 wait = 'wait:%d' % (math.log(max(1.0,
1145 elapsed * WAIT_SCALE)))
1146 self.ngrams.setdefault(key, []).append(wait)
1147 key = key[1:] + (wait,)
1149 short_p = p.as_packet_type()
1150 self.query_details.setdefault(short_p,
1151 []).append(tuple(p.extra))
1152 self.ngrams.setdefault(key, []).append(short_p)
1153 key = key[1:] + (short_p,)
1155 self.cumulative_duration += cum_duration
1156 # add in the end
1157 self.ngrams.setdefault(key, []).append(NON_PACKET)
1159 def save(self, f):
1160 ngrams = {}
1161 for k, v in self.ngrams.items():
1162 k = '\t'.join(k)
1163 ngrams[k] = dict(Counter(v))
1165 query_details = {}
1166 for k, v in self.query_details.items():
1167 query_details[k] = dict(Counter('\t'.join(x) if x else '-'
1168 for x in v))
1170 d = {
1171 'ngrams': ngrams,
1172 'query_details': query_details,
1173 'cumulative_duration': self.cumulative_duration,
1174 'conversation_rate': self.conversation_rate,
1176 d['dns'] = self.dns_opcounts
1178 if isinstance(f, str):
1179 f = open(f, 'w')
1181 json.dump(d, f, indent=2)
1183 def load(self, f):
1184 if isinstance(f, str):
1185 f = open(f)
1187 d = json.load(f)
1189 for k, v in d['ngrams'].items():
1190 k = tuple(str(k).split('\t'))
1191 values = self.ngrams.setdefault(k, [])
1192 for p, count in v.items():
1193 values.extend([str(p)] * count)
1195 for k, v in d['query_details'].items():
1196 values = self.query_details.setdefault(str(k), [])
1197 for p, count in v.items():
1198 if p == '-':
1199 values.extend([()] * count)
1200 else:
1201 values.extend([tuple(str(p).split('\t'))] * count)
1203 if 'dns' in d:
1204 for k, v in d['dns'].items():
1205 self.dns_opcounts[k] += v
1207 self.cumulative_duration = d['cumulative_duration']
1208 self.conversation_rate = d['conversation_rate']
1210 def construct_conversation(self, timestamp=0.0, client=2, server=1,
1211 hard_stop=None, packet_rate=1):
1212 """Construct a individual converation from the model."""
1214 c = Conversation(timestamp, (server, client))
1216 key = (NON_PACKET,) * (self.n - 1)
1218 while key in self.ngrams:
1219 p = random.choice(self.ngrams.get(key, NON_PACKET))
1220 if p == NON_PACKET:
1221 break
1222 if p in self.query_details:
1223 extra = random.choice(self.query_details[p])
1224 else:
1225 extra = []
1227 protocol, opcode = p.split(':', 1)
1228 if protocol == 'wait':
1229 log_wait_time = int(opcode) + random.random()
1230 wait = math.exp(log_wait_time) / (WAIT_SCALE * packet_rate)
1231 timestamp += wait
1232 else:
1233 log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
1234 wait = math.exp(log_wait) / packet_rate
1235 timestamp += wait
1236 if hard_stop is not None and timestamp > hard_stop:
1237 break
1238 c.add_short_packet(timestamp, protocol, opcode, extra)
1240 key = key[1:] + (p,)
1242 return c
1244 def generate_conversations(self, rate, duration, packet_rate=1):
1245 """Generate a list of conversations from the model."""
1247 # We run the simulation for at least ten times as long as our
1248 # desired duration, and take a section near the start.
1249 rate_n, rate_t = self.conversation_rate
1251 duration2 = max(rate_t, duration * 2)
1252 n = rate * duration2 * rate_n / rate_t
1254 server = 1
1255 client = 2
1257 conversations = []
1258 end = duration2
1259 start = end - duration
1261 while client < n + 2:
1262 start = random.uniform(0, duration2)
1263 c = self.construct_conversation(start,
1264 client,
1265 server,
1266 hard_stop=(duration2 * 5),
1267 packet_rate=packet_rate)
1269 c.forget_packets_outside_window(start, end)
1270 c.renormalise_times(start)
1271 if len(c) != 0:
1272 conversations.append(c)
1273 client += 1
1275 print(("we have %d conversations at rate %f" %
1276 (len(conversations), rate)), file=sys.stderr)
1277 conversations.sort()
1278 return conversations
1281 IP_PROTOCOLS = {
1282 'dns': '11',
1283 'rpc_netlogon': '06',
1284 'kerberos': '06', # ratio 16248:258
1285 'smb': '06',
1286 'smb2': '06',
1287 'ldap': '06',
1288 'cldap': '11',
1289 'lsarpc': '06',
1290 'samr': '06',
1291 'dcerpc': '06',
1292 'epm': '06',
1293 'drsuapi': '06',
1294 'browser': '11',
1295 'smb_netlogon': '11',
1296 'srvsvc': '06',
1297 'nbns': '11',
1300 OP_DESCRIPTIONS = {
1301 ('browser', '0x01'): 'Host Announcement (0x01)',
1302 ('browser', '0x02'): 'Request Announcement (0x02)',
1303 ('browser', '0x08'): 'Browser Election Request (0x08)',
1304 ('browser', '0x09'): 'Get Backup List Request (0x09)',
1305 ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
1306 ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
1307 ('cldap', '3'): 'searchRequest',
1308 ('cldap', '5'): 'searchResDone',
1309 ('dcerpc', '0'): 'Request',
1310 ('dcerpc', '11'): 'Bind',
1311 ('dcerpc', '12'): 'Bind_ack',
1312 ('dcerpc', '13'): 'Bind_nak',
1313 ('dcerpc', '14'): 'Alter_context',
1314 ('dcerpc', '15'): 'Alter_context_resp',
1315 ('dcerpc', '16'): 'AUTH3',
1316 ('dcerpc', '2'): 'Response',
1317 ('dns', '0'): 'query',
1318 ('dns', '1'): 'response',
1319 ('drsuapi', '0'): 'DsBind',
1320 ('drsuapi', '12'): 'DsCrackNames',
1321 ('drsuapi', '13'): 'DsWriteAccountSpn',
1322 ('drsuapi', '1'): 'DsUnbind',
1323 ('drsuapi', '2'): 'DsReplicaSync',
1324 ('drsuapi', '3'): 'DsGetNCChanges',
1325 ('drsuapi', '4'): 'DsReplicaUpdateRefs',
1326 ('epm', '3'): 'Map',
1327 ('kerberos', ''): '',
1328 ('ldap', '0'): 'bindRequest',
1329 ('ldap', '1'): 'bindResponse',
1330 ('ldap', '2'): 'unbindRequest',
1331 ('ldap', '3'): 'searchRequest',
1332 ('ldap', '4'): 'searchResEntry',
1333 ('ldap', '5'): 'searchResDone',
1334 ('ldap', ''): '*** Unknown ***',
1335 ('lsarpc', '14'): 'lsa_LookupNames',
1336 ('lsarpc', '15'): 'lsa_LookupSids',
1337 ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
1338 ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
1339 ('lsarpc', '6'): 'lsa_OpenPolicy',
1340 ('lsarpc', '76'): 'lsa_LookupSids3',
1341 ('lsarpc', '77'): 'lsa_LookupNames4',
1342 ('nbns', '0'): 'query',
1343 ('nbns', '1'): 'response',
1344 ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
1345 ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
1346 ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
1347 ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
1348 ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
1349 ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
1350 ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
1351 ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
1352 ('samr', '0',): 'Connect',
1353 ('samr', '16'): 'GetAliasMembership',
1354 ('samr', '17'): 'LookupNames',
1355 ('samr', '18'): 'LookupRids',
1356 ('samr', '19'): 'OpenGroup',
1357 ('samr', '1'): 'Close',
1358 ('samr', '25'): 'QueryGroupMember',
1359 ('samr', '34'): 'OpenUser',
1360 ('samr', '36'): 'QueryUserInfo',
1361 ('samr', '39'): 'GetGroupsForUser',
1362 ('samr', '3'): 'QuerySecurity',
1363 ('samr', '5'): 'LookupDomain',
1364 ('samr', '64'): 'Connect5',
1365 ('samr', '6'): 'EnumDomains',
1366 ('samr', '7'): 'OpenDomain',
1367 ('samr', '8'): 'QueryDomainInfo',
1368 ('smb', '0x04'): 'Close (0x04)',
1369 ('smb', '0x24'): 'Locking AndX (0x24)',
1370 ('smb', '0x2e'): 'Read AndX (0x2e)',
1371 ('smb', '0x32'): 'Trans2 (0x32)',
1372 ('smb', '0x71'): 'Tree Disconnect (0x71)',
1373 ('smb', '0x72'): 'Negotiate Protocol (0x72)',
1374 ('smb', '0x73'): 'Session Setup AndX (0x73)',
1375 ('smb', '0x74'): 'Logoff AndX (0x74)',
1376 ('smb', '0x75'): 'Tree Connect AndX (0x75)',
1377 ('smb', '0xa2'): 'NT Create AndX (0xa2)',
1378 ('smb2', '0'): 'NegotiateProtocol',
1379 ('smb2', '11'): 'Ioctl',
1380 ('smb2', '14'): 'Find',
1381 ('smb2', '16'): 'GetInfo',
1382 ('smb2', '18'): 'Break',
1383 ('smb2', '1'): 'SessionSetup',
1384 ('smb2', '2'): 'SessionLogoff',
1385 ('smb2', '3'): 'TreeConnect',
1386 ('smb2', '4'): 'TreeDisconnect',
1387 ('smb2', '5'): 'Create',
1388 ('smb2', '6'): 'Close',
1389 ('smb2', '8'): 'Read',
1390 ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
1391 ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
1392 'user unknown (0x17)'),
1393 ('srvsvc', '16'): 'NetShareGetInfo',
1394 ('srvsvc', '21'): 'NetSrvGetInfo',
1398 def expand_short_packet(p, timestamp, src, dest, extra):
1399 protocol, opcode = p.split(':', 1)
1400 desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
1401 ip_protocol = IP_PROTOCOLS.get(protocol, '06')
1403 line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
1404 line.extend(extra)
1405 return '\t'.join(line)
1408 def replay(conversations,
1409 host=None,
1410 creds=None,
1411 lp=None,
1412 accounts=None,
1413 dns_rate=0,
1414 duration=None,
1415 **kwargs):
1417 context = ReplayContext(server=host,
1418 creds=creds,
1419 lp=lp,
1420 **kwargs)
1422 if len(accounts) < len(conversations):
1423 print(("we have %d accounts but %d conversations" %
1424 (accounts, conversations)), file=sys.stderr)
1426 cstack = list(zip(
1427 sorted(conversations, key=lambda x: x.start_time, reverse=True),
1428 accounts))
1430 # Set the process group so that the calling scripts are not killed
1431 # when the forked child processes are killed.
1432 os.setpgrp()
1434 start = time.time()
1436 if duration is None:
1437 # end 1 second after the last packet of the last conversation
1438 # to start. Conversations other than the last could still be
1439 # going, but we don't care.
1440 duration = cstack[0][0].packets[-1].timestamp + 1.0
1441 print("We will stop after %.1f seconds" % duration,
1442 file=sys.stderr)
1444 end = start + duration
1446 LOGGER.info("Replaying traffic for %u conversations over %d seconds"
1447 % (len(conversations), duration))
1449 children = {}
1450 if dns_rate:
1451 dns_hammer = DnsHammer(dns_rate, duration)
1452 cstack.append((dns_hammer, None))
1454 try:
1455 while True:
1456 # we spawn a batch, wait for finishers, then spawn another
1457 now = time.time()
1458 batch_end = min(now + 2.0, end)
1459 fork_time = 0.0
1460 fork_n = 0
1461 while cstack:
1462 c, account = cstack.pop()
1463 if c.start_time + start > batch_end:
1464 cstack.append((c, account))
1465 break
1467 st = time.time()
1468 pid = c.replay_in_fork_with_delay(start, context, account)
1469 children[pid] = c
1470 t = time.time()
1471 elapsed = t - st
1472 fork_time += elapsed
1473 fork_n += 1
1474 print("forked %s in pid %s (in %fs)" % (c, pid,
1475 elapsed),
1476 file=sys.stderr)
1478 if fork_n:
1479 print(("forked %d times in %f seconds (avg %f)" %
1480 (fork_n, fork_time, fork_time / fork_n)),
1481 file=sys.stderr)
1482 elif cstack:
1483 debug(2, "no forks in batch ending %f" % batch_end)
1485 while time.time() < batch_end - 1.0:
1486 time.sleep(0.01)
1487 try:
1488 pid, status = os.waitpid(-1, os.WNOHANG)
1489 except OSError as e:
1490 if e.errno != 10: # no child processes
1491 raise
1492 break
1493 if pid:
1494 c = children.pop(pid, None)
1495 print(("process %d finished conversation %s;"
1496 " %d to go" %
1497 (pid, c, len(children))), file=sys.stderr)
1499 if time.time() >= end:
1500 print("time to stop", file=sys.stderr)
1501 break
1503 except Exception:
1504 print("EXCEPTION in parent", file=sys.stderr)
1505 traceback.print_exc()
1506 finally:
1507 for s in (15, 15, 9):
1508 print(("killing %d children with -%d" %
1509 (len(children), s)), file=sys.stderr)
1510 for pid in children:
1511 try:
1512 os.kill(pid, s)
1513 except OSError as e:
1514 if e.errno != 3: # don't fail if it has already died
1515 raise
1516 time.sleep(0.5)
1517 end = time.time() + 1
1518 while children:
1519 try:
1520 pid, status = os.waitpid(-1, os.WNOHANG)
1521 except OSError as e:
1522 if e.errno != 10:
1523 raise
1524 if pid != 0:
1525 c = children.pop(pid, None)
1526 print(("kill -%d %d KILLED conversation %s; "
1527 "%d to go" %
1528 (s, pid, c, len(children))),
1529 file=sys.stderr)
1530 if time.time() >= end:
1531 break
1533 if not children:
1534 break
1535 time.sleep(1)
1537 if children:
1538 print("%d children are missing" % len(children),
1539 file=sys.stderr)
1541 # there may be stragglers that were forked just as ^C was hit
1542 # and don't appear in the list of children. We can get them
1543 # with killpg, but that will also kill us, so this is^H^H would be
1544 # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
1545 # so as not to have to fuss around writing signal handlers.
1546 try:
1547 os.killpg(0, 2)
1548 except KeyboardInterrupt:
1549 print("ignoring fake ^C", file=sys.stderr)
1552 def openLdb(host, creds, lp):
1553 session = system_session()
1554 ldb = SamDB(url="ldap://%s" % host,
1555 session_info=session,
1556 options=['modules:paged_searches'],
1557 credentials=creds,
1558 lp=lp)
1559 return ldb
1562 def ou_name(ldb, instance_id):
1563 """Generate an ou name from the instance id"""
1564 return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
1565 ldb.domain_dn())
1568 def create_ou(ldb, instance_id):
1569 """Create an ou, all created user and machine accounts will belong to it.
1571 This allows all the created resources to be cleaned up easily.
1573 ou = ou_name(ldb, instance_id)
1574 try:
1575 ldb.add({"dn": ou.split(',', 1)[1],
1576 "objectclass": "organizationalunit"})
1577 except LdbError as e:
1578 (status, _) = e.args
1579 # ignore already exists
1580 if status != 68:
1581 raise
1582 try:
1583 ldb.add({"dn": ou,
1584 "objectclass": "organizationalunit"})
1585 except LdbError as e:
1586 (status, _) = e.args
1587 # ignore already exists
1588 if status != 68:
1589 raise
1590 return ou
1593 class ConversationAccounts(object):
1594 """Details of the machine and user accounts associated with a conversation.
1596 def __init__(self, netbios_name, machinepass, username, userpass):
1597 self.netbios_name = netbios_name
1598 self.machinepass = machinepass
1599 self.username = username
1600 self.userpass = userpass
1603 def generate_replay_accounts(ldb, instance_id, number, password):
1604 """Generate a series of unique machine and user account names."""
1606 generate_traffic_accounts(ldb, instance_id, number, password)
1607 accounts = []
1608 for i in range(1, number + 1):
1609 netbios_name = machine_name(instance_id, i)
1610 username = user_name(instance_id, i)
1612 account = ConversationAccounts(netbios_name, password, username,
1613 password)
1614 accounts.append(account)
1615 return accounts
1618 def generate_traffic_accounts(ldb, instance_id, number, password):
1619 """Create the specified number of user and machine accounts.
1621 As accounts are not explicitly deleted between runs. This function starts
1622 with the last account and iterates backwards stopping either when it
1623 finds an already existing account or it has generated all the required
1624 accounts.
1626 print(("Generating machine and conversation accounts, "
1627 "as required for %d conversations" % number),
1628 file=sys.stderr)
1629 added = 0
1630 for i in range(number, 0, -1):
1631 try:
1632 netbios_name = machine_name(instance_id, i)
1633 create_machine_account(ldb, instance_id, netbios_name, password)
1634 added += 1
1635 if added % 50 == 0:
1636 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1637 except LdbError as e:
1638 (status, _) = e.args
1639 if status == 68:
1640 break
1641 else:
1642 raise
1643 if added > 0:
1644 LOGGER.info("Added %d new machine accounts" % added)
1646 added = 0
1647 for i in range(number, 0, -1):
1648 try:
1649 username = user_name(instance_id, i)
1650 create_user_account(ldb, instance_id, username, password)
1651 added += 1
1652 if added % 50 == 0:
1653 LOGGER.info("Created %u/%u users" % (added, number))
1655 except LdbError as e:
1656 (status, _) = e.args
1657 if status == 68:
1658 break
1659 else:
1660 raise
1662 if added > 0:
1663 LOGGER.info("Added %d new user accounts" % added)
1666 def create_machine_account(ldb, instance_id, netbios_name, machinepass,
1667 traffic_account=True):
1668 """Create a machine account via ldap."""
1670 ou = ou_name(ldb, instance_id)
1671 dn = "cn=%s,%s" % (netbios_name, ou)
1672 utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
1674 if traffic_account:
1675 # we set these bits for the machine account otherwise the replayed
1676 # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
1677 account_controls = str(UF_TRUSTED_FOR_DELEGATION |
1678 UF_SERVER_TRUST_ACCOUNT)
1680 else:
1681 account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
1683 ldb.add({
1684 "dn": dn,
1685 "objectclass": "computer",
1686 "sAMAccountName": "%s$" % netbios_name,
1687 "userAccountControl": account_controls,
1688 "unicodePwd": utf16pw})
1691 def create_user_account(ldb, instance_id, username, userpass):
1692 """Create a user account via ldap."""
1693 ou = ou_name(ldb, instance_id)
1694 user_dn = "cn=%s,%s" % (username, ou)
1695 utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
1696 ldb.add({
1697 "dn": user_dn,
1698 "objectclass": "user",
1699 "sAMAccountName": username,
1700 "userAccountControl": str(UF_NORMAL_ACCOUNT),
1701 "unicodePwd": utf16pw
1704 # grant user write permission to do things like write account SPN
1705 sdutils = sd_utils.SDUtils(ldb)
1706 sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
1709 def create_group(ldb, instance_id, name):
1710 """Create a group via ldap."""
1712 ou = ou_name(ldb, instance_id)
1713 dn = "cn=%s,%s" % (name, ou)
1714 ldb.add({
1715 "dn": dn,
1716 "objectclass": "group",
1717 "sAMAccountName": name,
1721 def user_name(instance_id, i):
1722 """Generate a user name based in the instance id"""
1723 return "STGU-%d-%d" % (instance_id, i)
1726 def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
1727 """Seach objectclass, return attr in a set"""
1728 objs = ldb.search(
1729 expression="(objectClass={})".format(objectclass),
1730 attrs=[attr]
1732 return {str(obj[attr]) for obj in objs}
1735 def generate_users(ldb, instance_id, number, password):
1736 """Add users to the server"""
1737 existing_objects = search_objectclass(ldb, objectclass='user')
1738 users = 0
1739 for i in range(number, 0, -1):
1740 name = user_name(instance_id, i)
1741 if name not in existing_objects:
1742 create_user_account(ldb, instance_id, name, password)
1743 users += 1
1744 if users % 50 == 0:
1745 LOGGER.info("Created %u/%u users" % (users, number))
1747 return users
1750 def machine_name(instance_id, i):
1751 """Generate a machine account name from instance id."""
1752 return "STGM-%d-%d" % (instance_id, i)
1755 def generate_machine_accounts(ldb, instance_id, number, password,
1756 traffic_account=True):
1757 """Add machine accounts to the server"""
1758 existing_objects = search_objectclass(ldb, objectclass='computer')
1759 added = 0
1760 for i in range(number, 0, -1):
1761 name = machine_name(instance_id, i)
1762 if name + "$" not in existing_objects:
1763 create_machine_account(ldb, instance_id, name, password,
1764 traffic_account)
1765 added += 1
1766 if added % 50 == 0:
1767 LOGGER.info("Created %u/%u machine accounts" % (added, number))
1769 return added
1772 def group_name(instance_id, i):
1773 """Generate a group name from instance id."""
1774 return "STGG-%d-%d" % (instance_id, i)
1777 def generate_groups(ldb, instance_id, number):
1778 """Create the required number of groups on the server."""
1779 existing_objects = search_objectclass(ldb, objectclass='group')
1780 groups = 0
1781 for i in range(number, 0, -1):
1782 name = group_name(instance_id, i)
1783 if name not in existing_objects:
1784 create_group(ldb, instance_id, name)
1785 groups += 1
1786 if groups % 1000 == 0:
1787 LOGGER.info("Created %u/%u groups" % (groups, number))
1789 return groups
1792 def clean_up_accounts(ldb, instance_id):
1793 """Remove the created accounts and groups from the server."""
1794 ou = ou_name(ldb, instance_id)
1795 try:
1796 ldb.delete(ou, ["tree_delete:1"])
1797 except LdbError as e:
1798 (status, _) = e.args
1799 # ignore does not exist
1800 if status != 32:
1801 raise
1804 def generate_users_and_groups(ldb, instance_id, password,
1805 number_of_users, number_of_groups,
1806 group_memberships, machine_accounts=0,
1807 traffic_accounts=True):
1808 """Generate the required users and groups, allocating the users to
1809 those groups."""
1810 memberships_added = 0
1811 groups_added = 0
1812 computers_added = 0
1814 create_ou(ldb, instance_id)
1816 LOGGER.info("Generating dummy user accounts")
1817 users_added = generate_users(ldb, instance_id, number_of_users, password)
1819 if machine_accounts > 0:
1820 LOGGER.info("Generating dummy machine accounts")
1821 computers_added = generate_machine_accounts(ldb, instance_id,
1822 machine_accounts, password,
1823 traffic_accounts)
1825 if number_of_groups > 0:
1826 LOGGER.info("Generating dummy groups")
1827 groups_added = generate_groups(ldb, instance_id, number_of_groups)
1829 if group_memberships > 0:
1830 LOGGER.info("Assigning users to groups")
1831 assignments = GroupAssignments(number_of_groups,
1832 groups_added,
1833 number_of_users,
1834 users_added,
1835 group_memberships)
1836 LOGGER.info("Adding users to groups")
1837 add_users_to_groups(ldb, instance_id, assignments)
1838 memberships_added = assignments.total()
1840 if (groups_added > 0 and users_added == 0 and
1841 number_of_groups != groups_added):
1842 LOGGER.warning("The added groups will contain no members")
1844 LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
1845 (users_added, computers_added, groups_added,
1846 memberships_added))
1849 class GroupAssignments(object):
1850 def __init__(self, number_of_groups, groups_added, number_of_users,
1851 users_added, group_memberships):
1853 self.count = 0
1854 self.generate_group_distribution(number_of_groups)
1855 self.generate_user_distribution(number_of_users, group_memberships)
1856 self.assignments = self.assign_groups(number_of_groups,
1857 groups_added,
1858 number_of_users,
1859 users_added,
1860 group_memberships)
1862 def cumulative_distribution(self, weights):
1863 # make sure the probabilities conform to a cumulative distribution
1864 # spread between 0.0 and 1.0. Dividing by the weighted total gives each
1865 # probability a proportional share of 1.0. Higher probabilities get a
1866 # bigger share, so are more likely to be picked. We use the cumulative
1867 # value, so we can use random.random() as a simple index into the list
1868 dist = []
1869 total = sum(weights)
1870 cumulative = 0.0
1871 for probability in weights:
1872 cumulative += probability
1873 dist.append(cumulative / total)
1874 return dist
1876 def generate_user_distribution(self, num_users, num_memberships):
1877 """Probability distribution of a user belonging to a group.
1879 # Assign a weighted probability to each user. Use the Pareto
1880 # Distribution so that some users are in a lot of groups, and the
1881 # bulk of users are in only a few groups. If we're assigning a large
1882 # number of group memberships, use a higher shape. This means slightly
1883 # fewer outlying users that are in large numbers of groups. The aim is
1884 # to have no users belonging to more than ~500 groups.
1885 if num_memberships > 5000000:
1886 shape = 3.0
1887 elif num_memberships > 2000000:
1888 shape = 2.5
1889 elif num_memberships > 300000:
1890 shape = 2.25
1891 else:
1892 shape = 1.75
1894 weights = []
1895 for x in range(1, num_users + 1):
1896 p = random.paretovariate(shape)
1897 weights.append(p)
1899 # convert the weights to a cumulative distribution between 0.0 and 1.0
1900 self.user_dist = self.cumulative_distribution(weights)
1902 def generate_group_distribution(self, n):
1903 """Probability distribution of a group containing a user."""
1905 # Assign a weighted probability to each user. Probability decreases
1906 # as the group-ID increases
1907 weights = []
1908 for x in range(1, n + 1):
1909 p = 1 / (x**1.3)
1910 weights.append(p)
1912 # convert the weights to a cumulative distribution between 0.0 and 1.0
1913 self.group_dist = self.cumulative_distribution(weights)
1915 def generate_random_membership(self):
1916 """Returns a randomly generated user-group membership"""
1918 # the list items are cumulative distribution values between 0.0 and
1919 # 1.0, which makes random() a handy way to index the list to get a
1920 # weighted random user/group. (Here the user/group returned are
1921 # zero-based array indexes)
1922 user = bisect.bisect(self.user_dist, random.random())
1923 group = bisect.bisect(self.group_dist, random.random())
1925 return user, group
1927 def users_in_group(self, group):
1928 return self.assignments[group]
1930 def get_groups(self):
1931 return self.assignments.keys()
1933 def assign_groups(self, number_of_groups, groups_added,
1934 number_of_users, users_added, group_memberships):
1935 """Allocate users to groups.
1937 The intention is to have a few users that belong to most groups, while
1938 the majority of users belong to a few groups.
1940 A few groups will contain most users, with the remaining only having a
1941 few users.
1944 assignments = set()
1945 if group_memberships <= 0:
1946 return {}
1948 # Calculate the number of group menberships required
1949 group_memberships = math.ceil(
1950 float(group_memberships) *
1951 (float(users_added) / float(number_of_users)))
1953 existing_users = number_of_users - users_added - 1
1954 existing_groups = number_of_groups - groups_added - 1
1955 while len(assignments) < group_memberships:
1956 user, group = self.generate_random_membership()
1958 if group > existing_groups or user > existing_users:
1959 # the + 1 converts the array index to the corresponding
1960 # group or user number
1961 assignments.add(((user + 1), (group + 1)))
1963 # convert the set into a dictionary, where key=group, value=list-of-
1964 # users-in-group (indexing by group-ID allows us to optimize for
1965 # DB membership writes)
1966 assignment_dict = defaultdict(list)
1967 for (user, group) in assignments:
1968 assignment_dict[group].append(user)
1969 self.count += 1
1971 return assignment_dict
1973 def total(self):
1974 return self.count
1977 def add_users_to_groups(db, instance_id, assignments):
1978 """Takes the assignments of users to groups and applies them to the DB."""
1980 total = assignments.total()
1981 count = 0
1982 added = 0
1984 for group in assignments.get_groups():
1985 users_in_group = assignments.users_in_group(group)
1986 if len(users_in_group) == 0:
1987 continue
1989 # Split up the users into chunks, so we write no more than 1K at a
1990 # time. (Minimizing the DB modifies is more efficient, but writing
1991 # 10K+ users to a single group becomes inefficient memory-wise)
1992 for chunk in range(0, len(users_in_group), 1000):
1993 chunk_of_users = users_in_group[chunk:chunk + 1000]
1994 add_group_members(db, instance_id, group, chunk_of_users)
1996 added += len(chunk_of_users)
1997 count += 1
1998 if count % 50 == 0:
1999 LOGGER.info("Added %u/%u memberships" % (added, total))
2001 def add_group_members(db, instance_id, group, users_in_group):
2002 """Adds the given users to group specified."""
2004 ou = ou_name(db, instance_id)
2006 def build_dn(name):
2007 return("cn=%s,%s" % (name, ou))
2009 group_dn = build_dn(group_name(instance_id, group))
2010 m = ldb.Message()
2011 m.dn = ldb.Dn(db, group_dn)
2013 for user in users_in_group:
2014 user_dn = build_dn(user_name(instance_id, user))
2015 idx = "member-" + str(user)
2016 m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
2018 db.modify(m)
2021 def generate_stats(statsdir, timing_file):
2022 """Generate and print the summary stats for a run."""
2023 first = sys.float_info.max
2024 last = 0
2025 successful = 0
2026 failed = 0
2027 latencies = {}
2028 failures = {}
2029 unique_converations = set()
2030 conversations = 0
2032 if timing_file is not None:
2033 tw = timing_file.write
2034 else:
2035 def tw(x):
2036 pass
2038 tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
2040 for filename in os.listdir(statsdir):
2041 path = os.path.join(statsdir, filename)
2042 with open(path, 'r') as f:
2043 for line in f:
2044 try:
2045 fields = line.rstrip('\n').split('\t')
2046 conversation = fields[1]
2047 protocol = fields[2]
2048 packet_type = fields[3]
2049 latency = float(fields[4])
2050 first = min(float(fields[0]) - latency, first)
2051 last = max(float(fields[0]), last)
2053 if protocol not in latencies:
2054 latencies[protocol] = {}
2055 if packet_type not in latencies[protocol]:
2056 latencies[protocol][packet_type] = []
2058 latencies[protocol][packet_type].append(latency)
2060 if protocol not in failures:
2061 failures[protocol] = {}
2062 if packet_type not in failures[protocol]:
2063 failures[protocol][packet_type] = 0
2065 if fields[5] == 'True':
2066 successful += 1
2067 else:
2068 failed += 1
2069 failures[protocol][packet_type] += 1
2071 if conversation not in unique_converations:
2072 unique_converations.add(conversation)
2073 conversations += 1
2075 tw(line)
2076 except (ValueError, IndexError):
2077 # not a valid line print and ignore
2078 print(line, file=sys.stderr)
2079 pass
2080 duration = last - first
2081 if successful == 0:
2082 success_rate = 0
2083 else:
2084 success_rate = successful / duration
2085 if failed == 0:
2086 failure_rate = 0
2087 else:
2088 failure_rate = failed / duration
2090 print("Total conversations: %10d" % conversations)
2091 print("Successful operations: %10d (%.3f per second)"
2092 % (successful, success_rate))
2093 print("Failed operations: %10d (%.3f per second)"
2094 % (failed, failure_rate))
2096 print("Protocol Op Code Description "
2097 " Count Failed Mean Median "
2098 "95% Range Max")
2100 protocols = sorted(latencies.keys())
2101 for protocol in protocols:
2102 packet_types = sorted(latencies[protocol], key=opcode_key)
2103 for packet_type in packet_types:
2104 values = latencies[protocol][packet_type]
2105 values = sorted(values)
2106 count = len(values)
2107 failed = failures[protocol][packet_type]
2108 mean = sum(values) / count
2109 median = calc_percentile(values, 0.50)
2110 percentile = calc_percentile(values, 0.95)
2111 rng = values[-1] - values[0]
2112 maxv = values[-1]
2113 desc = OP_DESCRIPTIONS.get((protocol, packet_type), '')
2114 if sys.stdout.isatty:
2115 print("%-12s %4s %-35s %12d %12d %12.6f "
2116 "%12.6f %12.6f %12.6f %12.6f"
2117 % (protocol,
2118 packet_type,
2119 desc,
2120 count,
2121 failed,
2122 mean,
2123 median,
2124 percentile,
2125 rng,
2126 maxv))
2127 else:
2128 print("%s\t%s\t%s\t%d\t%d\t%f\t%f\t%f\t%f\t%f"
2129 % (protocol,
2130 packet_type,
2131 desc,
2132 count,
2133 failed,
2134 mean,
2135 median,
2136 percentile,
2137 rng,
2138 maxv))
2141 def opcode_key(v):
2142 """Sort key for the operation code to ensure that it sorts numerically"""
2143 try:
2144 return "%03d" % int(v)
2145 except:
2146 return v
2149 def calc_percentile(values, percentile):
2150 """Calculate the specified percentile from the list of values.
2152 Assumes the list is sorted in ascending order.
2155 if not values:
2156 return 0
2157 k = (len(values) - 1) * percentile
2158 f = math.floor(k)
2159 c = math.ceil(k)
2160 if f == c:
2161 return values[int(k)]
2162 d0 = values[int(f)] * (c - k)
2163 d1 = values[int(c)] * (k - f)
2164 return d0 + d1
2167 def mk_masked_dir(*path):
2168 """In a testenv we end up with 0777 diectories that look an alarming
2169 green colour with ls. Use umask to avoid that."""
2170 d = os.path.join(*path)
2171 mask = os.umask(0o077)
2172 os.mkdir(d)
2173 os.umask(mask)
2174 return d