2 * Copyright (c) Meta Platforms, Inc. and affiliates.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
24 #include <boost/intrusive/unordered_set.hpp>
26 #include <folly/IntrusiveList.h>
27 #include <folly/Likely.h>
28 #include <folly/Portability.h>
29 #include <folly/fibers/Baton.h>
31 #include <thrift/lib/cpp2/transport/rocket/Types.h>
32 #include <thrift/lib/cpp2/transport/rocket/framing/FrameType.h>
33 #include <thrift/lib/cpp2/transport/rocket/framing/Frames.h>
34 #include <thrift/lib/cpp2/transport/rocket/framing/Serializer.h>
35 #include <thrift/lib/thrift/gen-cpp2/RpcMetadata_types.h>
40 class RequestContextQueue
;
42 class RequestContext
{
44 class WriteSuccessCallback
{
46 virtual ~WriteSuccessCallback() = default;
47 virtual void onWriteSuccess() noexcept
= 0;
50 enum class State
: uint8_t {
51 DEFERRED_INIT
, /* still needs to be intialized with server version */
54 WRITE_SENDING
, /* AsyncSocket::writeChain() called, but WriteCallback has
56 WRITE_SENT
, /* Write to socket completed (possibly with error) */
57 COMPLETE
, /* Terminal state. Result stored in responsePayload_ */
60 template <class Frame
>
63 RequestContextQueue
& queue
,
64 SetupFrame
* setupFrame
= nullptr,
65 WriteSuccessCallback
* writeSuccessCallback
= nullptr)
67 streamId_(frame
.streamId()),
68 frameType_(Frame::frameType()),
69 writeSuccessCallback_(writeSuccessCallback
) {
70 serialize(std::forward
<Frame
>(frame
), setupFrame
);
73 template <class InitFunc
>
76 int32_t serverVersion
,
78 RequestContextQueue
& queue
,
79 WriteSuccessCallback
* writeSuccessCallback
= nullptr)
82 writeSuccessCallback_(writeSuccessCallback
) {
83 if (UNLIKELY(serverVersion
== -1)) {
84 deferredInit_
= std::forward
<InitFunc
>(initFunc
);
85 state_
= State::DEFERRED_INIT
;
87 std::tie(serializedFrame_
, frameType_
) = initFunc(serverVersion
);
91 RequestContext(const RequestContext
&) = delete;
92 RequestContext(RequestContext
&&) = delete;
93 RequestContext
& operator=(const RequestContext
&) = delete;
94 RequestContext
& operator=(RequestContext
&&) = delete;
96 // For REQUEST_RESPONSE contexts, where an immediate matching response is
98 FOLLY_NODISCARD
folly::Try
<Payload
> waitForResponse(
99 std::chrono::milliseconds timeout
);
100 FOLLY_NODISCARD
folly::Try
<Payload
> getResponse() &&;
102 // For request types for which an immediate matching response is not
103 // necessarily expected, e.g., REQUEST_FNF and REQUEST_STREAM
104 FOLLY_NODISCARD
folly::Try
<void> waitForWriteToComplete();
106 void waitForWriteToCompleteSchedule(folly::fibers::Baton::Waiter
* waiter
);
107 FOLLY_NODISCARD
folly::Try
<void> waitForWriteToCompleteResult();
110 folly::HHWheelTimer
& timer
,
111 folly::HHWheelTimer::Callback
& callback
,
112 std::chrono::milliseconds timeout
) {
114 timeoutCallback_
= &callback
;
115 requestTimeout_
= timeout
;
118 void scheduleTimeoutForResponse() {
119 DCHECK(isRequestResponse());
120 // In some edge cases, response may arrive before write to socket finishes.
121 if (state_
!= State::COMPLETE
&&
122 requestTimeout_
!= std::chrono::milliseconds::zero()) {
123 timer_
->scheduleTimeout(timeoutCallback_
, requestTimeout_
);
127 std::unique_ptr
<folly::IOBuf
> serializedChain() {
128 DCHECK(serializedFrame_
);
129 return std::move(serializedFrame_
);
132 size_t endOffsetInBatch() const {
133 DCHECK_GT(endOffsetInBatch_
, 0);
134 return endOffsetInBatch_
;
137 void setEndOffsetInBatch(ssize_t offset
) { endOffsetInBatch_
= offset
; }
139 State
state() const { return state_
; }
141 StreamId
streamId() const { return streamId_
; }
143 bool isRequestResponse() const {
144 return frameType_
== FrameType::REQUEST_RESPONSE
;
147 void onPayloadFrame(PayloadFrame
&& payloadFrame
);
148 void onErrorFrame(ErrorFrame
&& errorFrame
);
150 void onWriteSuccess() noexcept
;
152 bool hasPartialPayload() const { return responsePayload_
.hasValue(); }
154 void initWithVersion(int32_t serverVersion
) {
155 if (!deferredInit_
) {
158 DCHECK(state_
== State::DEFERRED_INIT
);
159 std::tie(serializedFrame_
, frameType_
) = deferredInit_(serverVersion
);
160 DCHECK(serializedFrame_
&& frameType_
!= FrameType::RESERVED
);
161 state_
= State::WRITE_NOT_SCHEDULED
;
165 RequestContextQueue
& queue_
;
166 folly::SafeIntrusiveListHook queueHook_
;
167 std::unique_ptr
<folly::IOBuf
> serializedFrame_
;
168 ssize_t endOffsetInBatch_
{};
170 FrameType frameType_
;
171 State state_
{State::WRITE_NOT_SCHEDULED
};
172 bool lastInWriteBatch_
{false};
173 bool isDummyEndOfBatchMarker_
{false};
175 boost::intrusive::unordered_set_member_hook
<> setHook_
;
176 folly::fibers::Baton baton_
;
177 std::chrono::milliseconds requestTimeout_
{1000};
178 folly::HHWheelTimer
* timer_
{nullptr};
179 folly::HHWheelTimer::Callback
* timeoutCallback_
{nullptr};
180 folly::Try
<Payload
> responsePayload_
;
181 WriteSuccessCallback
* const writeSuccessCallback_
{nullptr};
182 folly::Function
<std::pair
<std::unique_ptr
<folly::IOBuf
>, FrameType
>(int32_t)>
183 deferredInit_
{nullptr};
185 template <class Frame
>
186 void serialize(Frame
&& frame
, SetupFrame
* setupFrame
) {
187 DCHECK(!serializedFrame_
);
189 serializedFrame_
= std::move(frame
).serialize();
191 if (UNLIKELY(setupFrame
!= nullptr)) {
193 std::move(*setupFrame
).serialize(writer
);
194 auto setupBuffer
= std::move(writer
).move();
195 setupBuffer
->prependChain(std::move(serializedFrame_
));
196 serializedFrame_
= std::move(setupBuffer
);
200 explicit RequestContext(RequestContextQueue
& queue
)
201 : queue_(queue
), frameType_(FrameType::REQUEST_RESPONSE
) {}
203 static RequestContext
& createDummyEndOfBatchMarker(
204 RequestContextQueue
& queue
) {
205 auto* rctx
= new RequestContext(queue
);
206 rctx
->lastInWriteBatch_
= true;
207 rctx
->isDummyEndOfBatchMarker_
= true;
208 rctx
->state_
= State::WRITE_SENDING
;
214 const RequestContext
& ctxa
, const RequestContext
& ctxb
) const noexcept
{
215 return ctxa
.streamId_
== ctxb
.streamId_
;
220 size_t operator()(const RequestContext
& ctx
) const noexcept
{
221 return std::hash
<StreamId::underlying_type
>()(
222 static_cast<uint32_t>(ctx
.streamId_
));
228 folly::CountedIntrusiveList
<RequestContext
, &RequestContext::queueHook_
>;
230 using UnorderedSet
= boost::intrusive::unordered_set
<
232 boost::intrusive::member_hook
<
235 &RequestContext::setHook_
>,
236 boost::intrusive::equal
<Equal
>,
237 boost::intrusive::hash
<Hash
>>;
240 friend class RequestContextQueue
;
243 } // namespace rocket
244 } // namespace thrift
245 } // namespace apache