59#include <userver/engine/task/cancel.hpp>
60#include <userver/engine/wait_any.hpp>
61#include <userver/utils/assert.hpp>
62#include <userver/utils/async.hpp>
63#include <userver/utils/datetime.hpp>
65USERVER_NAMESPACE_BEGIN
67namespace utils::hedging {
79template <
typename RequestStrategy>
92using TimePoint = Clock::time_point;
94enum class Action { StartTry, Stop };
98 PlanEntry(TimePoint timepoint, std::size_t request_index,
99 std::size_t attempt_id, Action action)
100 : timepoint(timepoint),
101 request_index(request_index),
102 attempt_id(attempt_id),
105 bool operator<(
const PlanEntry& other)
const noexcept {
106 return tie() < other.tie();
108 bool operator>(
const PlanEntry& other)
const noexcept {
109 return tie() > other.tie();
111 bool operator==(
const PlanEntry& other)
const noexcept {
112 return tie() == other.tie();
114 bool operator<=(
const PlanEntry& other)
const noexcept {
115 return tie() <= other.tie();
117 bool operator>=(
const PlanEntry& other)
const noexcept {
118 return tie() >= other.tie();
120 bool operator!=(
const PlanEntry& other)
const noexcept {
121 return tie() != other.tie();
125 std::size_t request_index{0};
126 std::size_t attempt_id{0};
130 std::tuple<
const TimePoint&,
const size_t&,
const size_t&,
const Action&>
131 tie()
const noexcept {
132 return std::tie(timepoint, request_index, attempt_id, action);
138template <
typename RequestStrategy>
139struct SubrequestWrapper {
140 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
142 SubrequestWrapper() =
default;
143 SubrequestWrapper(SubrequestWrapper&&)
noexcept =
default;
144 explicit SubrequestWrapper(std::optional<RequestType>&& request)
145 : request(std::move(request)) {}
147 engine::impl::ContextAccessor* TryGetContextAccessor() {
148 if (!request)
return nullptr;
149 return request->TryGetContextAccessor();
152 std::optional<RequestType> request;
156 std::vector<std::size_t> subrequest_indices;
157 std::size_t attempts_made = 0;
158 bool finished =
false;
161template <
typename RequestStrategy>
163 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
164 using ReplyType =
typename RequestTraits<RequestStrategy>::ReplyType;
167 : inputs_(std::move(inputs)), settings(std::move(settings)) {
168 const std::size_t size =
this->inputs_.size();
169 request_states_.resize(size);
171 Context(Context&&)
noexcept =
default;
173 void Prepare(TimePoint start_time) {
174 const auto request_count = GetRequestsCount();
175 for (std::size_t request_id = 0; request_id < request_count; ++request_id) {
176 plan_.emplace(start_time, request_id, 0, Action::StartTry);
178 plan_.emplace(start_time + settings.timeout_all, 0, 0, Action::Stop);
179 subrequests_.reserve(settings.max_attempts * request_count);
182 std::optional<TimePoint> NextEventTime()
const {
183 if (plan_.empty())
return std::nullopt;
184 return plan_.top().timepoint;
186 std::optional<PlanEntry> PopPlan() {
187 if (plan_.empty())
return std::nullopt;
188 auto ret = plan_.top();
192 bool IsStop()
const {
return stop_; }
194 void FinishAllSubrequests(std::size_t request_index) {
195 auto& request_state = request_states_[request_index];
196 request_state.finished =
true;
197 const auto& subrequest_indices = request_state.subrequest_indices;
198 auto& strategy = inputs_[request_index];
199 for (
auto i : subrequest_indices) {
200 auto& request = subrequests_[i].request;
202 strategy.Finish(std::move(*request));
210 size_t GetRequestsCount()
const {
return inputs_.size(); }
212 size_t GetRequestIdxBySubrequestIdx(size_t subrequest_idx)
const {
213 return input_by_subrequests_.at(subrequest_idx);
216 RequestStrategy& GetStrategy(size_t index) {
return inputs_[index]; }
218 auto& GetSubRequests() {
return subrequests_; }
220 std::vector<std::optional<ReplyType>> ExtractAllReplies() {
221 std::vector<std::optional<ReplyType>> ret;
222 ret.reserve(GetRequestsCount());
223 for (
auto&& strategy : inputs_) {
224 ret.emplace_back(strategy.ExtractReply());
233 void OnActionStop() {
234 for (std::size_t i = 0; i < inputs_.size(); ++i) FinishAllSubrequests(i);
240 void OnActionStartTry(std::size_t request_index, std::size_t attempt_id,
242 auto& request_state = request_states_[request_index];
243 if (request_state.finished) {
246 auto& attempts_made = request_state.attempts_made;
252 if (attempt_id < attempts_made) {
256 if (attempts_made >= settings.max_attempts) {
259 auto& strategy = inputs_[request_index];
260 auto request_opt = strategy.Create(attempts_made);
262 request_state.finished =
true;
267 const auto idx = subrequests_.size();
268 subrequests_.emplace_back(std::move(request_opt));
269 request_state.subrequest_indices.push_back(idx);
270 input_by_subrequests_[idx] = request_index;
272 plan_.emplace(now + settings.hedging_delay, request_index, attempts_made,
277 void OnRetriableReply(std::size_t request_idx,
278 std::chrono::milliseconds retry_delay, TimePoint now) {
279 const auto& request_state = request_states_[request_idx];
280 if (request_state.finished)
return;
281 if (request_state.attempts_made >= settings.max_attempts)
return;
283 plan_.emplace(now + retry_delay, request_idx, request_state.attempts_made,
287 void OnNonRetriableReply(std::size_t request_idx) {
288 FinishAllSubrequests(request_idx);
294 std::vector<RequestStrategy> inputs_;
298 std::priority_queue<PlanEntry, std::vector<PlanEntry>, std::greater<>>
300 std::vector<SubrequestWrapper<RequestStrategy>> subrequests_{};
302 std::unordered_map<std::size_t, std::size_t> input_by_subrequests_{};
303 std::vector<RequestState> request_states_{};
310template <
typename RequestStrategy>
312 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
313 using ReplyType =
typename RequestTraits<RequestStrategy>::ReplyType;
316 ~HedgedRequestBulkFuture() { task_.SyncCancel(); }
318 void Wait() { task_.Wait(); }
319 std::vector<std::optional<ReplyType>> Get() {
return task_.Get(); }
320 engine::impl::ContextAccessor* TryGetContextAccessor() {
321 return task_.TryGetContextAccessor();
325 template <
typename RequestStrategy_>
326 friend auto HedgeRequestsBulkAsync(std::vector<RequestStrategy_> inputs,
329 HedgedRequestBulkFuture(Task&& task) : task_(std::move(task)) {}
334template <
typename RequestStrategy>
336 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
337 using ReplyType =
typename RequestTraits<RequestStrategy>::ReplyType;
340 ~HedgedRequestFuture() { task_.SyncCancel(); }
342 void Wait() { task_.Wait(); }
343 std::optional<ReplyType> Get() {
return task_.Get(); }
344 void IgnoreResult() {}
345 engine::impl::ContextAccessor* TryGetContextAccessor() {
346 return task_.TryGetContextAccessor();
350 template <
typename RequestStrategy_>
351 friend auto HedgeRequestAsync(RequestStrategy_ input,
354 HedgedRequestFuture(Task&& task) : task_(std::move(task)) {}
364template <
typename RequestStrategy>
368 using Action = impl::Action;
369 using Clock = impl::Clock;
371 impl::Context(std::move(inputs), std::move(hedging_settings));
373 auto& sub_requests = context.GetSubRequests();
375 auto wakeup_time = Clock::now();
376 context.Prepare(wakeup_time);
378 while (!context.IsStop()) {
379 auto wait_result = engine::WaitAnyUntil(wakeup_time, sub_requests);
380 if (!wait_result.has_value()) {
381 if (engine::current_task::ShouldCancel()) {
382 context.OnActionStop();
386 auto plan_entry = context.PopPlan();
387 if (!plan_entry.has_value()) {
391 const auto [timestamp, request_index, attempt_id, action] = *plan_entry;
393 case Action::StartTry:
394 context.OnActionStartTry(request_index, attempt_id, timestamp);
397 context.OnActionStop();
400 auto next_wakeup_time = context.NextEventTime();
401 if (!next_wakeup_time.has_value()) {
404 wakeup_time = *next_wakeup_time;
407 const auto result_idx = *wait_result;
408 UASSERT(result_idx < sub_requests.size());
409 const auto request_idx = context.GetRequestIdxBySubrequestIdx(result_idx);
410 auto& strategy = context.GetStrategy(request_idx);
412 auto& request = sub_requests[result_idx].request;
413 UASSERT_MSG(request,
"Finished requests must not be empty");
414 auto reply = strategy.ProcessReply(std::move(*request));
415 if (reply.has_value()) {
418 context.OnRetriableReply(request_idx, *reply, Clock::now());
420 wakeup_time = *context.NextEventTime();
422 context.OnNonRetriableReply(request_idx);
425 return context.ExtractAllReplies();
435template <
typename RequestStrategy>
436auto HedgeRequestsBulkAsync(std::vector<RequestStrategy> inputs,
438 return HedgedRequestBulkFuture<RequestStrategy>(utils::Async(
439 "hedged-bulk-request",
440 [inputs{std::move(inputs)}, settings{std::move(settings)}]()
mutable {
441 return HedgeRequestsBulk(std::move(inputs), std::move(settings));
449template <
typename RequestStrategy>
452 std::vector<RequestStrategy> inputs;
453 inputs.emplace_back(std::move(input));
454 auto bulk_ret = HedgeRequestsBulk(std::move(inputs), std::move(settings));
455 if (bulk_ret.size() != 1) {
465template <
typename RequestStrategy>
469 [input{std::move(input)}, settings{std::move(settings)}]()
mutable {
470 return HedgeRequest(std::move(input), std::move(settings));