ctdb-server: Use find_public_ip_vnn() in a couple of extra places
[Samba.git] / python / samba / tests / audit_log_base.py
blobe82b9bedf5a56aa1c6fc1e3cd89fbb1e94f5d638
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Andrew Bartlett <abartlet@samba.org> 2017
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
14 # You should have received a copy of the GNU General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
17 """Tests for DSDB audit logging.
18 """
20 import samba.tests
21 from samba.messaging import Messaging
22 from samba.dcerpc.messaging import MSG_AUTH_LOG, AUTH_EVENT_NAME
23 from samba.param import LoadParm
24 from samba import string_is_guid
25 import time
26 import json
27 import os
30 def getAudit(message):
31 if "type" not in message:
32 return None
34 type = message["type"]
35 audit = message[type]
36 return audit
39 class AuditLogTestBase(samba.tests.TestCase):
41 def setUp(self):
42 super().setUp()
44 # connect to the server's messaging bus (we need to explicitly load a
45 # different smb.conf here, because in all other respects this test
46 # wants to act as a separate remote client)
47 server_conf = os.getenv('SERVERCONFFILE')
48 if server_conf:
49 lp_ctx = LoadParm(filename_for_non_global_lp=server_conf)
50 else:
51 lp_ctx = self.get_loadparm()
52 self.msg_ctx = Messaging((1,), lp_ctx=lp_ctx)
53 self.msg_ctx.irpc_add_name(self.event_type)
55 # Now switch back to using the client-side smb.conf. The tests will
56 # use the first interface in the client.conf (we need to strip off
57 # the subnet mask portion)
58 lp_ctx = self.get_loadparm()
59 client_ip_and_mask = lp_ctx.get('interfaces')[0]
60 client_ip = client_ip_and_mask.split('/')[0]
62 # the messaging ctx is the server's view of the world, so our own
63 # client IP will be the remoteAddress when connections are logged
64 self.remoteAddress = client_ip
67 # Check the remote address of a message against the one being used
68 # for the tests.
70 def isRemote(message):
71 audit = getAudit(message)
72 if audit is None:
73 return False
75 remote = audit["remoteAddress"]
76 if remote is None:
77 return False
79 try:
80 addr = remote.split(":")
81 return addr[1] == self.remoteAddress
82 except IndexError:
83 return False
85 def messageHandler(context, msgType, src, message):
86 # This does not look like sub unit output and it
87 # makes these tests much easier to debug.
88 print(message)
89 jsonMsg = json.loads(message)
90 if ((jsonMsg["type"] == "passwordChange" or
91 jsonMsg["type"] == "dsdbChange" or
92 jsonMsg["type"] == "groupChange") and
93 isRemote(jsonMsg)):
94 context["messages"].append(jsonMsg)
95 elif jsonMsg["type"] == "dsdbTransaction":
96 context["txnMessage"] = jsonMsg
98 self.context = {"messages": [], "txnMessage": None}
99 self.msg_handler_and_context = (messageHandler, self.context)
100 self.msg_ctx.register(self.msg_handler_and_context,
101 msg_type=self.message_type)
103 self.msg_ctx.irpc_add_name(AUTH_EVENT_NAME)
105 def authHandler(context, msgType, src, message):
106 jsonMsg = json.loads(message)
107 if jsonMsg["type"] == "Authorization" and isRemote(jsonMsg):
108 # This does not look like sub unit output and it
109 # makes these tests much easier to debug.
110 print(message)
111 context["sessionId"] = jsonMsg["Authorization"]["sessionId"]
112 context["serviceDescription"] =\
113 jsonMsg["Authorization"]["serviceDescription"]
115 self.auth_context = {"sessionId": "", "serviceDescription": ""}
116 self.auth_handler_and_context = (authHandler, self.auth_context)
117 self.msg_ctx.register(self.auth_handler_and_context,
118 msg_type=MSG_AUTH_LOG)
120 self.discardMessages()
122 self.server = os.environ["SERVER"]
123 self.connection = None
125 def tearDown(self):
126 self.discardMessages()
127 self.msg_ctx.irpc_remove_name(self.event_type)
128 self.msg_ctx.irpc_remove_name(AUTH_EVENT_NAME)
129 self.msg_ctx.deregister(self.msg_handler_and_context,
130 msg_type=self.message_type)
131 self.msg_ctx.deregister(self.auth_handler_and_context,
132 msg_type=MSG_AUTH_LOG)
134 super().tearDown()
136 def haveExpected(self, expected, dn):
137 if dn is None:
138 return len(self.context["messages"]) >= expected
139 else:
140 received = 0
141 for msg in self.context["messages"]:
142 audit = getAudit(msg)
143 if audit["dn"].lower() == dn.lower():
144 received += 1
145 if received >= expected:
146 return True
147 return False
149 def waitForMessages(self, number, connection=None, dn=None):
150 """Wait for all the expected messages to arrive
151 The connection is passed through to keep the connection alive
152 until all the logging messages have been received.
155 self.connection = connection
157 start_time = time.time()
158 while not self.haveExpected(number, dn):
159 self.msg_ctx.loop_once(0.1)
160 if time.time() - start_time > 1:
161 self.connection = None
162 print("Timed out")
163 return []
165 self.connection = None
166 if dn is None:
167 return self.context["messages"]
169 messages = []
170 for msg in self.context["messages"]:
171 audit = getAudit(msg)
172 if audit["dn"].lower() == dn.lower():
173 messages.append(msg)
174 return messages
176 # Discard any previously queued messages.
177 def discardMessages(self):
178 messages = self.context["messages"]
180 while True:
181 messages.clear()
182 self.context["txnMessage"] = None
184 # tevent presumably has other tasks to run, so we might need two or
185 # three loops before a message comes through.
186 for _ in range(5):
187 self.msg_ctx.loop_once(0.001)
189 if not messages and self.context["txnMessage"] is None:
190 # No new messages. We’ve probably got them all.
191 break
193 def is_guid(self, guid):
194 """Is the supplied GUID string correctly formatted"""
195 return string_is_guid(guid)
197 def get_session(self):
198 return self.auth_context["sessionId"]
200 def get_service_description(self):
201 return self.auth_context["serviceDescription"]