userver: userver/ugrpc/server/impl/service_worker_impl.hpp Source File
Loading...
Searching...
No Matches
service_worker_impl.hpp
1#pragma once
2
3#include <chrono>
4#include <exception>
5#include <functional>
6#include <string>
7#include <string_view>
8#include <type_traits>
9#include <utility>
10
11#include <grpcpp/completion_queue.h>
12#include <grpcpp/impl/service_type.h>
13#include <grpcpp/server_context.h>
14
15#include <userver/engine/async.hpp>
16#include <userver/engine/task/cancel.hpp>
17#include <userver/engine/task/task_processor_fwd.hpp>
18#include <userver/server/request/task_inherited_data.hpp>
19#include <userver/tracing/in_place_span.hpp>
20#include <userver/tracing/span.hpp>
21#include <userver/utils/assert.hpp>
22#include <userver/utils/fast_scope_guard.hpp>
23#include <userver/utils/impl/wait_token_storage.hpp>
24#include <userver/utils/lazy_prvalue.hpp>
25#include <userver/utils/statistics/entry.hpp>
26
27#include <userver/ugrpc/impl/static_metadata.hpp>
28#include <userver/ugrpc/impl/statistics.hpp>
29#include <userver/ugrpc/impl/statistics_scope.hpp>
30#include <userver/ugrpc/server/impl/async_method_invocation.hpp>
31#include <userver/ugrpc/server/impl/async_service.hpp>
32#include <userver/ugrpc/server/impl/call_params.hpp>
33#include <userver/ugrpc/server/impl/call_traits.hpp>
34#include <userver/ugrpc/server/impl/error_code.hpp>
35#include <userver/ugrpc/server/impl/service_worker.hpp>
36#include <userver/ugrpc/server/middlewares/base.hpp>
37#include <userver/ugrpc/server/rpc.hpp>
38#include <userver/ugrpc/server/service_base.hpp>
39
40USERVER_NAMESPACE_BEGIN
41
42namespace ugrpc::server::impl {
43
44void ReportHandlerError(const std::exception& ex, std::string_view call_name,
45 tracing::Span& span) noexcept;
46
47void ReportNetworkError(const RpcInterruptedError& ex,
48 std::string_view call_name,
49 tracing::Span& span) noexcept;
50
51void ReportCustomError(
52 const USERVER_NAMESPACE::server::handlers::CustomHandlerException& ex,
53 CallAnyBase& call, tracing::Span& span);
54
55void SetupSpan(std::optional<tracing::InPlaceSpan>& span_holder,
56 grpc::ServerContext& context, std::string_view call_name);
57
58/// Per-gRPC-service data
59template <typename GrpcppService>
60struct ServiceData final {
61 ServiceData(const ServiceSettings& settings,
62 const ugrpc::impl::StaticServiceMetadata& metadata)
63 : settings(settings),
64 metadata(metadata),
65 statistics(settings.statistics_storage.GetServiceStatistics(metadata)) {
66 }
67
68 ~ServiceData() = default;
69
70 const ServiceSettings settings;
71 const ugrpc::impl::StaticServiceMetadata metadata;
72 AsyncService<GrpcppService> async_service{metadata.method_full_names.size()};
73 utils::impl::WaitTokenStorage wait_tokens;
74 ugrpc::impl::ServiceStatistics& statistics;
75};
76
77/// Per-gRPC-method data
78template <typename GrpcppService, typename CallTraits>
79struct MethodData final {
80 ServiceData<GrpcppService>& service_data;
81 const std::size_t method_id{};
82 typename CallTraits::ServiceBase& service;
83 const typename CallTraits::ServiceMethod service_method;
84
85 std::string_view call_name{
86 service_data.metadata.method_full_names[method_id]};
87 // Remove name of the service and slash
88 std::string_view method_name{
89 call_name.substr(service_data.metadata.service_full_name.size() + 1)};
90 ugrpc::impl::MethodStatistics& statistics{
91 service_data.statistics.GetMethodStatistics(method_id)};
92};
93
94template <typename GrpcppService, typename CallTraits>
95class CallData final {
96 public:
97 explicit CallData(const MethodData<GrpcppService, CallTraits>& method_data)
98 : wait_token_(method_data.service_data.wait_tokens.GetToken()),
99 method_data_(method_data) {
100 UASSERT(method_data.method_id <
101 method_data.service_data.metadata.method_full_names.size());
102 }
103
104 void operator()() && {
105 // Based on the tensorflow code, we must first call AsyncNotifyWhenDone
106 // and only then Prepare<>
107 // see
108 // https://git.ecdf.ed.ac.uk/s1886313/tensorflow/-/blob/438604fc885208ee05f9eef2d0f2c630e1360a83/tensorflow/core/distributed_runtime/rpc/grpc_call.h#L201
109 // and grpc::ServerContext::AsyncNotifyWhenDone
110 ugrpc::server::impl::RpcFinishedEvent notify_when_done(
111 engine::current_task::GetCancellationToken(), context_);
112
113 context_.AsyncNotifyWhenDone(notify_when_done.GetTag());
114
115 // the request for an incoming RPC must be performed synchronously
116 auto& queue = method_data_.service_data.settings.queue;
117 method_data_.service_data.async_service.template Prepare<CallTraits>(
118 method_data_.method_id, context_, initial_request_, raw_responder_,
119 queue, queue, prepare_.GetTag());
120
121 // Note: we ignore task cancellations here. Even if notify_when_done has
122 // already cancelled this RPC, we want to:
123 // 1. listen to further RPCs for the same method
124 // 2. handle this RPC correctly, including metrics, logs, etc.
125 if (Wait(prepare_) != impl::AsyncMethodInvocation::WaitStatus::kOk) {
126 // the CompletionQueue is shutting down
127
128 // Do not wait for notify_when_done. When queue is shutting down, it will
129 // not be called.
130 // https://github.com/grpc/grpc/issues/10136
131 return;
132 }
133
134 // start a concurrent listener immediately, as advised by gRPC docs
135 ListenAsync(method_data_);
136
137 HandleRpc();
138
139 // Even if we finished before receiving notification that call is done, we
140 // should wait on this async operation. CompletionQueue has a pointer to
141 // stack-allocated object, that object is going to be freed upon exit. To
142 // prevent segfaults, wait until queue is done with this object.
143 notify_when_done.Wait();
144 }
145
146 static void ListenAsync(const MethodData<GrpcppService, CallTraits>& data) {
147 engine::CriticalAsyncNoSpan(
148 data.service_data.settings.task_processor,
149 utils::LazyPrvalue([&] { return CallData(data); }))
150 .Detach();
151 }
152
153 private:
154 using InitialRequest = typename CallTraits::InitialRequest;
155 using RawCall = typename CallTraits::RawCall;
156 using Call = typename CallTraits::Call;
157
158 void HandleRpc() {
159 const auto call_name = method_data_.call_name;
160 auto& service = method_data_.service;
161 const auto service_method = method_data_.service_method;
162
163 const auto& service_name =
164 method_data_.service_data.metadata.service_full_name;
165 const auto& method_name = method_data_.method_name;
166
167 SetupSpan(span_, context_, call_name);
168 utils::FastScopeGuard destroy_span([&]() noexcept { span_.reset(); });
169
170 ugrpc::impl::RpcStatisticsScope statistics_scope(method_data_.statistics);
171
172 auto& access_tskv_logger =
173 method_data_.service_data.settings.access_tskv_logger;
174 Call responder(CallParams{context_, call_name, statistics_scope,
175 *access_tskv_logger, span_->Get()},
176 raw_responder_);
177 auto do_call = [&] {
178 if constexpr (std::is_same_v<InitialRequest, NoInitialRequest>) {
179 (service.*service_method)(responder);
180 } else {
181 (service.*service_method)(responder, std::move(initial_request_));
182 }
183 };
184
185 try {
186 ::google::protobuf::Message* initial_request = nullptr;
187 if constexpr (!std::is_same_v<InitialRequest, NoInitialRequest>) {
188 initial_request = &initial_request_;
189 }
190
191 auto& middlewares = method_data_.service_data.settings.middlewares;
192 MiddlewareCallContext middleware_context(
193 middlewares, responder, do_call, service_name, method_name,
194 method_data_.service_data.settings.config_source.GetSnapshot(),
195 initial_request);
196 middleware_context.Next();
197 } catch (
198 const USERVER_NAMESPACE::server::handlers::CustomHandlerException& ex) {
199 ReportCustomError(ex, responder, span_->Get());
200 } catch (const RpcInterruptedError& ex) {
201 ReportNetworkError(ex, call_name, span_->Get());
202 statistics_scope.OnNetworkError();
203 } catch (const std::exception& ex) {
204 ReportHandlerError(ex, call_name, span_->Get());
205 }
206 }
207
208 // 'wait_token_' must be the first field, because its lifetime keeps
209 // ServiceData alive during server shutdown.
210 const utils::impl::WaitTokenStorage::Token wait_token_;
211
212 MethodData<GrpcppService, CallTraits> method_data_;
213
214 grpc::ServerContext context_{};
215 InitialRequest initial_request_{};
216 RawCall raw_responder_{&context_};
217 ugrpc::impl::AsyncMethodInvocation prepare_;
218 std::optional<tracing::InPlaceSpan> span_{};
219};
220
221template <typename GrpcppService>
222class ServiceWorkerImpl final : public ServiceWorker {
223 public:
224 template <typename Service, typename... ServiceMethods>
225 ServiceWorkerImpl(ServiceSettings&& settings,
226 ugrpc::impl::StaticServiceMetadata&& metadata,
227 Service& service, ServiceMethods... service_methods)
228 : service_data_(settings, metadata),
229 start_{[this, &service, service_methods...] {
230 std::size_t method_id = 0;
231 (CallData<GrpcppService, CallTraits<ServiceMethods>>::ListenAsync(
232 {service_data_, method_id++, service, service_methods}),
233 ...);
234 }} {}
235
236 ~ServiceWorkerImpl() override {
237 service_data_.wait_tokens.WaitForAllTokens();
238 }
239
240 grpc::Service& GetService() override { return service_data_.async_service; }
241
242 const ugrpc::impl::StaticServiceMetadata& GetMetadata() const override {
243 return service_data_.metadata;
244 }
245
246 void Start() override { start_(); }
247
248 private:
249 ServiceData<GrpcppService> service_data_;
250 std::function<void()> start_;
251};
252
253// Called from 'MakeWorker' of code-generated service base classes
254template <typename GrpcppService, typename Service, typename... ServiceMethods>
255std::unique_ptr<ServiceWorker> MakeServiceWorker(
256 ServiceSettings&& settings,
257 const std::string_view (&method_full_names)[sizeof...(ServiceMethods)],
258 Service& service, ServiceMethods... service_methods) {
259 return std::make_unique<ServiceWorkerImpl<GrpcppService>>(
260 std::move(settings),
261 ugrpc::impl::MakeStaticServiceMetadata<GrpcppService>(method_full_names),
262 service, service_methods...);
263}
264
265} // namespace ugrpc::server::impl
266
267USERVER_NAMESPACE_END