userver: /data/code/userver/libraries/grpc-protovalidate/src/grpc-protovalidate/client/middleware.cpp Source File
Loading...
Searching...
No Matches
middleware.cpp
1#include <grpc-protovalidate/client/middleware.hpp>
2
3#include <google/protobuf/arena.h>
4
5#include <userver/grpc-protovalidate/client/exceptions.hpp>
6#include <userver/grpc-protovalidate/validate.hpp>
7#include <userver/logging/log.hpp>
8#include <userver/ugrpc/client/exceptions.hpp>
9#include <userver/utils/assert.hpp>
10
11USERVER_NAMESPACE_BEGIN
12
13namespace grpc_protovalidate::client {
14
15const ValidationSettings& Settings::Get(std::string_view method_name) const {
16 auto it = per_method.find(method_name);
17 return it != per_method.end() ? it->second : global;
18}
19
20Middleware::Middleware(const Settings& settings)
21 : settings_(settings)
22{}
23
24Middleware::~Middleware() = default;
25
26void Middleware::PreSendMessage(ugrpc::client::MiddlewareCallContext& context, const google::protobuf::Message& request)
27 const {
28 const ValidationSettings& settings = settings_.Get(context.GetCallName());
29 if (!settings.validate_requests) {
30 return;
31 }
32 const ValidationResult result = ValidateMessage(request, {.fail_fast = settings.fail_fast});
33 if (result.IsSuccess()) {
34 return;
35 }
36 const ValidationError& error = result.GetError();
37 switch (error.GetType()) {
38 case ValidationError::Type::kInternal:
39 throw ValidatorError(context.GetCallName());
40 case ValidationError::Type::kRule:
41 LOG_WARNING() << error;
42 throw RequestError(context.GetCallName(), error.GetViolations());
43 }
44 UINVARIANT(false, "Unexpected error type");
45}
46
47void Middleware::PostRecvMessage(
48 ugrpc::client::MiddlewareCallContext& context,
49 const google::protobuf::Message& response
50) const {
51 const ValidationSettings& settings = settings_.Get(context.GetCallName());
52 const ValidationResult result = ValidateMessage(response, {.fail_fast = settings.fail_fast});
53 if (result.IsSuccess()) {
54 return;
55 }
56 const ValidationError& error = result.GetError();
57 switch (error.GetType()) {
58 case ValidationError::Type::kInternal:
59 throw ValidatorError(context.GetCallName());
60 case ValidationError::Type::kRule:
61 LOG_WARNING() << error;
62 throw ResponseError(context.GetCallName(), error.GetViolations());
63 }
64 UINVARIANT(false, "Unexpected error type");
65}
66
67} // namespace grpc_protovalidate::client
68
69USERVER_NAMESPACE_END