userver: /home/antonyzhilin/arcadia/taxi/uservices/userver/libraries/grpc-protovalidate/tests/validator_service_test.cpp Source File
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages Concepts
validator_service_test.cpp
1#include <userver/utest/utest.hpp>
2
3#include <buf/validate/validate.pb.h>
4
5#include <userver/engine/async.hpp>
6#include <userver/ugrpc/client/exceptions.hpp>
7#include <userver/ugrpc/tests/service_fixtures.hpp>
8
9#include <types/unit_test_client.usrv.pb.hpp>
10#include <types/unit_test_service.usrv.pb.hpp>
11
12#include <grpc-protovalidate/server/middleware.hpp>
13#include "utils.hpp"
14
15USERVER_NAMESPACE_BEGIN
16
17namespace {
18
19class UnitTestServiceValidator final : public types::UnitTestServiceBase {
20public:
21 CheckConstraintsUnaryResult CheckConstraintsUnary(CallContext&, types::ConstrainedRequest&& request) override {
22 types::ConstrainedResponse response;
23 response.set_field(request.field());
24 return response;
25 }
26
27 CheckConstraintsStreamingResult
28 CheckConstraintsStreaming(CallContext&, CheckConstraintsStreamingReaderWriter& stream) override {
29 types::ConstrainedRequest request;
30 types::ConstrainedResponse response{};
31 while (stream.Read(request)) {
32 response.set_field(request.field());
33 stream.Write(response);
34 }
35 return grpc::Status::OK;
36 }
37
38 CheckInvalidRequestConstraintsResult CheckInvalidRequestConstraints(CallContext&, types::InvalidConstraints&&)
39 override {
40 return google::protobuf::Empty{};
41 }
42};
43
44class GrpcServerValidatorTest : public ugrpc::tests::ServiceFixtureBase,
45 public testing::WithParamInterface<grpc_protovalidate::server::Settings> {
46public:
47 GrpcServerValidatorTest() {
48 SetServerMiddlewares({std::make_shared<grpc_protovalidate::server::Middleware>(GetParam())});
49 RegisterService(service_);
50 StartServer();
51 }
52
53 ~GrpcServerValidatorTest() override { StopServer(); }
54
55private:
56 UnitTestServiceValidator service_;
57};
58
59} // namespace
60
61INSTANTIATE_UTEST_SUITE_P(
62 /*no prefix*/,
63 GrpcServerValidatorTest,
64 testing::Values(
65 grpc_protovalidate::server::Settings{
66 .per_method =
67 {
68 {"types.UnitTestService/CheckConstraintsUnary", {.fail_fast = false, .send_violations = true}},
69 {"/UnknownMethod", {.fail_fast = true, .send_violations = false}},
70 }},
71 grpc_protovalidate::server::Settings{
72 .global =
73 {
74 .fail_fast = false,
75 .send_violations = true,
76 },
77 .per_method =
78 {
79 {"types.UnitTestService/CheckConstraintsUnary", {.fail_fast = false, .send_violations = true}},
80 {"/UnknownMethod", {.fail_fast = true, .send_violations = false}},
81 }},
82 grpc_protovalidate::server::Settings{
83 .global =
84 {
85 .fail_fast = true,
86 .send_violations = true,
87 },
88 .per_method =
89 {
90 {"types.UnitTestService/CheckConstraintsUnary", {.fail_fast = false, .send_violations = true}},
91 {"/UnknownMethod", {.fail_fast = true, .send_violations = false}},
92 }}
93 )
94);
95
96UTEST_P_MT(GrpcServerValidatorTest, AllValid, 2) {
97 constexpr std::size_t kRequestCount = 3;
98 auto client = MakeClient<types::UnitTestServiceClient>();
99 auto stream = client.CheckConstraintsStreaming();
100
101 std::vector<types::ConstrainedMessage> messages;
102 std::vector<types::ConstrainedRequest> requests(kRequestCount);
103 std::vector<types::ConstrainedResponse> responses;
104
105 types::ConstrainedMessage msg;
106 types::ConstrainedResponse response;
107
108 msg.set_required_rule(1);
109 messages.push_back(std::move(msg));
110 messages.push_back(tests::CreateValidMessage(2));
111 messages.push_back(tests::CreateValidMessage(3));
112
113 for (std::size_t i = 0; i < kRequestCount; ++i) {
114 requests[i].set_field(static_cast<int32_t>(i));
115 requests[i].mutable_messages()->Add(messages.begin(), messages.end());
116 }
117
118 // check unary method
119
120 UASSERT_NO_THROW(response = client.CheckConstraintsUnary(requests[0]));
121 EXPECT_TRUE(response.has_field());
122 EXPECT_EQ(response.field(), requests[0].field());
123
124 // check streaming method
125
126 auto write_task = engine::AsyncNoSpan([&stream, &requests] {
127 for (const auto& request : requests) {
128 const bool success = stream.Write(request);
129 if (!success) return false;
130 }
131
132 return stream.WritesDone();
133 });
134
135 while (stream.Read(response)) {
136 responses.push_back(std::move(response));
137 }
138
139 ASSERT_TRUE(write_task.Get());
140 ASSERT_EQ(responses.size(), kRequestCount);
141
142 for (std::size_t i = 0; i < kRequestCount; ++i) {
143 EXPECT_TRUE(responses[i].has_field());
144 EXPECT_EQ(responses[i].field(), requests[i].field());
145 }
146}
147
148UTEST_P_MT(GrpcServerValidatorTest, AllInvalid, 2) {
149 constexpr std::size_t kRequestCount = 3;
150 const auto& streaming_settings = GetParam().Get("types.UnitTestService/CheckConstraintsStreaming");
151 auto client = MakeClient<types::UnitTestServiceClient>();
152 auto stream = client.CheckConstraintsStreaming();
153
154 std::vector<types::ConstrainedRequest> requests(kRequestCount);
155 std::vector<types::ConstrainedResponse> responses;
156
157 types::ConstrainedRequest request;
158 request.set_field(1);
159
160 types::ConstrainedRequest invalid_request;
161 invalid_request.set_field(100);
162 *invalid_request.add_messages() = types::ConstrainedMessage{};
163 *invalid_request.add_messages() = tests::CreateInvalidMessage();
164
165 requests[0] = request;
166 requests[1] = invalid_request;
167 request.set_field(3);
168 requests[2] = request;
169
170 // check unary method
171
172 try {
173 [[maybe_unused]] auto response = client.CheckConstraintsUnary(requests[1]);
174 ADD_FAILURE() << "Call must fail";
175 } catch (const ugrpc::client::InvalidArgumentError& err) {
176 auto violations = tests::GetViolations(err);
177 ASSERT_TRUE(violations);
178 EXPECT_EQ(violations->violations().size(), 20);
179 } catch (...) {
180 ADD_FAILURE() << "'InvalidArgumentError' exception expected";
181 }
182
183 // check streaming method
184
185 auto write_task = engine::AsyncNoSpan([&stream, &requests] {
186 for (const auto& request : requests) {
187 const bool success = stream.Write(request);
188 if (!success) return false;
189 }
190
191 return stream.WritesDone();
192 });
193
194 types::ConstrainedResponse response;
195
196 ASSERT_TRUE(stream.Read(response));
197 EXPECT_EQ(response.field(), 1);
198
199 try {
200 [[maybe_unused]] bool result = stream.Read(response);
201 ADD_FAILURE() << "Call must fail";
202 } catch (const ugrpc::client::InvalidArgumentError& err) {
203 auto violations = tests::GetViolations(err);
204
205 if (streaming_settings.send_violations) {
206 ASSERT_TRUE(violations);
207
208 if (streaming_settings.fail_fast) {
209 EXPECT_EQ(violations->violations().size(), 1);
210 } else {
211 EXPECT_EQ(violations->violations().size(), 20);
212 }
213 } else {
214 EXPECT_FALSE(violations);
215 }
216 } catch (...) {
217 ADD_FAILURE() << "'InvalidArgumentError' exception expected";
218 }
219
220 write_task.Get();
221}
222
223UTEST_P(GrpcServerValidatorTest, InvalidConstraints) {
224 auto client = MakeClient<types::UnitTestServiceClient>();
225 types::InvalidConstraints request;
226 request.set_field(1);
227
228 try {
229 [[maybe_unused]] auto response = client.CheckInvalidRequestConstraints(std::move(request));
230 ADD_FAILURE() << "Call must fail";
231 } catch (const ugrpc::client::InternalError&) {
232 // do nothing
233 } catch (...) {
234 ADD_FAILURE() << "'InternalError' exception expected";
235 }
236}
237
238USERVER_NAMESPACE_END