59#include <userver/engine/task/cancel.hpp>
60#include <userver/engine/wait_any.hpp>
61#include <userver/utils/assert.hpp>
62#include <userver/utils/async.hpp>
63#include <userver/utils/datetime.hpp>
65USERVER_NAMESPACE_BEGIN
67namespace utils::hedging {
79template <
typename RequestStrategy>
82 typename std::invoke_result_t<
decltype(&RequestStrategy::Create), RequestStrategy,
int>::value_type;
84 typename std::invoke_result_t<
decltype(&RequestStrategy::ExtractReply), RequestStrategy>::value_type;
90using TimePoint = Clock::time_point;
92enum class Action { kStartTry, kStop };
96 PlanEntry(TimePoint timepoint, std::size_t request_index, std::size_t attempt_id, Action action)
97 : timepoint(timepoint),
98 request_index(request_index),
99 attempt_id(attempt_id),
103 bool operator<(
const PlanEntry& other)
const noexcept {
return Tie() < other.Tie(); }
104 bool operator>(
const PlanEntry& other)
const noexcept {
return Tie() > other.Tie(); }
105 bool operator==(
const PlanEntry& other)
const noexcept {
return Tie() == other.Tie(); }
106 bool operator<=(
const PlanEntry& other)
const noexcept {
return Tie() <= other.Tie(); }
107 bool operator>=(
const PlanEntry& other)
const noexcept {
return Tie() >= other.Tie(); }
108 bool operator!=(
const PlanEntry& other)
const noexcept {
return Tie() != other.Tie(); }
111 std::size_t request_index{0};
112 std::size_t attempt_id{0};
116 std::tuple<
const TimePoint&,
const size_t&,
const size_t&,
const Action&> Tie()
const noexcept {
117 return std::tie(timepoint, request_index, attempt_id, action);
123template <
typename RequestStrategy>
124struct SubrequestWrapper {
125 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
127 SubrequestWrapper() =
default;
128 SubrequestWrapper(SubrequestWrapper&&)
noexcept =
default;
129 explicit SubrequestWrapper(std::optional<RequestType>&& request)
130 : request(std::move(request))
133 engine::impl::ContextAccessor* TryGetContextAccessor() {
137 return request->TryGetContextAccessor();
140 std::optional<RequestType> request;
144 std::vector<std::size_t> subrequest_indices;
145 std::size_t attempts_made = 0;
146 bool finished =
false;
149template <
typename RequestStrategy>
151 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
152 using ReplyType =
typename RequestTraits<RequestStrategy>::ReplyType;
155 : inputs_(std::move(inputs)),
156 settings_(std::move(settings))
158 const std::size_t size =
this->inputs_.size();
159 request_states_.resize(size);
161 Context(Context&&)
noexcept =
default;
163 void Prepare(TimePoint start_time) {
164 const auto request_count = GetRequestsCount();
165 for (std::size_t request_id = 0; request_id < request_count; ++request_id) {
166 plan_.emplace(start_time, request_id, 0, Action::kStartTry);
168 plan_.emplace(start_time + settings_
.timeout_all, 0, 0, Action::kStop);
169 subrequests_.reserve(settings_
.max_attempts * request_count);
172 std::optional<TimePoint> NextEventTime()
const {
176 return plan_.top().timepoint;
178 std::optional<PlanEntry> PopPlan() {
182 auto ret = plan_.top();
186 bool IsStop()
const {
return stop_; }
188 void FinishAllSubrequests(std::size_t request_index) {
189 auto& request_state = request_states_[request_index];
190 request_state.finished =
true;
191 const auto& subrequest_indices = request_state.subrequest_indices;
192 auto& strategy = inputs_[request_index];
193 for (
auto i : subrequest_indices) {
194 auto& request = subrequests_[i].request;
196 strategy.Finish(std::move(*request));
204 size_t GetRequestsCount()
const {
return inputs_.size(); }
206 size_t GetRequestIdxBySubrequestIdx(size_t subrequest_idx)
const {
207 return input_by_subrequests_.at(subrequest_idx);
210 RequestStrategy& GetStrategy(size_t index) {
return inputs_[index]; }
212 auto& GetSubRequests() {
return subrequests_; }
214 std::vector<std::optional<ReplyType>> ExtractAllReplies() {
215 std::vector<std::optional<ReplyType>> ret;
216 ret.reserve(GetRequestsCount());
217 for (
auto&& strategy : inputs_) {
218 ret.emplace_back(strategy.ExtractReply());
227 void OnActionStop() {
228 for (std::size_t i = 0; i < inputs_.size(); ++i) {
229 FinishAllSubrequests(i);
236 void OnActionStartTry(std::size_t request_index, std::size_t attempt_id, TimePoint now) {
237 auto& request_state = request_states_[request_index];
238 if (request_state.finished) {
241 auto& attempts_made = request_state.attempts_made;
247 if (attempt_id < attempts_made) {
254 auto& strategy = inputs_[request_index];
255 auto request_opt = strategy.Create(attempts_made);
257 request_state.finished =
true;
262 const auto idx = subrequests_.size();
263 subrequests_.emplace_back(std::move(request_opt));
264 request_state.subrequest_indices.push_back(idx);
265 input_by_subrequests_[idx] = request_index;
267 plan_.emplace(now + settings_
.hedging_delay, request_index, attempts_made, Action::kStartTry);
271 void OnRetriableReply(std::size_t request_idx, std::chrono::milliseconds retry_delay, TimePoint now) {
272 const auto& request_state = request_states_[request_idx];
273 if (request_state.finished) {
280 plan_.emplace(now + retry_delay, request_idx, request_state.attempts_made, Action::kStartTry);
283 void OnNonRetriableReply(std::size_t request_idx) { FinishAllSubrequests(request_idx); }
288 std::vector<RequestStrategy> inputs_;
292 std::priority_queue<PlanEntry, std::vector<PlanEntry>, std::greater<>> plan_{};
293 std::vector<SubrequestWrapper<RequestStrategy>> subrequests_{};
295 std::unordered_map<std::size_t, std::size_t> input_by_subrequests_{};
296 std::vector<RequestState> request_states_{};
303template <
typename RequestStrategy>
305 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
306 using ReplyType =
typename RequestTraits<RequestStrategy>::ReplyType;
309 ~HedgedRequestBulkFuture() { task_.SyncCancel(); }
315 std::vector<std::optional<ReplyType>>
Get() {
return task_.Get(); }
317 engine::impl::ContextAccessor* TryGetContextAccessor() {
return task_.TryGetContextAccessor(); }
320 template <
typename TRequestStrategy>
321 friend auto HedgeRequestsBulkAsync(std::vector<TRequestStrategy> inputs,
HedgingSettings settings);
323 HedgedRequestBulkFuture(Task&& task)
324 : task_(std::move(task))
330template <
typename RequestStrategy>
332 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
333 using ReplyType =
typename RequestTraits<RequestStrategy>::ReplyType;
336 ~HedgedRequestFuture() { task_.SyncCancel(); }
342 std::optional<ReplyType>
Get() {
return task_.Get(); }
344 void IgnoreResult() {}
346 engine::impl::ContextAccessor* TryGetContextAccessor() {
return task_.TryGetContextAccessor(); }
349 template <
typename TRequestStrategy>
350 friend auto HedgeRequestAsync(TRequestStrategy input,
HedgingSettings settings);
352 HedgedRequestFuture(Task&& task)
353 : task_(std::move(task))
365template <
typename RequestStrategy>
368 using Action = impl::Action;
369 using Clock = impl::Clock;
370 auto context = impl::Context(std::move(inputs), std::move(hedging_settings));
372 auto& sub_requests = context.GetSubRequests();
374 auto wakeup_time = Clock::now();
375 context.Prepare(wakeup_time);
377 while (!context.IsStop()) {
378 auto wait_result =
engine::WaitAnyUntil(wakeup_time, sub_requests);
379 if (!wait_result.has_value()) {
381 context.OnActionStop();
385 auto plan_entry = context.PopPlan();
386 if (!plan_entry.has_value()) {
390 const auto [timestamp, request_index, attempt_id, action] = *plan_entry;
392 case Action::kStartTry:
393 context.OnActionStartTry(request_index, attempt_id, timestamp);
396 context.OnActionStop();
399 auto next_wakeup_time = context.NextEventTime();
400 if (!next_wakeup_time.has_value()) {
403 wakeup_time = *next_wakeup_time;
406 const auto result_idx = *wait_result;
407 UASSERT(result_idx < sub_requests.size());
408 const auto request_idx = context.GetRequestIdxBySubrequestIdx(result_idx);
409 auto& strategy = context.GetStrategy(request_idx);
411 auto& request = sub_requests[result_idx].request;
412 UASSERT_MSG(request,
"Finished requests must not be empty");
413 auto reply = strategy.ProcessReply(std::move(*request));
414 if (reply.has_value()) {
417 context.OnRetriableReply(request_idx, *reply, Clock::now());
419 wakeup_time = *context.NextEventTime();
421 context.OnNonRetriableReply(request_idx);
424 return context.ExtractAllReplies();
435template <
typename RequestStrategy>
438 "hedged-bulk-request",
439 [inputs{std::move(inputs)}, settings{std::move(settings)}]()
mutable {
440 return HedgeRequestsBulk(std::move(inputs), std::move(settings));
449template <
typename RequestStrategy>
451 RequestStrategy input,
454 std::vector<RequestStrategy> inputs;
455 inputs.emplace_back(std::move(input));
456 auto bulk_ret = HedgeRequestsBulk(std::move(inputs), std::move(settings));
457 if (bulk_ret.size() != 1) {
467template <
typename RequestStrategy>
471 [input{std::move(input)}, settings{std::move(settings)}]()
mutable {
472 return HedgeRequest(std::move(input), std::move(settings));