msctf: Maintain context reference in ranges.
[wine.git] / dlls / rpcrt4 / rpc_assoc.c
blobab865c2de14a33d61554de1a793ddf2e2820b8dc
1 /*
2 * Associations
4 * Copyright 2007 Robert Shearman (for CodeWeavers)
6 * This library is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
11 * This library is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with this library; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
22 #include <stdarg.h>
23 #include <assert.h>
25 #include "rpc.h"
26 #include "rpcndr.h"
28 #include "wine/debug.h"
30 #include "rpc_binding.h"
31 #include "rpc_assoc.h"
32 #include "rpc_message.h"
34 WINE_DEFAULT_DEBUG_CHANNEL(rpc);
36 static CRITICAL_SECTION assoc_list_cs;
37 static CRITICAL_SECTION_DEBUG assoc_list_cs_debug =
39 0, 0, &assoc_list_cs,
40 { &assoc_list_cs_debug.ProcessLocksList, &assoc_list_cs_debug.ProcessLocksList },
41 0, 0, { (DWORD_PTR)(__FILE__ ": assoc_list_cs") }
43 static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 };
45 static struct list client_assoc_list = LIST_INIT(client_assoc_list);
46 static struct list server_assoc_list = LIST_INIT(server_assoc_list);
48 static LONG last_assoc_group_id;
50 typedef struct _RpcContextHandle
52 struct list entry;
53 void *user_context;
54 NDR_RUNDOWN rundown_routine;
55 void *ctx_guard;
56 UUID uuid;
57 CRITICAL_SECTION lock;
58 unsigned int refs;
59 } RpcContextHandle;
61 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle);
63 static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr,
64 LPCSTR Endpoint, LPCWSTR NetworkOptions,
65 RpcAssoc **assoc_out)
67 RpcAssoc *assoc;
68 assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc));
69 if (!assoc)
70 return RPC_S_OUT_OF_RESOURCES;
71 assoc->refs = 1;
72 list_init(&assoc->free_connection_pool);
73 list_init(&assoc->context_handle_list);
74 InitializeCriticalSection(&assoc->cs);
75 assoc->cs.DebugInfo->Spare[0] = (DWORD_PTR)(__FILE__ ": RpcAssoc.cs");
76 assoc->Protseq = RPCRT4_strdupA(Protseq);
77 assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr);
78 assoc->Endpoint = RPCRT4_strdupA(Endpoint);
79 assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL;
80 assoc->assoc_group_id = 0;
81 assoc->connection_cnt = 0;
82 UuidCreate(&assoc->http_uuid);
83 list_init(&assoc->entry);
84 *assoc_out = assoc;
85 return RPC_S_OK;
88 static BOOL compare_networkoptions(LPCWSTR opts1, LPCWSTR opts2)
90 if ((opts1 == NULL) && (opts2 == NULL))
91 return TRUE;
92 if ((opts1 == NULL) || (opts2 == NULL))
93 return FALSE;
94 return !wcscmp(opts1, opts2);
97 RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
98 LPCSTR Endpoint, LPCWSTR NetworkOptions,
99 RpcAssoc **assoc_out)
101 RpcAssoc *assoc;
102 RPC_STATUS status;
104 EnterCriticalSection(&assoc_list_cs);
105 LIST_FOR_EACH_ENTRY(assoc, &client_assoc_list, RpcAssoc, entry)
107 if (!strcmp(Protseq, assoc->Protseq) &&
108 !strcmp(NetworkAddr, assoc->NetworkAddr) &&
109 !strcmp(Endpoint, assoc->Endpoint) &&
110 compare_networkoptions(NetworkOptions, assoc->NetworkOptions))
112 assoc->refs++;
113 *assoc_out = assoc;
114 LeaveCriticalSection(&assoc_list_cs);
115 TRACE("using existing assoc %p\n", assoc);
116 return RPC_S_OK;
120 status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
121 if (status != RPC_S_OK)
123 LeaveCriticalSection(&assoc_list_cs);
124 return status;
126 list_add_head(&client_assoc_list, &assoc->entry);
127 *assoc_out = assoc;
129 LeaveCriticalSection(&assoc_list_cs);
131 TRACE("new assoc %p\n", assoc);
133 return RPC_S_OK;
136 RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
137 LPCSTR Endpoint, LPCWSTR NetworkOptions,
138 ULONG assoc_gid,
139 RpcAssoc **assoc_out)
141 RpcAssoc *assoc;
142 RPC_STATUS status;
144 EnterCriticalSection(&assoc_list_cs);
145 if (assoc_gid)
147 LIST_FOR_EACH_ENTRY(assoc, &server_assoc_list, RpcAssoc, entry)
149 /* FIXME: NetworkAddr shouldn't be NULL */
150 if (assoc->assoc_group_id == assoc_gid &&
151 !strcmp(Protseq, assoc->Protseq) &&
152 (!NetworkAddr || !assoc->NetworkAddr || !strcmp(NetworkAddr, assoc->NetworkAddr)) &&
153 !strcmp(Endpoint, assoc->Endpoint) &&
154 ((!assoc->NetworkOptions == !NetworkOptions) &&
155 (!NetworkOptions || !wcscmp(NetworkOptions, assoc->NetworkOptions))))
157 assoc->refs++;
158 *assoc_out = assoc;
159 LeaveCriticalSection(&assoc_list_cs);
160 TRACE("using existing assoc %p\n", assoc);
161 return RPC_S_OK;
164 *assoc_out = NULL;
165 LeaveCriticalSection(&assoc_list_cs);
166 return RPC_S_NO_CONTEXT_AVAILABLE;
169 status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
170 if (status != RPC_S_OK)
172 LeaveCriticalSection(&assoc_list_cs);
173 return status;
175 assoc->assoc_group_id = InterlockedIncrement(&last_assoc_group_id);
176 list_add_head(&server_assoc_list, &assoc->entry);
177 *assoc_out = assoc;
179 LeaveCriticalSection(&assoc_list_cs);
181 TRACE("new assoc %p\n", assoc);
183 return RPC_S_OK;
186 ULONG RpcAssoc_Release(RpcAssoc *assoc)
188 ULONG refs;
190 EnterCriticalSection(&assoc_list_cs);
191 refs = --assoc->refs;
192 if (!refs)
193 list_remove(&assoc->entry);
194 LeaveCriticalSection(&assoc_list_cs);
196 if (!refs)
198 RpcConnection *Connection, *cursor2;
199 RpcContextHandle *context_handle, *context_handle_cursor;
201 TRACE("destroying assoc %p\n", assoc);
203 LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
205 list_remove(&Connection->conn_pool_entry);
206 RPCRT4_ReleaseConnection(Connection);
209 LIST_FOR_EACH_ENTRY_SAFE(context_handle, context_handle_cursor, &assoc->context_handle_list, RpcContextHandle, entry)
210 RpcContextHandle_Destroy(context_handle);
212 HeapFree(GetProcessHeap(), 0, assoc->NetworkOptions);
213 HeapFree(GetProcessHeap(), 0, assoc->Endpoint);
214 HeapFree(GetProcessHeap(), 0, assoc->NetworkAddr);
215 HeapFree(GetProcessHeap(), 0, assoc->Protseq);
217 assoc->cs.DebugInfo->Spare[0] = 0;
218 DeleteCriticalSection(&assoc->cs);
220 HeapFree(GetProcessHeap(), 0, assoc);
223 return refs;
226 #define ROUND_UP(value, alignment) (((value) + ((alignment) - 1)) & ~((alignment)-1))
228 static RPC_STATUS RpcAssoc_BindConnection(const RpcAssoc *assoc, RpcConnection *conn,
229 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
230 const RPC_SYNTAX_IDENTIFIER *TransferSyntax)
232 RpcPktHdr *hdr;
233 RpcPktHdr *response_hdr;
234 RPC_MESSAGE msg;
235 RPC_STATUS status;
236 unsigned char *auth_data = NULL;
237 ULONG auth_length;
239 TRACE("sending bind request to server\n");
241 hdr = RPCRT4_BuildBindHeader(NDR_LOCAL_DATA_REPRESENTATION,
242 RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE,
243 assoc->assoc_group_id,
244 InterfaceId, TransferSyntax);
246 status = RPCRT4_Send(conn, hdr, NULL, 0);
247 RPCRT4_FreeHeader(hdr);
248 if (status != RPC_S_OK)
249 return status;
251 status = RPCRT4_ReceiveWithAuth(conn, &response_hdr, &msg, &auth_data, &auth_length);
252 if (status != RPC_S_OK)
254 ERR("receive failed with error %d\n", status);
255 return status;
258 switch (response_hdr->common.ptype)
260 case PKT_BIND_ACK:
262 RpcAddressString *server_address = msg.Buffer;
263 if ((msg.BufferLength >= FIELD_OFFSET(RpcAddressString, string[0])) ||
264 (msg.BufferLength >= ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4)))
266 unsigned short remaining = msg.BufferLength -
267 ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4);
268 RpcResultList *results = (RpcResultList*)((ULONG_PTR)server_address +
269 ROUND_UP(FIELD_OFFSET(RpcAddressString, string[server_address->length]), 4));
270 if ((results->num_results == 1) &&
271 (remaining >= FIELD_OFFSET(RpcResultList, results[results->num_results])))
273 switch (results->results[0].result)
275 case RESULT_ACCEPT:
276 /* respond to authorization request */
277 if (auth_length > sizeof(RpcAuthVerifier))
278 status = RPCRT4_ClientConnectionAuth(conn,
279 auth_data + sizeof(RpcAuthVerifier),
280 auth_length);
281 if (status == RPC_S_OK)
283 conn->assoc_group_id = response_hdr->bind_ack.assoc_gid;
284 conn->MaxTransmissionSize = response_hdr->bind_ack.max_tsize;
285 conn->ActiveInterface = *InterfaceId;
287 break;
288 case RESULT_PROVIDER_REJECTION:
289 switch (results->results[0].reason)
291 case REASON_ABSTRACT_SYNTAX_NOT_SUPPORTED:
292 ERR("syntax %s, %d.%d not supported\n",
293 debugstr_guid(&InterfaceId->SyntaxGUID),
294 InterfaceId->SyntaxVersion.MajorVersion,
295 InterfaceId->SyntaxVersion.MinorVersion);
296 status = RPC_S_UNKNOWN_IF;
297 break;
298 case REASON_TRANSFER_SYNTAXES_NOT_SUPPORTED:
299 ERR("transfer syntax not supported\n");
300 status = RPC_S_SERVER_UNAVAILABLE;
301 break;
302 case REASON_NONE:
303 default:
304 status = RPC_S_CALL_FAILED_DNE;
306 break;
307 case RESULT_USER_REJECTION:
308 default:
309 ERR("rejection result %d\n", results->results[0].result);
310 status = RPC_S_CALL_FAILED_DNE;
313 else
315 ERR("incorrect results size\n");
316 status = RPC_S_CALL_FAILED_DNE;
319 else
321 ERR("bind ack packet too small (%d)\n", msg.BufferLength);
322 status = RPC_S_PROTOCOL_ERROR;
324 break;
326 case PKT_BIND_NACK:
327 switch (response_hdr->bind_nack.reject_reason)
329 case REJECT_LOCAL_LIMIT_EXCEEDED:
330 case REJECT_TEMPORARY_CONGESTION:
331 ERR("server too busy\n");
332 status = RPC_S_SERVER_TOO_BUSY;
333 break;
334 case REJECT_PROTOCOL_VERSION_NOT_SUPPORTED:
335 ERR("protocol version not supported\n");
336 status = RPC_S_PROTOCOL_ERROR;
337 break;
338 case REJECT_UNKNOWN_AUTHN_SERVICE:
339 ERR("unknown authentication service\n");
340 status = RPC_S_UNKNOWN_AUTHN_SERVICE;
341 break;
342 case REJECT_INVALID_CHECKSUM:
343 ERR("invalid checksum\n");
344 status = RPC_S_ACCESS_DENIED;
345 break;
346 default:
347 ERR("rejected bind for reason %d\n", response_hdr->bind_nack.reject_reason);
348 status = RPC_S_CALL_FAILED_DNE;
350 break;
351 default:
352 ERR("wrong packet type received %d\n", response_hdr->common.ptype);
353 status = RPC_S_PROTOCOL_ERROR;
354 break;
357 I_RpcFree(msg.Buffer);
358 RPCRT4_FreeHeader(response_hdr);
359 HeapFree(GetProcessHeap(), 0, auth_data);
360 return status;
363 static RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc,
364 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
365 const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo,
366 const RpcQualityOfService *QOS)
368 RpcConnection *Connection;
369 EnterCriticalSection(&assoc->cs);
370 /* try to find a compatible connection from the connection pool */
371 LIST_FOR_EACH_ENTRY(Connection, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
373 if (!memcmp(&Connection->ActiveInterface, InterfaceId,
374 sizeof(RPC_SYNTAX_IDENTIFIER)) &&
375 RpcAuthInfo_IsEqual(Connection->AuthInfo, AuthInfo) &&
376 RpcQualityOfService_IsEqual(Connection->QOS, QOS))
378 list_remove(&Connection->conn_pool_entry);
379 LeaveCriticalSection(&assoc->cs);
380 TRACE("got connection from pool %p\n", Connection);
381 return Connection;
385 LeaveCriticalSection(&assoc->cs);
386 return NULL;
389 RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc,
390 const RPC_SYNTAX_IDENTIFIER *InterfaceId,
391 const RPC_SYNTAX_IDENTIFIER *TransferSyntax, RpcAuthInfo *AuthInfo,
392 RpcQualityOfService *QOS, LPCWSTR CookieAuth,
393 RpcConnection **Connection, BOOL *from_cache)
395 RpcConnection *NewConnection;
396 RPC_STATUS status;
398 *Connection = RpcAssoc_GetIdleConnection(assoc, InterfaceId, TransferSyntax, AuthInfo, QOS);
399 if (*Connection) {
400 TRACE("return idle connection %p for association %p\n", *Connection, assoc);
401 if (from_cache) *from_cache = TRUE;
402 return RPC_S_OK;
405 /* create a new connection */
406 status = RPCRT4_CreateConnection(&NewConnection, FALSE /* is this a server connection? */,
407 assoc->Protseq, assoc->NetworkAddr,
408 assoc->Endpoint, assoc->NetworkOptions,
409 AuthInfo, QOS, CookieAuth);
410 if (status != RPC_S_OK)
411 return status;
413 NewConnection->assoc = assoc;
414 status = RPCRT4_OpenClientConnection(NewConnection);
415 if (status != RPC_S_OK)
417 RPCRT4_ReleaseConnection(NewConnection);
418 return status;
421 status = RpcAssoc_BindConnection(assoc, NewConnection, InterfaceId, TransferSyntax);
422 if (status != RPC_S_OK)
424 RPCRT4_ReleaseConnection(NewConnection);
425 return status;
428 InterlockedIncrement(&assoc->connection_cnt);
430 TRACE("return new connection %p for association %p\n", *Connection, assoc);
431 *Connection = NewConnection;
432 if (from_cache) *from_cache = FALSE;
433 return RPC_S_OK;
436 void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection)
438 assert(!Connection->server);
439 Connection->async_state = NULL;
440 EnterCriticalSection(&assoc->cs);
441 if (!assoc->assoc_group_id) assoc->assoc_group_id = Connection->assoc_group_id;
442 list_add_head(&assoc->free_connection_pool, &Connection->conn_pool_entry);
443 LeaveCriticalSection(&assoc->cs);
446 void RpcAssoc_ConnectionReleased(RpcAssoc *assoc)
448 if (InterlockedDecrement(&assoc->connection_cnt))
449 return;
451 TRACE("Last %p connection released\n", assoc);
452 assoc->assoc_group_id = 0;
455 RPC_STATUS RpcServerAssoc_AllocateContextHandle(RpcAssoc *assoc, void *CtxGuard,
456 NDR_SCONTEXT *SContext)
458 RpcContextHandle *context_handle;
460 context_handle = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*context_handle));
461 if (!context_handle)
462 return RPC_S_OUT_OF_MEMORY;
464 context_handle->ctx_guard = CtxGuard;
465 InitializeCriticalSection(&context_handle->lock);
466 context_handle->lock.DebugInfo->Spare[0] = (DWORD_PTR)(__FILE__ ": RpcContextHandle.lock");
467 context_handle->refs = 1;
469 /* lock here to mirror unmarshall, so we don't need to special-case the
470 * freeing of a non-marshalled context handle */
471 EnterCriticalSection(&context_handle->lock);
473 EnterCriticalSection(&assoc->cs);
474 list_add_tail(&assoc->context_handle_list, &context_handle->entry);
475 LeaveCriticalSection(&assoc->cs);
477 *SContext = (NDR_SCONTEXT)context_handle;
478 return RPC_S_OK;
481 BOOL RpcContextHandle_IsGuardCorrect(NDR_SCONTEXT SContext, void *CtxGuard)
483 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
484 return context_handle->ctx_guard == CtxGuard;
487 RPC_STATUS RpcServerAssoc_FindContextHandle(RpcAssoc *assoc, const UUID *uuid,
488 void *CtxGuard, ULONG Flags, NDR_SCONTEXT *SContext)
490 RpcContextHandle *context_handle;
492 EnterCriticalSection(&assoc->cs);
493 LIST_FOR_EACH_ENTRY(context_handle, &assoc->context_handle_list, RpcContextHandle, entry)
495 if (RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard) &&
496 !memcmp(&context_handle->uuid, uuid, sizeof(*uuid)))
498 *SContext = (NDR_SCONTEXT)context_handle;
499 if (context_handle->refs++)
501 LeaveCriticalSection(&assoc->cs);
502 TRACE("found %p\n", context_handle);
503 EnterCriticalSection(&context_handle->lock);
504 return RPC_S_OK;
508 LeaveCriticalSection(&assoc->cs);
510 ERR("no context handle found for uuid %s, guard %p\n",
511 debugstr_guid(uuid), CtxGuard);
512 return ERROR_INVALID_HANDLE;
515 RPC_STATUS RpcServerAssoc_UpdateContextHandle(RpcAssoc *assoc,
516 NDR_SCONTEXT SContext,
517 void *CtxGuard,
518 NDR_RUNDOWN rundown_routine)
520 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
521 RPC_STATUS status;
523 if (!RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard))
524 return ERROR_INVALID_HANDLE;
526 EnterCriticalSection(&assoc->cs);
527 if (UuidIsNil(&context_handle->uuid, &status))
529 /* add a ref for the data being valid */
530 context_handle->refs++;
531 UuidCreate(&context_handle->uuid);
532 context_handle->rundown_routine = rundown_routine;
533 TRACE("allocated uuid %s for context handle %p\n",
534 debugstr_guid(&context_handle->uuid), context_handle);
536 LeaveCriticalSection(&assoc->cs);
538 return RPC_S_OK;
541 void RpcContextHandle_GetUuid(NDR_SCONTEXT SContext, UUID *uuid)
543 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
544 *uuid = context_handle->uuid;
547 static void RpcContextHandle_Destroy(RpcContextHandle *context_handle)
549 TRACE("freeing %p\n", context_handle);
551 if (context_handle->user_context && context_handle->rundown_routine)
553 TRACE("calling rundown routine %p with user context %p\n",
554 context_handle->rundown_routine, context_handle->user_context);
555 context_handle->rundown_routine(context_handle->user_context);
558 context_handle->lock.DebugInfo->Spare[0] = 0;
559 DeleteCriticalSection(&context_handle->lock);
561 HeapFree(GetProcessHeap(), 0, context_handle);
564 unsigned int RpcServerAssoc_ReleaseContextHandle(RpcAssoc *assoc, NDR_SCONTEXT SContext, BOOL release_lock)
566 RpcContextHandle *context_handle = (RpcContextHandle *)SContext;
567 unsigned int refs;
569 if (release_lock)
570 LeaveCriticalSection(&context_handle->lock);
572 EnterCriticalSection(&assoc->cs);
573 refs = --context_handle->refs;
574 if (!refs)
575 list_remove(&context_handle->entry);
576 LeaveCriticalSection(&assoc->cs);
578 if (!refs)
579 RpcContextHandle_Destroy(context_handle);
581 return refs;