1#include <userver/utest/utest.hpp>
3#include <buf/validate/validate.pb.h>
5#include <userver/engine/async.hpp>
6#include <userver/ugrpc/client/exceptions.hpp>
7#include <userver/ugrpc/tests/service_fixtures.hpp>
9#include <types/unit_test_client.usrv.pb.hpp>
10#include <types/unit_test_service.usrv.pb.hpp>
12#include <grpc-protovalidate/server/middleware.hpp>
15USERVER_NAMESPACE_BEGIN
19class UnitTestServiceValidator
final :
public types::UnitTestServiceBase {
21 CheckConstraintsUnaryResult CheckConstraintsUnary(CallContext&, types::ConstrainedRequest&& request)
override {
22 types::ConstrainedResponse response;
23 response.set_field(request.field());
27 CheckConstraintsStreamingResult
28 CheckConstraintsStreaming(CallContext&, CheckConstraintsStreamingReaderWriter& stream)
override {
29 types::ConstrainedRequest request;
30 types::ConstrainedResponse response{};
31 while (stream.Read(request)) {
32 response.set_field(request.field());
33 stream.Write(response);
35 return grpc::Status::OK;
38 CheckInvalidRequestConstraintsResult CheckInvalidRequestConstraints(CallContext&, types::InvalidConstraints&&)
40 return google::protobuf::Empty{};
44class GrpcServerValidatorTest :
public ugrpc::tests::ServiceFixtureBase,
45 public testing::WithParamInterface<grpc_protovalidate::server::Settings> {
47 GrpcServerValidatorTest() {
48 SetServerMiddlewares({std::make_shared<grpc_protovalidate::server::Middleware>(GetParam())});
49 RegisterService(service_);
53 ~GrpcServerValidatorTest() override { StopServer(); }
56 UnitTestServiceValidator service_;
61INSTANTIATE_UTEST_SUITE_P(
63 GrpcServerValidatorTest,
65 grpc_protovalidate::server::Settings{
68 {
"types.UnitTestService/CheckConstraintsUnary", {.fail_fast =
false, .send_violations =
true}},
69 {
"/UnknownMethod", {.fail_fast =
true, .send_violations =
false}},
71 grpc_protovalidate::server::Settings{
75 .send_violations =
true,
79 {
"types.UnitTestService/CheckConstraintsUnary", {.fail_fast =
false, .send_violations =
true}},
80 {
"/UnknownMethod", {.fail_fast =
true, .send_violations =
false}},
82 grpc_protovalidate::server::Settings{
86 .send_violations =
true,
90 {
"types.UnitTestService/CheckConstraintsUnary", {.fail_fast =
false, .send_violations =
true}},
91 {
"/UnknownMethod", {.fail_fast =
true, .send_violations =
false}},
96UTEST_P_MT(GrpcServerValidatorTest, AllValid, 2) {
97 constexpr std::size_t kRequestCount = 3;
98 auto client = MakeClient<types::UnitTestServiceClient>();
99 auto stream = client.CheckConstraintsStreaming();
101 std::vector<types::ConstrainedMessage> messages;
102 std::vector<types::ConstrainedRequest> requests(kRequestCount);
103 std::vector<types::ConstrainedResponse> responses;
105 types::ConstrainedMessage msg;
106 types::ConstrainedResponse response;
108 msg.set_required_rule(1);
109 messages.push_back(std::move(msg));
110 messages.push_back(tests::CreateValidMessage(2));
111 messages.push_back(tests::CreateValidMessage(3));
113 for (std::size_t i = 0; i < kRequestCount; ++i) {
114 requests[i].set_field(
static_cast<int32_t>(i));
115 requests[i].mutable_messages()->Add(messages.begin(), messages.end());
120 UASSERT_NO_THROW(response = client.CheckConstraintsUnary(requests[0]));
121 EXPECT_TRUE(response.has_field());
122 EXPECT_EQ(response.field(), requests[0].field());
126 auto write_task = engine::AsyncNoSpan([&stream, &requests] {
127 for (
const auto& request : requests) {
128 const bool success = stream.Write(request);
129 if (!success)
return false;
132 return stream.WritesDone();
135 while (stream.Read(response)) {
136 responses.push_back(std::move(response));
139 ASSERT_TRUE(write_task.Get());
140 ASSERT_EQ(responses.size(), kRequestCount);
142 for (std::size_t i = 0; i < kRequestCount; ++i) {
143 EXPECT_TRUE(responses[i].has_field());
144 EXPECT_EQ(responses[i].field(), requests[i].field());
148UTEST_P_MT(GrpcServerValidatorTest, AllInvalid, 2) {
149 constexpr std::size_t kRequestCount = 3;
150 const auto& streaming_settings = GetParam().Get(
"types.UnitTestService/CheckConstraintsStreaming");
151 auto client = MakeClient<types::UnitTestServiceClient>();
152 auto stream = client.CheckConstraintsStreaming();
154 std::vector<types::ConstrainedRequest> requests(kRequestCount);
155 std::vector<types::ConstrainedResponse> responses;
157 types::ConstrainedRequest request;
158 request.set_field(1);
160 types::ConstrainedRequest invalid_request;
161 invalid_request.set_field(100);
162 *invalid_request.add_messages() = types::ConstrainedMessage{};
163 *invalid_request.add_messages() = tests::CreateInvalidMessage();
165 requests[0] = request;
166 requests[1] = invalid_request;
167 request.set_field(3);
168 requests[2] = request;
173 [[maybe_unused]]
auto response = client.CheckConstraintsUnary(requests[1]);
174 ADD_FAILURE() <<
"Call must fail";
175 }
catch (
const ugrpc::client::InvalidArgumentError& err) {
176 auto violations = tests::GetViolations(err);
177 ASSERT_TRUE(violations);
178 EXPECT_EQ(violations->violations().size(), 20);
180 ADD_FAILURE() <<
"'InvalidArgumentError' exception expected";
185 auto write_task = engine::AsyncNoSpan([&stream, &requests] {
186 for (
const auto& request : requests) {
187 const bool success = stream.Write(request);
188 if (!success)
return false;
191 return stream.WritesDone();
194 types::ConstrainedResponse response;
196 ASSERT_TRUE(stream.Read(response));
197 EXPECT_EQ(response.field(), 1);
200 [[maybe_unused]]
bool result = stream.Read(response);
201 ADD_FAILURE() <<
"Call must fail";
202 }
catch (
const ugrpc::client::InvalidArgumentError& err) {
203 auto violations = tests::GetViolations(err);
205 if (streaming_settings.send_violations) {
206 ASSERT_TRUE(violations);
208 if (streaming_settings.fail_fast) {
209 EXPECT_EQ(violations->violations().size(), 1);
211 EXPECT_EQ(violations->violations().size(), 20);
214 EXPECT_FALSE(violations);
217 ADD_FAILURE() <<
"'InvalidArgumentError' exception expected";
223UTEST_P(GrpcServerValidatorTest, InvalidConstraints) {
224 auto client = MakeClient<types::UnitTestServiceClient>();
225 types::InvalidConstraints request;
226 request.set_field(1);
229 [[maybe_unused]]
auto response = client.CheckInvalidRequestConstraints(std::move(request));
230 ADD_FAILURE() <<
"Call must fail";
231 }
catch (
const ugrpc::client::InternalError&) {
234 ADD_FAILURE() <<
"'InternalError' exception expected";