userver: /data/code/userver/libraries/grpc-protovalidate/tests/validator_service_test.cpp Source File
Loading...
Searching...
No Matches
validator_service_test.cpp
1#include <userver/utest/utest.hpp>
2
3#include <buf/validate/validate.pb.h>
4
5#include <userver/engine/async.hpp>
6#include <userver/ugrpc/client/exceptions.hpp>
7#include <userver/ugrpc/tests/service_fixtures.hpp>
8
9#include <types/unit_test_client.usrv.pb.hpp>
10#include <types/unit_test_service.usrv.pb.hpp>
11
12#include <grpc-protovalidate/server/middleware.hpp>
13#include "utils.hpp"
14
15USERVER_NAMESPACE_BEGIN
16
17namespace {
18
19class UnitTestServiceValidator final : public types::UnitTestServiceBase {
20public:
21 CheckConstraintsUnaryResult CheckConstraintsUnary(CallContext&, types::ConstrainedRequest&& request) override {
22 types::ConstrainedResponse response;
23 response.set_field(request.field());
24 return response;
25 }
26
27 CheckConstraintsStreamingResult CheckConstraintsStreaming(
28 CallContext&,
29 CheckConstraintsStreamingReaderWriter& stream
30 ) override {
31 types::ConstrainedRequest request;
32 while (stream.Read(request)) {
33 types::ConstrainedResponse response{};
34 response.set_field(request.field());
35 stream.Write(std::move(response));
36 }
37 return grpc::Status::OK;
38 }
39
40 CheckInvalidRequestConstraintsResult CheckInvalidRequestConstraints(CallContext&, types::InvalidConstraints&&)
41 override {
42 return google::protobuf::Empty{};
43 }
44};
45
46class GrpcServerValidatorTest
47 : public ugrpc::tests::ServiceFixtureBase,
48 public testing::WithParamInterface<grpc_protovalidate::server::Settings> {
49public:
50 GrpcServerValidatorTest() {
51 SetServerMiddlewares({std::make_shared<grpc_protovalidate::server::Middleware>(GetParam())});
52 RegisterService(service_);
53 StartServer();
54 }
55
56 ~GrpcServerValidatorTest() override { StopServer(); }
57
58private:
59 UnitTestServiceValidator service_;
60};
61
62} // namespace
63
64INSTANTIATE_UTEST_SUITE_P(
65 /*no prefix*/,
66 GrpcServerValidatorTest,
67 testing::Values(
68 grpc_protovalidate::server::Settings{
69 .per_method =
70 {
71 {"types.UnitTestService/CheckConstraintsUnary", {.fail_fast = false, .send_violations = true}},
72 {"/UnknownMethod", {.fail_fast = true, .send_violations = false}},
73 }
74 },
75 grpc_protovalidate::server::Settings{
76 .global =
77 {
78 .fail_fast = false,
79 .send_violations = true,
80 },
81 .per_method =
82 {
83 {"types.UnitTestService/CheckConstraintsUnary", {.fail_fast = false, .send_violations = true}},
84 {"/UnknownMethod", {.fail_fast = true, .send_violations = false}},
85 }
86 },
87 grpc_protovalidate::server::Settings{
88 .global =
89 {
90 .fail_fast = true,
91 .send_violations = true,
92 },
93 .per_method =
94 {
95 {"types.UnitTestService/CheckConstraintsUnary", {.fail_fast = false, .send_violations = true}},
96 {"/UnknownMethod", {.fail_fast = true, .send_violations = false}},
97 }
98 }
99 )
100);
101
102UTEST_P_MT(GrpcServerValidatorTest, AllValid, 2) {
103 constexpr std::size_t kRequestCount = 3;
104 auto client = MakeClient<types::UnitTestServiceClient>();
105 auto stream = client.CheckConstraintsStreaming();
106
107 std::vector<types::ConstrainedMessage> messages;
108 std::vector<types::ConstrainedRequest> requests(kRequestCount);
109 std::vector<types::ConstrainedResponse> responses;
110
111 types::ConstrainedMessage msg;
112 types::ConstrainedResponse response;
113
114 msg.set_required_rule(1);
115 messages.push_back(std::move(msg));
116 messages.push_back(tests::CreateValidMessage(2));
117 messages.push_back(tests::CreateValidMessage(3));
118
119 for (std::size_t i = 0; i < kRequestCount; ++i) {
120 requests[i].set_field(static_cast<int32_t>(i));
121 requests[i].mutable_messages()->Add(messages.begin(), messages.end());
122 }
123
124 // check unary method
125
126 UASSERT_NO_THROW(response = client.CheckConstraintsUnary(requests[0]));
127 EXPECT_TRUE(response.has_field());
128 EXPECT_EQ(response.field(), requests[0].field());
129
130 // check streaming method
131
132 auto write_task = engine::AsyncNoSpan([&stream, &requests] {
133 for (const auto& request : requests) {
134 const bool success = stream.Write(request);
135 if (!success) {
136 return false;
137 }
138 }
139
140 return stream.WritesDone();
141 });
142
143 while (stream.Read(response)) {
144 responses.push_back(std::move(response));
145 }
146
147 ASSERT_TRUE(write_task.Get());
148 ASSERT_EQ(responses.size(), kRequestCount);
149
150 for (std::size_t i = 0; i < kRequestCount; ++i) {
151 EXPECT_TRUE(responses[i].has_field());
152 EXPECT_EQ(responses[i].field(), requests[i].field());
153 }
154}
155
156UTEST_P_MT(GrpcServerValidatorTest, AllInvalid, 2) {
157 constexpr std::size_t kRequestCount = 3;
158 const auto& streaming_settings = GetParam().Get("types.UnitTestService/CheckConstraintsStreaming");
159 auto client = MakeClient<types::UnitTestServiceClient>();
160 auto stream = client.CheckConstraintsStreaming();
161
162 std::vector<types::ConstrainedRequest> requests(kRequestCount);
163 const std::vector<types::ConstrainedResponse> responses;
164
165 types::ConstrainedRequest request;
166 request.set_field(1);
167
168 types::ConstrainedRequest invalid_request;
169 invalid_request.set_field(100);
170 *invalid_request.add_messages() = types::ConstrainedMessage{};
171 *invalid_request.add_messages() = tests::CreateInvalidMessage();
172
173 requests[0] = request;
174 requests[1] = invalid_request;
175 request.set_field(3);
176 requests[2] = request;
177
178 // check unary method
179
180 try {
181 [[maybe_unused]] auto response = client.CheckConstraintsUnary(requests[1]);
182 ADD_FAILURE() << "Call must fail";
183 } catch (const ugrpc::client::InvalidArgumentError& err) {
184 auto violations = tests::GetViolations(err);
185 ASSERT_TRUE(violations);
186 EXPECT_EQ(violations->violations().size(), 20);
187 } catch (...) {
188 ADD_FAILURE() << "'InvalidArgumentError' exception expected";
189 }
190
191 // check streaming method
192
193 auto write_task = engine::AsyncNoSpan([&stream, &requests] {
194 for (const auto& request : requests) {
195 const bool success = stream.Write(request);
196 if (!success) {
197 return false;
198 }
199 }
200
201 return stream.WritesDone();
202 });
203
204 types::ConstrainedResponse response;
205
206 ASSERT_TRUE(stream.Read(response));
207 EXPECT_EQ(response.field(), 1);
208
209 try {
210 [[maybe_unused]] const bool result = stream.Read(response);
211 ADD_FAILURE() << "Call must fail";
212 } catch (const ugrpc::client::InvalidArgumentError& err) {
213 auto violations = tests::GetViolations(err);
214
215 if (streaming_settings.send_violations) {
216 ASSERT_TRUE(violations);
217
218 if (streaming_settings.fail_fast) {
219 EXPECT_EQ(violations->violations().size(), 1);
220 } else {
221 EXPECT_EQ(violations->violations().size(), 20);
222 }
223 } else {
224 EXPECT_FALSE(violations);
225 }
226 } catch (...) {
227 ADD_FAILURE() << "'InvalidArgumentError' exception expected";
228 }
229
230 write_task.Get();
231}
232
233UTEST_P(GrpcServerValidatorTest, InvalidConstraints) {
234 auto client = MakeClient<types::UnitTestServiceClient>();
235 types::InvalidConstraints request;
236 request.set_field(1);
237
238 try {
239 [[maybe_unused]] auto response = client.CheckInvalidRequestConstraints(std::move(request));
240 ADD_FAILURE() << "Call must fail";
241 } catch (const ugrpc::client::InternalError&) {
242 // do nothing
243 } catch (...) {
244 ADD_FAILURE() << "'InternalError' exception expected";
245 }
246}
247
248USERVER_NAMESPACE_END