userver: userver/ugrpc/protobuf_visit.hpp Source File
Loading...
Searching...
No Matches
protobuf_visit.hpp
Go to the documentation of this file.
1#pragma once
2
3/// @file userver/ugrpc/protobuf_visit.hpp
4/// @brief Utilities for visiting the fields of protobufs
5
6#include <mutex>
7#include <string_view>
8#include <vector>
9
10#include <google/protobuf/message.h>
11
12#include <userver/engine/shared_mutex.hpp>
13#include <userver/utils/assert.hpp>
14#include <userver/utils/function_ref.hpp>
15#include <userver/utils/impl/internal_tag_fwd.hpp>
16#include <userver/utils/span.hpp>
17
18namespace google::protobuf {
19
20class Descriptor;
21class FieldDescriptor;
22
23} // namespace google::protobuf
24
25USERVER_NAMESPACE_BEGIN
26
27namespace ugrpc {
28
29using MessageVisitCallback = utils::function_ref<void(google::protobuf::Message&)>;
30
31using FieldVisitCallback = utils::function_ref<
32 void(google::protobuf::Message&, const google::protobuf::FieldDescriptor&)>;
33
34/// @brief Execute a callback for all non-empty fields of the message.
35void VisitFields(google::protobuf::Message& message, FieldVisitCallback callback);
36
37/// @brief Execute a callback for the message and its non-empty submessages.
38void VisitMessagesRecursive(google::protobuf::Message& message, MessageVisitCallback callback);
39
40/// @brief Execute a callback for all fields of the message and its non-empty submessages.
41void VisitFieldsRecursive(google::protobuf::Message& message, FieldVisitCallback callback);
42
43/// @brief Execute a callback for the submessage contained in the given field.
45 google::protobuf::Message& message,
46 const google::protobuf::FieldDescriptor& field,
47 MessageVisitCallback callback
48);
49
50using DescriptorList = std::vector<const google::protobuf::Descriptor*>;
51
52using FieldDescriptorList = std::vector<const google::protobuf::FieldDescriptor*>;
53
54/// @brief Get the descriptors of fields in the message.
55FieldDescriptorList GetFieldDescriptors(const google::protobuf::Descriptor& descriptor);
56
57/// @brief Get the descriptors of current and nested messages.
58DescriptorList GetNestedMessageDescriptors(const google::protobuf::Descriptor& descriptor);
59
60/// @brief Find a generated type by name.
61const google::protobuf::Descriptor* FindGeneratedMessage(std::string_view name);
62
63/// @brief Find the field of a generated type by name.
64const google::protobuf::FieldDescriptor* FindField(
65 const google::protobuf::Descriptor* descriptor,
66 std::string_view field
67);
68
69/// @brief Base class for @ref BaseVisitor.
70/// Constructs and manages the descriptor graph to collect the data about the messages
71/// and enable the visitors to find all selected structures.
73public:
74 enum class LockBehavior {
75 /// @brief Do not take shared_mutex locks for any operation on the visitor
76 kNone = 0,
77
78 /// @brief Take shared_lock for all read operations on the visitor
79 /// and unique_lock for all Compile operations
81 };
82
83 VisitorCompiler(VisitorCompiler&&) = delete;
84 VisitorCompiler(const VisitorCompiler&) = delete;
85
86 /// @brief Compiles the visitor for the given message type and its dependent types
87 void Compile(const google::protobuf::Descriptor* descriptor);
88
89 /// @brief Compiles the visitor for the given message types and their dependent types
90 void Compile(const DescriptorList& descriptors);
91
92 /// @brief Compiles the visitor for all message types we can find.
93 /// Not guaranteed to find all message types.
95
96 /// @brief Compiles the visitor for the given generated message type and its dependent types
97 void CompileGenerated(std::string_view message_name);
98
99 /// @brief Compiles the visitor for the given generated message type and their dependent types
100 void CompileGenerated(utils::span<std::string_view> message_names);
101
102 /// @brief Efficiently checks if the message contains any selected structures.
103 ///
104 /// You may want to call this before @ref BaseVisitor::Visit and @ref BaseVisitor::VisitRecursive
105 /// to avoid a copy of the message beforehand if you require one.
106 bool ContainsSelected(const google::protobuf::Descriptor* descriptor);
107
108 /// @cond
109 /// Only for internal use.
110 using Dependencies = std::unordered_map<
111 const google::protobuf::Descriptor*,
112 std::unordered_set<const google::protobuf::FieldDescriptor*>>;
113
114 /// Only for internal use.
115 using DescriptorSet = std::unordered_set<const google::protobuf::Descriptor*>;
116
117 /// Only for internal use.
118 using FieldDescriptorSet = std::unordered_set<const google::protobuf::FieldDescriptor*>;
119
120 /// Only for internal use.
121 const Dependencies& GetFieldsWithSelectedChildren(utils::impl::InternalTag) const;
122
123 /// Only for internal use.
124 const Dependencies& GetReverseEdges(utils::impl::InternalTag) const;
125
126 /// Only for internal use.
127 const DescriptorSet& GetPropagated(utils::impl::InternalTag) const;
128
129 /// Only for internal use.
130 const DescriptorSet& GetCompiled(utils::impl::InternalTag) const;
131
132protected:
133 explicit VisitorCompiler(LockBehavior lock_behavior)
134 : lock_behavior_(lock_behavior)
135 {}
136
137 // Disallow destruction via pointer to base
138 ~VisitorCompiler() = default;
139
140 /// @brief Lock the visitor for read
141 std::shared_lock<engine::SharedMutex> LockRead();
142
143 /// @brief Lock the visitor for write
144 std::unique_lock<engine::SharedMutex> LockWrite();
145
146 const Dependencies& GetFieldsWithSelectedChildren() const { return fields_with_selected_children_; }
147 /// @endcond
148
149 /// @brief Compile one message without nested.
150 virtual void CompileOne(const google::protobuf::Descriptor& descriptor) = 0;
151
152 /// @brief Checks if the message is selected or has anything selected.
153 virtual bool IsSelected(const google::protobuf::Descriptor&) const = 0;
154
155private:
156 /// @brief Gets all submessages of the given messages.
157 DescriptorSet GetFullSubtrees(const DescriptorList& descriptors) const;
158
159 /// @brief Propagate the selection information upwards
160 void PropagateSelected(const google::protobuf::Descriptor* descriptor);
161
162 engine::SharedMutex mutex_;
163 const LockBehavior lock_behavior_;
164
165 Dependencies fields_with_selected_children_;
166 Dependencies reverse_edges_;
167 DescriptorSet propagated_;
168 DescriptorSet compiled_;
169};
170
171/// @brief Base class for @ref FieldsVisitor and @ref MessagesVisitor.
172/// Provides the interface and contains common code to use the data collected in the @ref VisitorCompiler.
173template <typename Callback>
175public:
176 /// @brief Execute a callback without recursion
177 ///
178 /// Equivalent to @ref VisitFields
179 /// but utilizes the precompilation data from @ref Compile
180 void Visit(google::protobuf::Message& message, Callback callback) {
181 // Compile if not yet compiled
182 Compile(message.GetDescriptor());
183
184 const std::shared_lock read_lock = LockRead();
185 DoVisit(message, callback);
186 }
187
188 /// @brief Execute a callback recursively
189 ///
190 /// Equivalent to @ref VisitFieldsRecursive and @ref VisitMessagesRecursive
191 /// but utilizes the precompilation data from @ref Compile
192 void VisitRecursive(google::protobuf::Message& message, Callback callback) {
193 // Compile if not yet compiled
194 Compile(message.GetDescriptor());
195
196 constexpr int kMaxRecursionLimit = 100;
197 const std::shared_lock read_lock = LockRead();
198 VisitRecursiveImpl(message, callback, kMaxRecursionLimit);
199 }
200
201protected:
202 explicit BaseVisitor(LockBehavior lock_behavior)
203 : VisitorCompiler(lock_behavior)
204 {}
205
206 // Disallow destruction via pointer to base
207 ~BaseVisitor() = default;
208
209 /// @brief Execute a callback without recursion
210 virtual void DoVisit(google::protobuf::Message& message, Callback callback) const = 0;
211
212private:
213 /// @brief Safe version with recursion_limit
214 void VisitRecursiveImpl(google::protobuf::Message& message, Callback callback, int recursion_limit) {
215 UINVARIANT(recursion_limit > 0, "Recursion limit reached while traversing protobuf Message.");
216
217 // Loop over this message
218 DoVisit(message, callback);
219
220 // Recurse into nested messages
221 const auto it = GetFieldsWithSelectedChildren().find(message.GetDescriptor());
222 if (it == GetFieldsWithSelectedChildren().end()) {
223 return;
224 }
225
226 const FieldDescriptorSet& fields = it->second;
227 for (const google::protobuf::FieldDescriptor* field : fields) {
228 UINVARIANT(field, "field is nullptr");
229 VisitNestedMessage(message, *field, [&](google::protobuf::Message& msg) {
230 VisitRecursiveImpl(msg, callback, recursion_limit - 1);
231 });
232 }
233 }
234};
235
236/// @brief Collects knowledge of the structure of the protobuf messages
237/// allowing for efficient loops over fields to apply a callback to the ones
238/// selected by the 'selector' function.
239///
240/// If you do not have static knowledge of the required fields, you should
241/// use @ref VisitFields or @ref VisitFieldsRecursive that are equivalent to
242/// FieldsVisitor with a `return true;` selector.
243///
244/// @warning You should not construct this at runtime as it performs significant
245/// computations in the constructor to precompile the visitors.
246/// You should create these at start-up.
247///
248/// Example usage:
249/// @snippet grpc/tests/protobuf_visit_test.cpp fields visitor
250class FieldsVisitor final : public BaseVisitor<FieldVisitCallback> {
251public:
252 using Selector = utils::function_ref<bool(const google::protobuf::FieldDescriptor& field)>;
253
254 /// @brief Creates the visitor with the given selector
255 /// and compiles it for the message types we can find.
256 explicit FieldsVisitor(Selector selector);
257
258 /// @brief Creates the visitor with the given selector
259 /// and compiles it for the given message types and their fields recursively.
260 FieldsVisitor(Selector selector, const DescriptorList& descriptors);
261
262 /// @brief Creates the visitor with custom thread locking behavior
263 /// and the given selector for runtime compilation.
264 ///
265 /// @warning Do not use this unless you know what you are doing.
266 FieldsVisitor(Selector selector, LockBehavior lock_behavior);
267
268 /// @brief Creates the visitor with custom thread locking behavior
269 /// and the given selector; compiles it for the given message types.
270 ///
271 /// @warning Do not use this unless you know what you are doing.
272 FieldsVisitor(Selector selector, const DescriptorList& descriptors, LockBehavior lock_behavior);
273
274 /// @cond
275 /// Only for internal use.
276 const Dependencies& GetSelectedFields(utils::impl::InternalTag) const;
277 /// @endcond
278
279private:
280 void CompileOne(const google::protobuf::Descriptor& descriptor) override;
281
282 bool IsSelected(const google::protobuf::Descriptor& descriptor) const override {
283 return selected_fields_.find(&descriptor) != selected_fields_.end();
284 }
285
286 void DoVisit(google::protobuf::Message& message, FieldVisitCallback callback) const override;
287
288 Dependencies selected_fields_;
289 const Selector selector_;
290};
291
292/// @brief Collects knowledge of the structure of the protobuf messages
293/// allowing for efficient loops over nested messages to apply a callback
294/// to the ones selected by the 'selector' function.
295///
296/// If you do not have static knowledge of the required messages, you should
297/// use @ref VisitMessagesRecursive that is equivalent to
298/// MessagesVisitor with a 'return true' selector.
299///
300/// @warning You should not construct this at runtime as it performs significant
301/// computations in the constructor to precompile the visitors.
302/// You should create this ones at start-up.
303class MessagesVisitor final : public BaseVisitor<MessageVisitCallback> {
304public:
305 using Selector = utils::function_ref<bool(const google::protobuf::Descriptor& descriptor)>;
306
307 /// @brief Creates the visitor with the given selector for runtime compilation
308 /// and compiles it for the message types we can find.
309 explicit MessagesVisitor(Selector selector);
310
311 /// @brief Creates the visitor with the given selector
312 /// and compiles it for the given message types and their fields recursively.
313 MessagesVisitor(Selector selector, const DescriptorList& descriptors);
314
315 /// @brief Creates the visitor with custom thread locking behavior
316 /// and the given selector for runtime compilation.
317 ///
318 /// @warning Do not use this unless you know what you are doing.
319 MessagesVisitor(Selector selector, LockBehavior lock_behavior);
320
321 /// @brief Creates the visitor with custom thread locking behavior
322 /// and the given selector; compiles it for the given message types.
323 ///
324 /// @warning Do not use this unless you know what you are doing.
325 MessagesVisitor(Selector selector, const DescriptorList& descriptors, LockBehavior lock_behavior);
326
327 /// @cond
328 /// Only for internal use.
329 const DescriptorSet& GetSelectedMessages(utils::impl::InternalTag) const;
330 /// @endcond
331
332private:
333 void CompileOne(const google::protobuf::Descriptor& descriptor) override;
334
335 bool IsSelected(const google::protobuf::Descriptor& descriptor) const override {
336 return selected_messages_.find(&descriptor) != selected_messages_.end();
337 }
338
339 void DoVisit(google::protobuf::Message& message, MessageVisitCallback callback) const override;
340
341 DescriptorSet selected_messages_;
342 const Selector selector_;
343};
344
345} // namespace ugrpc
346
347USERVER_NAMESPACE_END