59#include <userver/engine/task/cancel.hpp>
60#include <userver/engine/wait_any.hpp>
61#include <userver/utils/assert.hpp>
62#include <userver/utils/async.hpp>
63#include <userver/utils/datetime.hpp>
65USERVER_NAMESPACE_BEGIN
67namespace utils::hedging {
79template <
typename RequestStrategy>
82 typename std::invoke_result_t<
decltype(&RequestStrategy::Create), RequestStrategy,
int>::value_type;
84 typename std::invoke_result_t<
decltype(&RequestStrategy::ExtractReply), RequestStrategy>::value_type;
89using Clock = utils::datetime::SteadyClock;
90using TimePoint = Clock::time_point;
92enum class Action { StartTry, Stop };
96 PlanEntry(TimePoint timepoint, std::size_t request_index, std::size_t attempt_id, Action action)
97 : timepoint(timepoint), request_index(request_index), attempt_id(attempt_id), action(action) {}
99 bool operator<(
const PlanEntry& other)
const noexcept {
return tie() < other.tie(); }
100 bool operator>(
const PlanEntry& other)
const noexcept {
return tie() > other.tie(); }
101 bool operator==(
const PlanEntry& other)
const noexcept {
return tie() == other.tie(); }
102 bool operator<=(
const PlanEntry& other)
const noexcept {
return tie() <= other.tie(); }
103 bool operator>=(
const PlanEntry& other)
const noexcept {
return tie() >= other.tie(); }
104 bool operator!=(
const PlanEntry& other)
const noexcept {
return tie() != other.tie(); }
107 std::size_t request_index{0};
108 std::size_t attempt_id{0};
112 std::tuple<
const TimePoint&,
const size_t&,
const size_t&,
const Action&> tie()
const noexcept {
113 return std::tie(timepoint, request_index, attempt_id, action);
119template <
typename RequestStrategy>
120struct SubrequestWrapper {
121 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
123 SubrequestWrapper() =
default;
124 SubrequestWrapper(SubrequestWrapper&&)
noexcept =
default;
125 explicit SubrequestWrapper(std::optional<RequestType>&& request) : request(std::move(request)) {}
127 engine::impl::ContextAccessor* TryGetContextAccessor() {
128 if (!request)
return nullptr;
129 return request->TryGetContextAccessor();
132 std::optional<RequestType> request;
136 std::vector<std::size_t> subrequest_indices;
137 std::size_t attempts_made = 0;
138 bool finished =
false;
141template <
typename RequestStrategy>
143 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
144 using ReplyType =
typename RequestTraits<RequestStrategy>::ReplyType;
147 : inputs_(std::move(inputs)), settings(std::move(settings)) {
148 const std::size_t size =
this->inputs_.size();
149 request_states_.resize(size);
151 Context(Context&&)
noexcept =
default;
153 void Prepare(TimePoint start_time) {
154 const auto request_count = GetRequestsCount();
155 for (std::size_t request_id = 0; request_id < request_count; ++request_id) {
156 plan_.emplace(start_time, request_id, 0, Action::StartTry);
158 plan_.emplace(start_time + settings
.timeout_all, 0, 0, Action::Stop);
162 std::optional<TimePoint> NextEventTime()
const {
163 if (plan_.empty())
return std::nullopt;
164 return plan_.top().timepoint;
166 std::optional<PlanEntry> PopPlan() {
167 if (plan_.empty())
return std::nullopt;
168 auto ret = plan_.top();
172 bool IsStop()
const {
return stop_; }
174 void FinishAllSubrequests(std::size_t request_index) {
175 auto& request_state = request_states_[request_index];
176 request_state.finished =
true;
177 const auto& subrequest_indices = request_state.subrequest_indices;
178 auto& strategy = inputs_[request_index];
179 for (
auto i : subrequest_indices) {
180 auto& request = subrequests_[i].request;
182 strategy.Finish(std::move(*request));
190 size_t GetRequestsCount()
const {
return inputs_.size(); }
192 size_t GetRequestIdxBySubrequestIdx(size_t subrequest_idx)
const {
193 return input_by_subrequests_.at(subrequest_idx);
196 RequestStrategy& GetStrategy(size_t index) {
return inputs_[index]; }
198 auto& GetSubRequests() {
return subrequests_; }
200 std::vector<std::optional<ReplyType>> ExtractAllReplies() {
201 std::vector<std::optional<ReplyType>> ret;
202 ret.reserve(GetRequestsCount());
203 for (
auto&& strategy : inputs_) {
204 ret.emplace_back(strategy.ExtractReply());
213 void OnActionStop() {
214 for (std::size_t i = 0; i < inputs_.size(); ++i) FinishAllSubrequests(i);
220 void OnActionStartTry(std::size_t request_index, std::size_t attempt_id, TimePoint now) {
221 auto& request_state = request_states_[request_index];
222 if (request_state.finished) {
225 auto& attempts_made = request_state.attempts_made;
231 if (attempt_id < attempts_made) {
238 auto& strategy = inputs_[request_index];
239 auto request_opt = strategy.Create(attempts_made);
241 request_state.finished =
true;
246 const auto idx = subrequests_.size();
247 subrequests_.emplace_back(std::move(request_opt));
248 request_state.subrequest_indices.push_back(idx);
249 input_by_subrequests_[idx] = request_index;
251 plan_.emplace(now + settings
.hedging_delay, request_index, attempts_made, Action::StartTry);
255 void OnRetriableReply(std::size_t request_idx, std::chrono::milliseconds retry_delay, TimePoint now) {
256 const auto& request_state = request_states_[request_idx];
257 if (request_state.finished)
return;
258 if (request_state.attempts_made >= settings
.max_attempts)
return;
260 plan_.emplace(now + retry_delay, request_idx, request_state.attempts_made, Action::StartTry);
263 void OnNonRetriableReply(std::size_t request_idx) { FinishAllSubrequests(request_idx); }
268 std::vector<RequestStrategy> inputs_;
272 std::priority_queue<PlanEntry, std::vector<PlanEntry>, std::greater<>> plan_{};
273 std::vector<SubrequestWrapper<RequestStrategy>> subrequests_{};
275 std::unordered_map<std::size_t, std::size_t> input_by_subrequests_{};
276 std::vector<RequestState> request_states_{};
283template <
typename RequestStrategy>
285 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
286 using ReplyType =
typename RequestTraits<RequestStrategy>::ReplyType;
289 ~HedgedRequestBulkFuture() { task_.SyncCancel(); }
295 std::vector<std::optional<ReplyType>>
Get() {
return task_.Get(); }
297 engine::impl::ContextAccessor* TryGetContextAccessor() {
return task_.TryGetContextAccessor(); }
300 template <
typename TRequestStrategy>
301 friend auto HedgeRequestsBulkAsync(std::vector<TRequestStrategy> inputs,
HedgingSettings settings);
302 using Task = engine::TaskWithResult<std::vector<std::optional<ReplyType>>>;
303 HedgedRequestBulkFuture(Task&& task) : task_(std::move(task)) {}
308template <
typename RequestStrategy>
310 using RequestType =
typename RequestTraits<RequestStrategy>::RequestType;
311 using ReplyType =
typename RequestTraits<RequestStrategy>::ReplyType;
314 ~HedgedRequestFuture() { task_.SyncCancel(); }
320 std::optional<ReplyType>
Get() {
return task_.Get(); }
322 void IgnoreResult() {}
324 engine::impl::ContextAccessor* TryGetContextAccessor() {
return task_.TryGetContextAccessor(); }
327 template <
typename TRequestStrategy>
328 friend auto HedgeRequestAsync(TRequestStrategy input,
HedgingSettings settings);
329 using Task = engine::TaskWithResult<std::optional<ReplyType>>;
330 HedgedRequestFuture(Task&& task) : task_(std::move(task)) {}
341template <
typename RequestStrategy>
344 using Action = impl::Action;
345 using Clock = impl::Clock;
346 auto context = impl::Context(std::move(inputs), std::move(hedging_settings));
348 auto& sub_requests = context.GetSubRequests();
350 auto wakeup_time = Clock::now();
351 context.Prepare(wakeup_time);
353 while (!context.IsStop()) {
354 auto wait_result = engine::WaitAnyUntil(wakeup_time, sub_requests);
355 if (!wait_result.has_value()) {
356 if (engine::current_task::ShouldCancel()) {
357 context.OnActionStop();
361 auto plan_entry = context.PopPlan();
362 if (!plan_entry.has_value()) {
366 const auto [timestamp, request_index, attempt_id, action] = *plan_entry;
368 case Action::StartTry:
369 context.OnActionStartTry(request_index, attempt_id, timestamp);
372 context.OnActionStop();
375 auto next_wakeup_time = context.NextEventTime();
376 if (!next_wakeup_time.has_value()) {
379 wakeup_time = *next_wakeup_time;
382 const auto result_idx = *wait_result;
383 UASSERT(result_idx < sub_requests.size());
384 const auto request_idx = context.GetRequestIdxBySubrequestIdx(result_idx);
385 auto& strategy = context.GetStrategy(request_idx);
387 auto& request = sub_requests[result_idx].request;
388 UASSERT_MSG(request,
"Finished requests must not be empty");
389 auto reply = strategy.ProcessReply(std::move(*request));
390 if (reply.has_value()) {
393 context.OnRetriableReply(request_idx, *reply, Clock::now());
395 wakeup_time = *context.NextEventTime();
397 context.OnNonRetriableReply(request_idx);
400 return context.ExtractAllReplies();
411template <
typename RequestStrategy>
414 "hedged-bulk-request",
415 [inputs{std::move(inputs)}, settings{std::move(settings)}]()
mutable {
416 return HedgeRequestsBulk(std::move(inputs), std::move(settings));
425template <
typename RequestStrategy>
428 std::vector<RequestStrategy> inputs;
429 inputs.emplace_back(std::move(input));
430 auto bulk_ret = HedgeRequestsBulk(std::move(inputs), std::move(settings));
431 if (bulk_ret.size() != 1) {
441template <
typename RequestStrategy>
445 [input{std::move(input)}, settings{std::move(settings)}]()
mutable {
446 return HedgeRequest(std::move(input), std::move(settings));