Bug 1888590 - Mark some subtests on trusted-types-event-handlers.html as failing...
[gecko.git] / third_party / content_analysis_sdk / agent_improvements.patch
blobccded0803282bc99e55a5fc85a68ab3a1e7e086a
1 commit 4ad63eb3aa65ce7baa08190aac2770540dc25f43
2 Author: Greg Stoll <gstoll@mozilla.com>
3 Date: Wed, 27 Mar 2024 12:13:56 -0500
5 Mozilla improvements to content_analysis_sdk
7 - add ability for demo agent to block/warn/report specific regexes
8 - add ability for demo agent to chose a sequence of delays to apply
9 - add a "misbehaving" demo agent for use in gtests
11 diff --git a/CMakeLists.txt b/CMakeLists.txt
12 index 39477223f031c..5dacc81031117 100644
13 --- a/CMakeLists.txt
14 +++ b/CMakeLists.txt
15 @@ -203,6 +203,7 @@ add_executable(agent
16 ./demo/agent.cc
17 ./demo/handler.h
19 +target_compile_features(agent PRIVATE cxx_std_17)
20 target_include_directories(agent PRIVATE ${AGENT_INCLUDES})
21 target_link_libraries(agent PRIVATE cac_agent)
23 diff --git a/agent/src/event_win.h b/agent/src/event_win.h
24 index 9f8b6903566f2..f631f693dcd9c 100644
25 --- a/agent/src/event_win.h
26 +++ b/agent/src/event_win.h
27 @@ -28,6 +28,12 @@ class ContentAnalysisEventWin : public ContentAnalysisEventBase {
28 ResultCode Close() override;
29 ResultCode Send() override;
30 std::string DebugString() const override;
31 + std::string SerializeStringToSendToBrowser() {
32 + return agent_to_chrome()->SerializeAsString();
33 + }
34 + void SetResponseSent() { response_sent_ = true; }
36 + HANDLE Pipe() const { return hPipe_; }
38 private:
39 void Shutdown();
40 diff --git a/browser/src/client_win.cc b/browser/src/client_win.cc
41 index 9d3d7e8c52662..039946d131398 100644
42 --- a/browser/src/client_win.cc
43 +++ b/browser/src/client_win.cc
44 @@ -418,7 +418,11 @@ DWORD ClientWin::ConnectToPipe(const std::string& pipename, HANDLE* handle) {
46 void ClientWin::Shutdown() {
47 if (hPipe_ != INVALID_HANDLE_VALUE) {
48 - FlushFileBuffers(hPipe_);
49 + // TODO: This trips the LateWriteObserver. We could move this earlier
50 + // (before the LateWriteObserver is created) or just remove it, although
51 + // the later could mean an ACK message is not processed by the agent
52 + // in time.
53 + // FlushFileBuffers(hPipe_);
54 CloseHandle(hPipe_);
55 hPipe_ = INVALID_HANDLE_VALUE;
57 diff --git a/demo/agent.cc b/demo/agent.cc
58 index ff8b93f647ebd..3e168b0915a0c 100644
59 --- a/demo/agent.cc
60 +++ b/demo/agent.cc
61 @@ -2,12 +2,18 @@
62 // Use of this source code is governed by a BSD-style license that can be
63 // found in the LICENSE file.
65 +#include <algorithm>
66 #include <fstream>
67 #include <iostream>
68 #include <string>
69 +#include <regex>
70 +#include <vector>
72 #include "content_analysis/sdk/analysis_agent.h"
73 #include "demo/handler.h"
74 +#include "demo/handler_misbehaving.h"
76 +using namespace content_analysis::sdk;
78 // Different paths are used depending on whether this agent should run as a
79 // use specific agent or not. These values are chosen to match the test
80 @@ -19,19 +25,50 @@ constexpr char kPathSystem[] = "brcm_chrm_cas";
81 std::string path = kPathSystem;
82 bool use_queue = false;
83 bool user_specific = false;
84 -unsigned long delay = 0; // In seconds.
85 +std::vector<unsigned long> delays = {0}; // In seconds.
86 unsigned long num_threads = 8u;
87 std::string save_print_data_path = "";
88 +RegexArray toBlock, toWarn, toReport;
89 +static bool useMisbehavingHandler = false;
90 +static std::string modeStr;
92 // Command line parameters.
93 -constexpr const char* kArgDelaySpecific = "--delay=";
94 +constexpr const char* kArgDelaySpecific = "--delays=";
95 constexpr const char* kArgPath = "--path=";
96 constexpr const char* kArgQueued = "--queued";
97 constexpr const char* kArgThreads = "--threads=";
98 constexpr const char* kArgUserSpecific = "--user";
99 +constexpr const char* kArgToBlock = "--toblock=";
100 +constexpr const char* kArgToWarn = "--towarn=";
101 +constexpr const char* kArgToReport = "--toreport=";
102 +constexpr const char* kArgMisbehave = "--misbehave=";
103 constexpr const char* kArgHelp = "--help";
104 constexpr const char* kArgSavePrintRequestDataTo = "--save-print-request-data-to=";
106 +std::map<std::string, Mode> sStringToMode = {
107 +#define AGENT_MODE(name) {#name, Mode::Mode_##name},
108 +#include "modes.h"
109 +#undef AGENT_MODE
112 +std::map<Mode, std::string> sModeToString = {
113 +#define AGENT_MODE(name) {Mode::Mode_##name, #name},
114 +#include "modes.h"
115 +#undef AGENT_MODE
118 +std::vector<std::pair<std::string, std::regex>>
119 +ParseRegex(const std::string str) {
120 + std::vector<std::pair<std::string, std::regex>> ret;
121 + for (auto it = str.begin(); it != str.end(); /* nop */) {
122 + auto it2 = std::find(it, str.end(), ',');
123 + ret.push_back(std::make_pair(std::string(it, it2), std::regex(it, it2)));
124 + it = it2 == str.end() ? it2 : it2 + 1;
127 + return ret;
130 bool ParseCommandLine(int argc, char* argv[]) {
131 for (int i = 1; i < argc; ++i) {
132 const std::string arg = argv[i];
133 @@ -44,16 +81,38 @@ bool ParseCommandLine(int argc, char* argv[]) {
134 path = kPathUser;
135 user_specific = true;
136 } else if (arg.find(kArgDelaySpecific) == 0) {
137 - delay = std::stoul(arg.substr(strlen(kArgDelaySpecific)));
138 + std::string delaysStr = arg.substr(strlen(kArgDelaySpecific));
139 + delays.clear();
140 + size_t posStart = 0, posEnd;
141 + unsigned long delay;
142 + while ((posEnd = delaysStr.find(',', posStart)) != std::string::npos) {
143 + delay = std::stoul(delaysStr.substr(posStart, posEnd - posStart));
144 + if (delay > 30) {
145 + delay = 30;
147 + delays.push_back(delay);
148 + posStart = posEnd + 1;
150 + delay = std::stoul(delaysStr.substr(posStart));
151 if (delay > 30) {
152 delay = 30;
154 + delays.push_back(delay);
155 } else if (arg.find(kArgPath) == 0) {
156 path = arg.substr(strlen(kArgPath));
157 } else if (arg.find(kArgQueued) == 0) {
158 use_queue = true;
159 } else if (arg.find(kArgThreads) == 0) {
160 num_threads = std::stoul(arg.substr(strlen(kArgThreads)));
161 + } else if (arg.find(kArgToBlock) == 0) {
162 + toBlock = ParseRegex(arg.substr(strlen(kArgToBlock)));
163 + } else if (arg.find(kArgToWarn) == 0) {
164 + toWarn = ParseRegex(arg.substr(strlen(kArgToWarn)));
165 + } else if (arg.find(kArgToReport) == 0) {
166 + toReport = ParseRegex(arg.substr(strlen(kArgToReport)));
167 + } else if (arg.find(kArgMisbehave) == 0) {
168 + modeStr = arg.substr(strlen(kArgMisbehave));
169 + useMisbehavingHandler = true;
170 } else if (arg.find(kArgHelp) == 0) {
171 return false;
172 } else if (arg.find(kArgSavePrintRequestDataTo) == 0) {
173 @@ -72,13 +131,17 @@ void PrintHelp() {
174 << "A simple agent to process content analysis requests." << std::endl
175 << "Data containing the string 'block' blocks the request data from being used." << std::endl
176 << std::endl << "Options:" << std::endl
177 - << kArgDelaySpecific << "<delay> : Add a delay to request processing in seconds (max 30)." << std::endl
178 + << kArgDelaySpecific << "<delay1,delay2,...> : Add delays to request processing in seconds. Delays are limited to 30 seconds and are applied round-robin to requests. Default is 0." << std::endl
179 << kArgPath << " <path> : Used the specified path instead of default. Must come after --user." << std::endl
180 << kArgQueued << " : Queue requests for processing in a background thread" << std::endl
181 << kArgThreads << " : When queued, number of threads in the request processing thread pool" << std::endl
182 << kArgUserSpecific << " : Make agent OS user specific." << std::endl
183 << kArgHelp << " : prints this help message" << std::endl
184 - << kArgSavePrintRequestDataTo << " : saves the PDF data to the given file path for print requests";
185 + << kArgSavePrintRequestDataTo << " : saves the PDF data to the given file path for print requests" << std::endl
186 + << kArgToBlock << "<regex> : Regular expression matching file and text content to block." << std::endl
187 + << kArgToWarn << "<regex> : Regular expression matching file and text content to warn about." << std::endl
188 + << kArgToReport << "<regex> : Regular expression matching file and text content to report." << std::endl
189 + << kArgMisbehave << "<mode> : Use 'misbehaving' agent in given mode for testing purposes." << std::endl;
192 int main(int argc, char* argv[]) {
193 @@ -87,9 +150,17 @@ int main(int argc, char* argv[]) {
194 return 1;
197 - auto handler = use_queue
198 - ? std::make_unique<QueuingHandler>(num_threads, delay, save_print_data_path)
199 - : std::make_unique<Handler>(delay, save_print_data_path);
200 + auto handler =
201 + useMisbehavingHandler
202 + ? MisbehavingHandler::Create(modeStr, std::move(delays), save_print_data_path, std::move(toBlock), std::move(toWarn), std::move(toReport))
203 + : use_queue
204 + ? std::make_unique<QueuingHandler>(num_threads, std::move(delays), save_print_data_path, std::move(toBlock), std::move(toWarn), std::move(toReport))
205 + : std::make_unique<Handler>(std::move(delays), save_print_data_path, std::move(toBlock), std::move(toWarn), std::move(toReport));
207 + if (!handler) {
208 + std::cout << "[Demo] Failed to construct handler." << std::endl;
209 + return 1;
212 // Each agent uses a unique name to identify itself with Google Chrome.
213 content_analysis::sdk::ResultCode rc;
214 diff --git a/demo/handler.h b/demo/handler.h
215 index 9d1ccfdf9857a..88599963c51b0 100644
216 --- a/demo/handler.h
217 +++ b/demo/handler.h
218 @@ -7,31 +7,51 @@
220 #include <time.h>
222 +#include <algorithm>
223 +#include <atomic>
224 #include <chrono>
225 #include <cstdio>
226 #include <fstream>
227 #include <iostream>
228 +#include <optional>
229 #include <thread>
230 #include <utility>
231 +#include <regex>
232 #include <vector>
234 #include "content_analysis/sdk/analysis_agent.h"
235 #include "demo/atomic_output.h"
236 #include "demo/request_queue.h"
238 +using RegexArray = std::vector<std::pair<std::string, std::regex>>;
240 // An AgentEventHandler that dumps requests information to stdout and blocks
241 // any requests that have the keyword "block" in their data
242 class Handler : public content_analysis::sdk::AgentEventHandler {
243 public:
244 using Event = content_analysis::sdk::ContentAnalysisEvent;
246 - Handler(unsigned long delay, const std::string& print_data_file_path) :
247 - delay_(delay), print_data_file_path_(print_data_file_path) {
249 + Handler(std::vector<unsigned long>&& delays, const std::string& print_data_file_path,
250 + RegexArray&& toBlock = RegexArray(),
251 + RegexArray&& toWarn = RegexArray(),
252 + RegexArray&& toReport = RegexArray()) :
253 + toBlock_(std::move(toBlock)), toWarn_(std::move(toWarn)), toReport_(std::move(toReport)),
254 + delays_(std::move(delays)), print_data_file_path_(print_data_file_path) {}
256 - unsigned long delay() { return delay_; }
257 + const std::vector<unsigned long> delays() { return delays_; }
258 + size_t nextDelayIndex() const { return nextDelayIndex_; }
260 protected:
261 + // subclasses can override this
262 + // returns whether the response has been set
263 + virtual bool SetCustomResponse(AtomicCout& aout, std::unique_ptr<Event>& event) {
264 + return false;
266 + // subclasses can override this
267 + // returns whether the response has been sent
268 + virtual bool SendCustomResponse(std::unique_ptr<Event>& event) {
269 + return false;
271 // Analyzes one request from Google Chrome and responds back to the browser
272 // with either an allow or block verdict.
273 void AnalyzeContent(AtomicCout& aout, std::unique_ptr<Event> event) {
274 @@ -43,29 +63,25 @@ class Handler : public content_analysis::sdk::AgentEventHandler {
276 DumpEvent(aout.stream(), event.get());
278 - bool block = false;
279 bool success = true;
280 - unsigned long delay = delay_;
282 - if (event->GetRequest().has_text_content()) {
283 - block = ShouldBlockRequest(
284 - event->GetRequest().text_content());
285 - GetFileSpecificDelay(event->GetRequest().text_content(), &delay);
286 - } else if (event->GetRequest().has_file_path()) {
287 - std::string content;
288 - success =
289 - ReadContentFromFile(event->GetRequest().file_path(),
290 - &content);
291 - if (success) {
292 - block = ShouldBlockRequest(content);
293 - GetFileSpecificDelay(content, &delay);
294 + std::optional<content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action> caResponse;
295 + bool setResponse = SetCustomResponse(aout, event);
296 + if (!setResponse) {
297 + caResponse = content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_BLOCK;
298 + if (event->GetRequest().has_text_content()) {
299 + caResponse = DecideCAResponse(
300 + event->GetRequest().text_content(), aout.stream());
301 + } else if (event->GetRequest().has_file_path()) {
302 + // TODO: Fix downloads to store file *first* so we can check contents.
303 + // Until then, just check the file name:
304 + caResponse = DecideCAResponse(
305 + event->GetRequest().file_path(), aout.stream());
306 + } else if (event->GetRequest().has_print_data()) {
307 + // In the case of print request, normally the PDF bytes would be parsed
308 + // for sensitive data violations. To keep this class simple, only the
309 + // URL is checked for the word "block".
310 + caResponse = DecideCAResponse(event->GetRequest().request_data().url(), aout.stream());
312 - } else if (event->GetRequest().has_print_data()) {
313 - // In the case of print request, normally the PDF bytes would be parsed
314 - // for sensitive data violations. To keep this class simple, only the
315 - // URL is checked for the word "block".
316 - block = ShouldBlockRequest(event->GetRequest().request_data().url());
317 - GetFileSpecificDelay(event->GetRequest().request_data().url(), &delay);
320 if (!success) {
321 @@ -75,22 +91,44 @@ class Handler : public content_analysis::sdk::AgentEventHandler {
322 content_analysis::sdk::ContentAnalysisResponse::Result::FAILURE);
323 aout.stream() << " Verdict: failed to reach verdict: ";
324 aout.stream() << event->DebugString() << std::endl;
325 - } else if (block) {
326 - auto rc = content_analysis::sdk::SetEventVerdictToBlock(event.get());
327 - aout.stream() << " Verdict: block";
328 - if (rc != content_analysis::sdk::ResultCode::OK) {
329 - aout.stream() << " error: "
330 - << content_analysis::sdk::ResultCodeToString(rc) << std::endl;
331 - aout.stream() << " " << event->DebugString() << std::endl;
332 + } else {
333 + aout.stream() << " Verdict: ";
334 + if (caResponse) {
335 + switch (caResponse.value()) {
336 + case content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_BLOCK:
337 + aout.stream() << "BLOCK";
338 + break;
339 + case content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_WARN:
340 + aout.stream() << "WARN";
341 + break;
342 + case content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_REPORT_ONLY:
343 + aout.stream() << "REPORT_ONLY";
344 + break;
345 + case content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_ACTION_UNSPECIFIED:
346 + aout.stream() << "ACTION_UNSPECIFIED";
347 + break;
348 + default:
349 + aout.stream() << "<error>";
350 + break;
352 + auto rc =
353 + content_analysis::sdk::SetEventVerdictTo(event.get(), caResponse.value());
354 + if (rc != content_analysis::sdk::ResultCode::OK) {
355 + aout.stream() << " error: "
356 + << content_analysis::sdk::ResultCodeToString(rc) << std::endl;
357 + aout.stream() << " " << event->DebugString() << std::endl;
359 + aout.stream() << std::endl;
360 + } else {
361 + aout.stream() << " Verdict: allow" << std::endl;
363 aout.stream() << std::endl;
364 - } else {
365 - aout.stream() << " Verdict: allow" << std::endl;
368 aout.stream() << std::endl;
370 // If a delay is specified, wait that much.
371 + size_t nextDelayIndex = nextDelayIndex_.fetch_add(1);
372 + unsigned long delay = delays_[nextDelayIndex % delays_.size()];
373 if (delay > 0) {
374 aout.stream() << "Delaying response to " << event->GetRequest().request_token()
375 << " for " << delay << "s" << std::endl<< std::endl;
376 @@ -99,16 +137,19 @@ class Handler : public content_analysis::sdk::AgentEventHandler {
379 // Send the response back to Google Chrome.
380 - auto rc = event->Send();
381 - if (rc != content_analysis::sdk::ResultCode::OK) {
382 - aout.stream() << "[Demo] Error sending response: "
383 - << content_analysis::sdk::ResultCodeToString(rc)
384 - << std::endl;
385 - aout.stream() << event->DebugString() << std::endl;
386 + bool sentCustomResponse = SendCustomResponse(event);
387 + if (!sentCustomResponse) {
388 + auto rc = event->Send();
389 + if (rc != content_analysis::sdk::ResultCode::OK) {
390 + aout.stream() << "[Demo] Error sending response: "
391 + << content_analysis::sdk::ResultCodeToString(rc)
392 + << std::endl;
393 + aout.stream() << event->DebugString() << std::endl;
398 - private:
399 + protected:
400 void OnBrowserConnected(
401 const content_analysis::sdk::BrowserInfo& info) override {
402 AtomicCout aout;
403 @@ -362,21 +403,40 @@ class Handler : public content_analysis::sdk::AgentEventHandler {
404 return true;
407 - bool ShouldBlockRequest(const std::string& content) {
408 - // Determines if the request should be blocked. For this simple example
409 - // the content is blocked if the string "block" is found. Otherwise the
410 - // content is allowed.
411 - return content.find("block") != std::string::npos;
414 - void GetFileSpecificDelay(const std::string& content, unsigned long* delay) {
415 - auto pos = content.find("delay=");
416 - if (pos != std::string::npos) {
417 - std::sscanf(content.substr(pos).c_str(), "delay=%lu", delay);
418 + std::optional<content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action>
419 + DecideCAResponse(const std::string& content, std::stringstream& stream) {
420 + for (auto& r : toBlock_) {
421 + if (std::regex_search(content, r.second)) {
422 + stream << "'" << content << "' matches BLOCK regex '"
423 + << r.first << "'" << std::endl;
424 + return content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_BLOCK;
427 + for (auto& r : toWarn_) {
428 + if (std::regex_search(content, r.second)) {
429 + stream << "'" << content << "' matches WARN regex '"
430 + << r.first << "'" << std::endl;
431 + return content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_WARN;
434 + for (auto& r : toReport_) {
435 + if (std::regex_search(content, r.second)) {
436 + stream << "'" << content << "' matches REPORT_ONLY regex '"
437 + << r.first << "'" << std::endl;
438 + return content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_REPORT_ONLY;
441 + stream << "'" << content << "' was ALLOWed\n";
442 + return {};
445 - unsigned long delay_;
446 + // For the demo, block any content that matches these wildcards.
447 + RegexArray toBlock_;
448 + RegexArray toWarn_;
449 + RegexArray toReport_;
451 + std::vector<unsigned long> delays_;
452 + std::atomic<size_t> nextDelayIndex_;
453 std::string print_data_file_path_;
456 @@ -384,8 +444,11 @@ class Handler : public content_analysis::sdk::AgentEventHandler {
457 // any requests that have the keyword "block" in their data
458 class QueuingHandler : public Handler {
459 public:
460 - QueuingHandler(unsigned long threads, unsigned long delay, const std::string& print_data_file_path)
461 - : Handler(delay, print_data_file_path) {
462 + QueuingHandler(unsigned long threads, std::vector<unsigned long>&& delays, const std::string& print_data_file_path,
463 + RegexArray&& toBlock = RegexArray(),
464 + RegexArray&& toWarn = RegexArray(),
465 + RegexArray&& toReport = RegexArray())
466 + : Handler(std::move(delays), print_data_file_path, std::move(toBlock), std::move(toWarn), std::move(toReport)) {
467 StartBackgroundThreads(threads);
470 @@ -421,6 +484,8 @@ class QueuingHandler : public Handler {
471 aout.stream() << std::endl << "----------" << std::endl;
472 aout.stream() << "Thread: " << std::this_thread::get_id()
473 << std::endl;
474 + aout.stream() << "Delaying request processing for "
475 + << handler->delays()[handler->nextDelayIndex() % handler->delays().size()] << "s" << std::endl << std::endl;
476 aout.flush();
478 handler->AnalyzeContent(aout, std::move(event));
479 diff --git a/demo/handler_misbehaving.h b/demo/handler_misbehaving.h
480 new file mode 100644
481 index 0000000000000..bb0b4f18adcff
482 --- /dev/null
483 +++ b/demo/handler_misbehaving.h
484 @@ -0,0 +1,290 @@
485 +/* This Source Code Form is subject to the terms of the Mozilla Public
486 + * License, v. 2.0. If a copy of the MPL was not distributed with this
487 + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
489 +#ifndef CONTENT_ANALYSIS_DEMO_HANDLER_MISBEHAVING_H_
490 +#define CONTENT_ANALYSIS_DEMO_HANDLER_MISBEHAVING_H_
492 +#include <time.h>
494 +#include <algorithm>
495 +#include <chrono>
496 +#include <fstream>
497 +#include <map>
498 +#include <iostream>
499 +#include <utility>
500 +#include <vector>
501 +#include <regex>
502 +#include <windows.h>
504 +#include "content_analysis/sdk/analysis.pb.h"
505 +#include "content_analysis/sdk/analysis_agent.h"
506 +#include "agent/src/event_win.h"
507 +#include "handler.h"
509 +enum class Mode {
510 +// Have to use a "Mode_" prefix to avoid preprocessing problems in StringToMode
511 +#define AGENT_MODE(name) Mode_##name,
512 +#include "modes.h"
513 +#undef AGENT_MODE
516 +extern std::map<std::string, Mode> sStringToMode;
517 +extern std::map<Mode, std::string> sModeToString;
519 +// Writes a string to the pipe. Returns ERROR_SUCCESS if successful, else
520 +// returns GetLastError() of the write. This function does not return until
521 +// the entire message has been sent (or an error occurs).
522 +static DWORD WriteBigMessageToPipe(HANDLE pipe, const std::string& message) {
523 + std::cout << "[demo] WriteBigMessageToPipe top, message size is "
524 + << message.size() << std::endl;
525 + if (message.empty()) {
526 + return ERROR_SUCCESS;
529 + OVERLAPPED overlapped;
530 + memset(&overlapped, 0, sizeof(overlapped));
531 + overlapped.hEvent = CreateEvent(/*securityAttr=*/nullptr,
532 + /*manualReset=*/TRUE,
533 + /*initialState=*/FALSE,
534 + /*name=*/nullptr);
535 + if (overlapped.hEvent == nullptr) {
536 + return GetLastError();
539 + DWORD err = ERROR_SUCCESS;
540 + const char* cursor = message.data();
541 + for (DWORD size = message.length(); size > 0;) {
542 + std::cout << "[demo] WriteBigMessageToPipe top of loop, remaining size "
543 + << size << std::endl;
544 + if (WriteFile(pipe, cursor, size, /*written=*/nullptr, &overlapped)) {
545 + std::cout << "[demo] WriteBigMessageToPipe: success" << std::endl;
546 + err = ERROR_SUCCESS;
547 + break;
550 + // If an I/O is not pending, return the error.
551 + err = GetLastError();
552 + if (err != ERROR_IO_PENDING) {
553 + std::cout
554 + << "[demo] WriteBigMessageToPipe: returning error from WriteFile "
555 + << err << std::endl;
556 + break;
559 + DWORD written;
560 + if (!GetOverlappedResult(pipe, &overlapped, &written, /*wait=*/TRUE)) {
561 + err = GetLastError();
562 + std::cout << "[demo] WriteBigMessageToPipe: returning error from "
563 + "GetOverlappedREsult "
564 + << err << std::endl;
565 + break;
568 + // reset err for the next loop iteration
569 + err = ERROR_SUCCESS;
570 + std::cout << "[demo] WriteBigMessageToPipe: bottom of loop, wrote "
571 + << written << std::endl;
572 + cursor += written;
573 + size -= written;
576 + CloseHandle(overlapped.hEvent);
577 + return err;
580 +// An AgentEventHandler that does various misbehaving things
581 +class MisbehavingHandler final : public Handler {
582 + public:
583 + using Event = content_analysis::sdk::ContentAnalysisEvent;
585 + static
586 + std::unique_ptr<AgentEventHandler> Create(
587 + const std::string& modeStr,
588 + std::vector<unsigned long>&& delays,
589 + const std::string& print_data_file_path,
590 + RegexArray&& toBlock = RegexArray(),
591 + RegexArray&& toWarn = RegexArray(),
592 + RegexArray&& toReport = RegexArray()) {
593 + auto it = sStringToMode.find(modeStr);
594 + if (it == sStringToMode.end()) {
595 + std::cout << "\"" << modeStr << "\""
596 + << " is not a valid mode!" << std::endl;
597 + return nullptr;
600 + return std::unique_ptr<AgentEventHandler>(new MisbehavingHandler(it->second, std::move(delays), print_data_file_path, std::move(toBlock), std::move(toWarn), std::move(toReport)));
603 + private:
604 + MisbehavingHandler(Mode mode, std::vector<unsigned long>&& delays, const std::string& print_data_file_path,
605 + RegexArray&& toBlock = RegexArray(),
606 + RegexArray&& toWarn = RegexArray(),
607 + RegexArray&& toReport = RegexArray()) :
608 + Handler(std::move(delays), print_data_file_path, std::move(toBlock), std::move(toWarn), std::move(toReport)),
609 + mode_(mode) {}
612 + template <size_t N>
613 + DWORD SendBytesOverPipe(const unsigned char (&bytes)[N],
614 + const std::unique_ptr<Event>& event) {
615 + content_analysis::sdk::ContentAnalysisEventWin* eventWin =
616 + static_cast<content_analysis::sdk::ContentAnalysisEventWin*>(
617 + event.get());
618 + HANDLE pipe = eventWin->Pipe();
619 + std::string s(reinterpret_cast<const char*>(bytes), N);
620 + return WriteBigMessageToPipe(pipe, s);
623 + bool SetCustomResponse(AtomicCout& aout, std::unique_ptr<Event>& event) override {
624 + std::cout << std::endl << "----------" << std::endl << std::endl;
625 + std::cout << "Mode is " << sModeToString[mode_] << std::endl;
627 + bool handled = true;
628 + if (mode_ == Mode::Mode_largeResponse) {
629 + for (size_t i = 0; i < 1000; ++i) {
630 + content_analysis::sdk::ContentAnalysisResponse_Result* result =
631 + event->GetResponse().add_results();
632 + result->set_tag("someTag");
633 + content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule*
634 + triggeredRule = result->add_triggered_rules();
635 + triggeredRule->set_rule_id("some_id");
636 + triggeredRule->set_rule_name("some_name");
638 + } else if (mode_ ==
639 + Mode::Mode_invalidUtf8StringStartByteIsContinuationByte) {
640 + // protobuf docs say
641 + // "A string must always contain UTF-8 encoded text."
642 + // So let's try something invalid
643 + // Anything with bits 10xxxxxx is only a continuation code point
644 + event->GetResponse().set_request_token("\x80\x41\x41\x41");
645 + } else if (mode_ ==
646 + Mode::Mode_invalidUtf8StringEndsInMiddleOfMultibyteSequence) {
647 + // f0 byte indicates there should be 3 bytes following it, but here
648 + // there are only 2
649 + event->GetResponse().set_request_token("\x41\xf0\x90\x8d");
650 + } else if (mode_ == Mode::Mode_invalidUtf8StringOverlongEncoding) {
651 + // codepoint U+20AC, should be encoded in 3 bytes (E2 82 AC)
652 + // instead of 4
653 + event->GetResponse().set_request_token("\xf0\x82\x82\xac");
654 + } else if (mode_ == Mode::Mode_invalidUtf8StringMultibyteSequenceTooShort) {
655 + // f0 byte indicates there should be 3 bytes following it, but here
656 + // there are only 2 (\x41 is not a continuation byte)
657 + event->GetResponse().set_request_token("\xf0\x90\x8d\x41");
658 + } else if (mode_ == Mode::Mode_invalidUtf8StringDecodesToInvalidCodePoint) {
659 + // decodes to U+1FFFFF, but only up to U+10FFFF is a valid code point
660 + event->GetResponse().set_request_token("\xf7\xbf\xbf\xbf");
661 + } else if (mode_ == Mode::Mode_stringWithEmbeddedNull) {
662 + event->GetResponse().set_request_token("\x41\x00\x41");
663 + } else if (mode_ == Mode::Mode_zeroResults) {
664 + event->GetResponse().clear_results();
665 + } else if (mode_ == Mode::Mode_resultWithInvalidStatus) {
666 + // This causes an assertion failure and the process exits
667 + // So we just serialize this ourselves in SendCustomResponse()
668 + /*content_analysis::sdk::ContentAnalysisResponse_Result* result =
669 + event->GetResponse().mutable_results(0);
670 + result->set_status(
671 + static_cast<
672 + ::content_analysis::sdk::ContentAnalysisResponse_Result_Status>(
673 + 100));*/
674 + } else {
675 + handled = false;
677 + return handled;
680 + bool SendCustomResponse(std::unique_ptr<Event>& event) override {
681 + if (mode_ == Mode::Mode_largeResponse) {
682 + content_analysis::sdk::ContentAnalysisEventWin* eventWin =
683 + static_cast<content_analysis::sdk::ContentAnalysisEventWin*>(
684 + event.get());
685 + HANDLE pipe = eventWin->Pipe();
686 + std::cout << "largeResponse about to write" << std::endl;
687 + DWORD result = WriteBigMessageToPipe(
688 + pipe, eventWin->SerializeStringToSendToBrowser());
689 + std::cout << "largeResponse done writing with error " << result
690 + << std::endl;
691 + eventWin->SetResponseSent();
692 + } else if (mode_ == Mode::Mode_resultWithInvalidStatus) {
693 + content_analysis::sdk::ContentAnalysisEventWin* eventWin =
694 + static_cast<content_analysis::sdk::ContentAnalysisEventWin*>(
695 + event.get());
696 + HANDLE pipe = eventWin->Pipe();
697 + std::string serializedString = eventWin->SerializeStringToSendToBrowser();
698 + // The last byte is the status value. Set it to 100
699 + serializedString[serializedString.length() - 1] = 100;
700 + WriteBigMessageToPipe(pipe, serializedString);
701 + } else if (mode_ == Mode::Mode_messageTruncatedInMiddleOfString) {
702 + unsigned char bytes[5];
703 + bytes[0] = 10; // field 1 (request_token), LEN encoding
704 + bytes[1] = 13; // length 13
705 + bytes[2] = 65; // "A"
706 + bytes[3] = 66; // "B"
707 + bytes[4] = 67; // "C"
708 + SendBytesOverPipe(bytes, event);
709 + } else if (mode_ == Mode::Mode_messageWithInvalidWireType) {
710 + unsigned char bytes[5];
711 + bytes[0] = 15; // field 1 (request_token), "7" encoding (invalid value)
712 + bytes[1] = 3; // length 3
713 + bytes[2] = 65; // "A"
714 + bytes[3] = 66; // "B"
715 + bytes[4] = 67; // "C"
716 + SendBytesOverPipe(bytes, event);
717 + } else if (mode_ == Mode::Mode_messageWithUnusedFieldNumber) {
718 + unsigned char bytes[5];
719 + bytes[0] = 82; // field 10 (this is invalid), LEN encoding
720 + bytes[1] = 3; // length 3
721 + bytes[2] = 65; // "A"
722 + bytes[3] = 66; // "B"
723 + bytes[4] = 67; // "C"
724 + SendBytesOverPipe(bytes, event);
725 + } else if (mode_ == Mode::Mode_messageWithWrongStringWireType) {
726 + unsigned char bytes[2];
727 + bytes[0] = 10; // field 1 (request_token), VARINT encoding (but should be
728 + // a string/LEN)
729 + bytes[1] = 42; // value 42
730 + SendBytesOverPipe(bytes, event);
731 + } else if (mode_ == Mode::Mode_messageWithZeroTag) {
732 + unsigned char bytes[1];
733 + // The protobuf deserialization code seems to handle this
734 + // in a special case.
735 + bytes[0] = 0;
736 + SendBytesOverPipe(bytes, event);
737 + } else if (mode_ == Mode::Mode_messageWithZeroFieldButNonzeroWireType) {
738 + // The protobuf deserialization code seems to handle this
739 + // in a special case.
740 + unsigned char bytes[5];
741 + bytes[0] = 2; // field 0 (invalid), LEN encoding
742 + bytes[1] = 3; // length 13
743 + bytes[2] = 65; // "A"
744 + bytes[3] = 66; // "B"
745 + bytes[4] = 67; // "C"
746 + SendBytesOverPipe(bytes, event);
747 + } else if (mode_ == Mode::Mode_messageWithGroupEnd) {
748 + // GROUP_ENDs are obsolete and the deserialization code
749 + // handles them in a special case.
750 + unsigned char bytes[1];
751 + bytes[0] = 12; // field 1 (request_token), GROUP_END encoding
752 + SendBytesOverPipe(bytes, event);
753 + } else if (mode_ == Mode::Mode_messageTruncatedInMiddleOfVarint) {
754 + unsigned char bytes[2];
755 + bytes[0] = 16; // field 2 (status), VARINT encoding
756 + bytes[1] = 128; // high bit is set, indicating there
757 + // should be a byte after this
758 + SendBytesOverPipe(bytes, event);
759 + } else if (mode_ == Mode::Mode_messageTruncatedInMiddleOfTag) {
760 + unsigned char bytes[1];
761 + bytes[0] = 128; // tag is actually encoded as a VARINT, so set the high
762 + // bit, indicating there should be a byte after this
763 + SendBytesOverPipe(bytes, event);
764 + } else {
765 + return false;
767 + return true;
770 + private:
771 + Mode mode_;
774 +#endif // CONTENT_ANALYSIS_DEMO_HANDLER_MISBEHAVING_H_
775 diff --git a/demo/modes.h b/demo/modes.h
776 new file mode 100644
777 index 0000000000000..debefc9d1a66c
778 --- /dev/null
779 +++ b/demo/modes.h
780 @@ -0,0 +1,25 @@
781 +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
782 +/* This Source Code Form is subject to the terms of the Mozilla Public
783 + * License, v. 2.0. If a copy of the MPL was not distributed with this
784 + * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
786 +// #define AGENT_MODE(name) to do what you want and then #include this file
788 +AGENT_MODE(largeResponse)
789 +AGENT_MODE(invalidUtf8StringStartByteIsContinuationByte)
790 +AGENT_MODE(invalidUtf8StringEndsInMiddleOfMultibyteSequence)
791 +AGENT_MODE(invalidUtf8StringOverlongEncoding)
792 +AGENT_MODE(invalidUtf8StringMultibyteSequenceTooShort)
793 +AGENT_MODE(invalidUtf8StringDecodesToInvalidCodePoint)
794 +AGENT_MODE(stringWithEmbeddedNull)
795 +AGENT_MODE(zeroResults)
796 +AGENT_MODE(resultWithInvalidStatus)
797 +AGENT_MODE(messageTruncatedInMiddleOfString)
798 +AGENT_MODE(messageWithInvalidWireType)
799 +AGENT_MODE(messageWithUnusedFieldNumber)
800 +AGENT_MODE(messageWithWrongStringWireType)
801 +AGENT_MODE(messageWithZeroTag)
802 +AGENT_MODE(messageWithZeroFieldButNonzeroWireType)
803 +AGENT_MODE(messageWithGroupEnd)
804 +AGENT_MODE(messageTruncatedInMiddleOfVarint)
805 +AGENT_MODE(messageTruncatedInMiddleOfTag)
807 2.42.0.windows.2