1#include <userver/utest/utest.hpp>
3#include <buf/validate/validate.pb.h>
5#include <userver/engine/async.hpp>
6#include <userver/ugrpc/client/exceptions.hpp>
7#include <userver/ugrpc/tests/service_fixtures.hpp>
9#include <types/unit_test_client.usrv.pb.hpp>
10#include <types/unit_test_service.usrv.pb.hpp>
12#include <grpc-protovalidate/server/middleware.hpp>
15USERVER_NAMESPACE_BEGIN
19class UnitTestServiceValidator
final :
public types::UnitTestServiceBase {
21 CheckConstraintsUnaryResult CheckConstraintsUnary(CallContext&, types::ConstrainedRequest&& request)
override {
22 types::ConstrainedResponse response;
23 response.set_field(request.field());
27 CheckConstraintsStreamingResult CheckConstraintsStreaming(
29 CheckConstraintsStreamingReaderWriter& stream
31 types::ConstrainedRequest request;
32 while (stream.Read(request)) {
33 types::ConstrainedResponse response{};
34 response.set_field(request.field());
35 stream.Write(std::move(response));
37 return grpc::Status::OK;
40 CheckInvalidRequestConstraintsResult CheckInvalidRequestConstraints(CallContext&, types::InvalidConstraints&&)
42 return google::protobuf::Empty{};
46class GrpcServerValidatorTest
47 :
public ugrpc::tests::ServiceFixtureBase,
48 public testing::WithParamInterface<grpc_protovalidate::server::Settings> {
50 GrpcServerValidatorTest() {
51 SetServerMiddlewares({std::make_shared<grpc_protovalidate::server::Middleware>(GetParam())});
52 RegisterService(service_);
56 ~GrpcServerValidatorTest() override { StopServer(); }
59 UnitTestServiceValidator service_;
64INSTANTIATE_UTEST_SUITE_P(
66 GrpcServerValidatorTest,
68 grpc_protovalidate::server::Settings{
71 {
"types.UnitTestService/CheckConstraintsUnary", {.fail_fast =
false, .send_violations =
true}},
72 {
"/UnknownMethod", {.fail_fast =
true, .send_violations =
false}},
75 grpc_protovalidate::server::Settings{
79 .send_violations =
true,
83 {
"types.UnitTestService/CheckConstraintsUnary", {.fail_fast =
false, .send_violations =
true}},
84 {
"/UnknownMethod", {.fail_fast =
true, .send_violations =
false}},
87 grpc_protovalidate::server::Settings{
91 .send_violations =
true,
95 {
"types.UnitTestService/CheckConstraintsUnary", {.fail_fast =
false, .send_violations =
true}},
96 {
"/UnknownMethod", {.fail_fast =
true, .send_violations =
false}},
102UTEST_P_MT(GrpcServerValidatorTest, AllValid, 2) {
103 constexpr std::size_t kRequestCount = 3;
104 auto client = MakeClient<types::UnitTestServiceClient>();
105 auto stream = client.CheckConstraintsStreaming();
107 std::vector<types::ConstrainedMessage> messages;
108 std::vector<types::ConstrainedRequest> requests(kRequestCount);
109 std::vector<types::ConstrainedResponse> responses;
111 types::ConstrainedMessage msg;
112 types::ConstrainedResponse response;
114 msg.set_required_rule(1);
115 messages.push_back(std::move(msg));
116 messages.push_back(tests::CreateValidMessage(2));
117 messages.push_back(tests::CreateValidMessage(3));
119 for (std::size_t i = 0; i < kRequestCount; ++i) {
120 requests[i].set_field(
static_cast<int32_t>(i));
121 requests[i].mutable_messages()->Add(messages.begin(), messages.end());
126 UASSERT_NO_THROW(response = client.CheckConstraintsUnary(requests[0]));
127 EXPECT_TRUE(response.has_field());
128 EXPECT_EQ(response.field(), requests[0].field());
132 auto write_task = engine::AsyncNoSpan([&stream, &requests] {
133 for (
const auto& request : requests) {
134 const bool success = stream.Write(request);
140 return stream.WritesDone();
143 while (stream.Read(response)) {
144 responses.push_back(std::move(response));
147 ASSERT_TRUE(write_task.Get());
148 ASSERT_EQ(responses.size(), kRequestCount);
150 for (std::size_t i = 0; i < kRequestCount; ++i) {
151 EXPECT_TRUE(responses[i].has_field());
152 EXPECT_EQ(responses[i].field(), requests[i].field());
156UTEST_P_MT(GrpcServerValidatorTest, AllInvalid, 2) {
157 constexpr std::size_t kRequestCount = 3;
158 const auto& streaming_settings = GetParam().Get(
"types.UnitTestService/CheckConstraintsStreaming");
159 auto client = MakeClient<types::UnitTestServiceClient>();
160 auto stream = client.CheckConstraintsStreaming();
162 std::vector<types::ConstrainedRequest> requests(kRequestCount);
163 const std::vector<types::ConstrainedResponse> responses;
165 types::ConstrainedRequest request;
166 request.set_field(1);
168 types::ConstrainedRequest invalid_request;
169 invalid_request.set_field(100);
170 *invalid_request.add_messages() = types::ConstrainedMessage{};
171 *invalid_request.add_messages() = tests::CreateInvalidMessage();
173 requests[0] = request;
174 requests[1] = invalid_request;
175 request.set_field(3);
176 requests[2] = request;
181 [[maybe_unused]]
auto response = client.CheckConstraintsUnary(requests[1]);
182 ADD_FAILURE() <<
"Call must fail";
183 }
catch (
const ugrpc::client::InvalidArgumentError& err) {
184 auto violations = tests::GetViolations(err);
185 ASSERT_TRUE(violations);
186 EXPECT_EQ(violations->violations().size(), 20);
188 ADD_FAILURE() <<
"'InvalidArgumentError' exception expected";
193 auto write_task = engine::AsyncNoSpan([&stream, &requests] {
194 for (
const auto& request : requests) {
195 const bool success = stream.Write(request);
201 return stream.WritesDone();
204 types::ConstrainedResponse response;
206 ASSERT_TRUE(stream.Read(response));
207 EXPECT_EQ(response.field(), 1);
210 [[maybe_unused]]
const bool result = stream.Read(response);
211 ADD_FAILURE() <<
"Call must fail";
212 }
catch (
const ugrpc::client::InvalidArgumentError& err) {
213 auto violations = tests::GetViolations(err);
215 if (streaming_settings.send_violations) {
216 ASSERT_TRUE(violations);
218 if (streaming_settings.fail_fast) {
219 EXPECT_EQ(violations->violations().size(), 1);
221 EXPECT_EQ(violations->violations().size(), 20);
224 EXPECT_FALSE(violations);
227 ADD_FAILURE() <<
"'InvalidArgumentError' exception expected";
233UTEST_P(GrpcServerValidatorTest, InvalidConstraints) {
234 auto client = MakeClient<types::UnitTestServiceClient>();
235 types::InvalidConstraints request;
236 request.set_field(1);
239 [[maybe_unused]]
auto response = client.CheckInvalidRequestConstraints(std::move(request));
240 ADD_FAILURE() <<
"Call must fail";
241 }
catch (
const ugrpc::client::InternalError&) {
244 ADD_FAILURE() <<
"'InternalError' exception expected";