59#include <userver/engine/task/cancel.hpp>
60#include <userver/engine/wait_any.hpp>
61#include <userver/utils/assert.hpp>
62#include <userver/utils/async.hpp>
63#include <userver/utils/datetime.hpp>
65USERVER_NAMESPACE_BEGIN
67namespace utils::hedging {
79template <
typename RequestStrategy>
82 typename std::invoke_result_t<
decltype(&RequestStrategy::Create), RequestStrategy,
int>::value_type;
84 typename std::invoke_result_t<
decltype(&RequestStrategy::ExtractReply), RequestStrategy>::value_type;
90using TimePoint = Clock::time_point;
92enum class Action { kStartTry, kStop };
95 PlanEntry(TimePoint timepoint, std::size_t request_index, std::size_t attempt_id, Action action)
96 : timepoint(timepoint),
97 request_index(request_index),
98 attempt_id(attempt_id),
102 bool operator<(
const PlanEntry& other)
const noexcept {
return Tie() < other.Tie(); }
103 bool operator>(
const PlanEntry& other)
const noexcept {
return Tie() > other.Tie(); }
104 bool operator==(
const PlanEntry& other)
const noexcept {
return Tie() == other.Tie(); }
105 bool operator<=(
const PlanEntry& other)
const noexcept {
return Tie() <= other.Tie(); }
106 bool operator>=(
const PlanEntry& other)
const noexcept {
return Tie() >= other.Tie(); }
107 bool operator!=(
const PlanEntry& other)
const noexcept {
return Tie() != other.Tie(); }
109 std::tuple<
const TimePoint&,
const size_t&,
const size_t&,
const Action&> Tie()
const noexcept {
110 return std::tie(timepoint, request_index, attempt_id, action);
114 std::size_t request_index{0};
115 std::size_t attempt_id{0};
121template <
typename RequestStrategy>
122struct SubrequestWrapper {
123 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
125 SubrequestWrapper() =
default;
126 SubrequestWrapper(SubrequestWrapper&&)
noexcept =
default;
127 explicit SubrequestWrapper(std::optional<RequestType>&& request)
128 : request(std::move(request))
131 engine::
impl::ContextAccessor* TryGetContextAccessor() {
135 return request->TryGetContextAccessor();
138 std::optional<RequestType> request;
142 std::vector<std::size_t> subrequest_indices;
143 std::size_t attempts_made = 0;
144 bool finished =
false;
147template <
typename RequestStrategy>
149 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
150 using ReplyType =
typename RequestTraits<RequestStrategy>::ReplyType;
153 : inputs_(std::move(inputs)),
154 settings_(std::move(settings))
156 const std::size_t size =
this->inputs_.size();
157 request_states_.resize(size);
159 Context(Context&&)
noexcept =
default;
161 void Prepare(TimePoint start_time) {
162 const auto request_count = GetRequestsCount();
163 for (std::size_t request_id = 0; request_id < request_count; ++request_id) {
164 plan_.emplace(start_time, request_id, 0, Action::kStartTry);
166 plan_.emplace(start_time + settings_
.timeout_all, 0, 0, Action::kStop);
167 subrequests_.reserve(settings_
.max_attempts * request_count);
170 std::optional<TimePoint> NextEventTime()
const {
174 return plan_.top().timepoint;
176 std::optional<PlanEntry> PopPlan() {
180 auto ret = plan_.top();
184 bool IsStop()
const {
return stop_; }
186 void FinishAllSubrequests(std::size_t request_index) {
187 auto& request_state = request_states_[request_index];
188 request_state.finished =
true;
189 const auto& subrequest_indices = request_state.subrequest_indices;
190 auto& strategy = inputs_[request_index];
191 for (
auto i : subrequest_indices) {
192 auto& request = subrequests_[i].request;
194 strategy.Finish(std::move(*request));
202 size_t GetRequestsCount()
const {
return inputs_.size(); }
204 size_t GetRequestIdxBySubrequestIdx(size_t subrequest_idx)
const {
205 return input_by_subrequests_.at(subrequest_idx);
208 RequestStrategy& GetStrategy(size_t index) {
return inputs_[index]; }
210 auto& GetSubRequests() {
return subrequests_; }
212 std::vector<std::optional<ReplyType>> ExtractAllReplies() {
213 std::vector<std::optional<ReplyType>> ret;
214 ret.reserve(GetRequestsCount());
215 for (
auto&& strategy : inputs_) {
216 ret.emplace_back(strategy.ExtractReply());
225 void OnActionStop() {
226 for (std::size_t i = 0; i < inputs_.size(); ++i) {
227 FinishAllSubrequests(i);
234 void OnActionStartTry(std::size_t request_index, std::size_t attempt_id, TimePoint now) {
235 auto& request_state = request_states_[request_index];
236 if (request_state.finished) {
239 auto& attempts_made = request_state.attempts_made;
245 if (attempt_id < attempts_made) {
252 auto& strategy = inputs_[request_index];
253 auto request_opt = strategy.Create(attempts_made);
255 request_state.finished =
true;
260 const auto idx = subrequests_.size();
261 subrequests_.emplace_back(std::move(request_opt));
262 request_state.subrequest_indices.push_back(idx);
263 input_by_subrequests_[idx] = request_index;
265 plan_.emplace(now + settings_
.hedging_delay, request_index, attempts_made, Action::kStartTry);
269 void OnRetriableReply(std::size_t request_idx, std::chrono::milliseconds retry_delay, TimePoint now) {
270 const auto& request_state = request_states_[request_idx];
271 if (request_state.finished) {
278 plan_.emplace(now + retry_delay, request_idx, request_state.attempts_made, Action::kStartTry);
281 void OnNonRetriableReply(std::size_t request_idx) { FinishAllSubrequests(request_idx); }
286 std::vector<RequestStrategy> inputs_;
290 std::priority_queue<PlanEntry, std::vector<PlanEntry>, std::greater<>> plan_{};
291 std::vector<SubrequestWrapper<RequestStrategy>> subrequests_{};
293 std::unordered_map<std::size_t, std::size_t> input_by_subrequests_{};
294 std::vector<RequestState> request_states_{};
301template <
typename RequestStrategy>
303 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
304 using ReplyType =
typename RequestTraits<RequestStrategy>::ReplyType;
307 ~HedgedRequestBulkFuture() { task_.SyncCancel(); }
313 std::vector<std::optional<ReplyType>>
Get() {
return task_.Get(); }
315 engine::
impl::ContextAccessor* TryGetContextAccessor() {
return task_.TryGetContextAccessor(); }
318 template <
typename TRequestStrategy>
319 friend auto HedgeRequestsBulkAsync(std::vector<TRequestStrategy> inputs,
HedgingSettings settings);
320 using Task = engine::
TaskWithResult<std::vector<std::optional<ReplyType>>>;
321 HedgedRequestBulkFuture(Task&& task)
322 : task_(std::move(task))
328template <
typename RequestStrategy>
330 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
331 using ReplyType =
typename RequestTraits<RequestStrategy>::ReplyType;
334 ~HedgedRequestFuture() { task_.SyncCancel(); }
340 std::optional<ReplyType>
Get() {
return task_.Get(); }
342 void IgnoreResult() {}
344 engine::
impl::ContextAccessor* TryGetContextAccessor() {
return task_.TryGetContextAccessor(); }
347 template <
typename TRequestStrategy>
348 friend auto HedgeRequestAsync(TRequestStrategy input,
HedgingSettings settings);
350 HedgedRequestFuture(Task&& task)
351 : task_(std::move(task))
363template <
typename RequestStrategy>
366 using Action = impl::Action;
367 using Clock = impl::Clock;
368 auto context = impl::Context(std::move(inputs), std::move(hedging_settings));
370 auto& sub_requests = context.GetSubRequests();
372 auto wakeup_time = Clock::now();
373 context.Prepare(wakeup_time);
375 while (!context.IsStop()) {
376 auto wait_result = engine::WaitAnyUntil(wakeup_time, sub_requests);
377 if (!wait_result.has_value()) {
379 context.OnActionStop();
383 auto plan_entry = context.PopPlan();
384 if (!plan_entry.has_value()) {
388 const auto [timestamp, request_index, attempt_id, action] = *plan_entry;
390 case Action::kStartTry:
391 context.OnActionStartTry(request_index, attempt_id, timestamp);
394 context.OnActionStop();
397 auto next_wakeup_time = context.NextEventTime();
398 if (!next_wakeup_time.has_value()) {
401 wakeup_time = *next_wakeup_time;
404 const auto result_idx = *wait_result;
405 UASSERT(result_idx < sub_requests.size());
406 const auto request_idx = context.GetRequestIdxBySubrequestIdx(result_idx);
407 auto& strategy = context.GetStrategy(request_idx);
409 auto& request = sub_requests[result_idx].request;
410 UASSERT_MSG(request,
"Finished requests must not be empty");
411 auto reply = strategy.ProcessReply(std::move(*request));
412 if (reply.has_value()) {
415 context.OnRetriableReply(request_idx, *reply, Clock::now());
417 wakeup_time = *context.NextEventTime();
419 context.OnNonRetriableReply(request_idx);
422 return context.ExtractAllReplies();
433template <
typename RequestStrategy>
436 "hedged-bulk-request",
437 [inputs{std::move(inputs)}, settings{std::move(settings)}]()
mutable {
438 return HedgeRequestsBulk(std::move(inputs), std::move(settings));
447template <
typename RequestStrategy>
449 RequestStrategy input,
452 std::vector<RequestStrategy> inputs;
453 inputs.emplace_back(std::move(input));
454 auto bulk_ret = HedgeRequestsBulk(std::move(inputs), std::move(settings));
455 if (bulk_ret.size() != 1) {
465template <
typename RequestStrategy>
469 [input{std::move(input)}, settings{std::move(settings)}]()
mutable {
470 return HedgeRequest(std::move(input), std::move(settings));