userver: userver/ugrpc/protobuf_visit.hpp Source File
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages Concepts
protobuf_visit.hpp
Go to the documentation of this file.
1#pragma once
2
3/// @file userver/ugrpc/protobuf_visit.hpp
4/// @brief Utilities for visiting the fields of protobufs
5
6#include <mutex>
7#include <shared_mutex>
8#include <string_view>
9#include <vector>
10
11#include <google/protobuf/message.h>
12
13#include <userver/utils/assert.hpp>
14#include <userver/utils/function_ref.hpp>
15#include <userver/utils/impl/internal_tag.hpp>
16#include <userver/utils/span.hpp>
17
18namespace google::protobuf {
19
20class Descriptor;
21class FieldDescriptor;
22
23} // namespace google::protobuf
24
25USERVER_NAMESPACE_BEGIN
26
27namespace ugrpc {
28
29using MessageVisitCallback = utils::function_ref<void(google::protobuf::Message&)>;
30
31using FieldVisitCallback =
32 utils::function_ref<void(google::protobuf::Message&, const google::protobuf::FieldDescriptor&)>;
33
34/// @brief Execute a callback for all non-empty fields of the message.
35void VisitFields(google::protobuf::Message& message, FieldVisitCallback callback);
36
37/// @brief Execute a callback for the message and its non-empty submessages.
38void VisitMessagesRecursive(google::protobuf::Message& message, MessageVisitCallback callback);
39
40/// @brief Execute a callback for all fields of the message and its non-empty submessages.
41void VisitFieldsRecursive(google::protobuf::Message& message, FieldVisitCallback callback);
42
43/// @brief Execute a callback for the submessage contained in the given field.
45 google::protobuf::Message& message,
46 const google::protobuf::FieldDescriptor& field,
47 MessageVisitCallback callback
48);
49
50using DescriptorList = std::vector<const google::protobuf::Descriptor*>;
51
52using FieldDescriptorList = std::vector<const google::protobuf::FieldDescriptor*>;
53
54/// @brief Get the descriptors of fields in the message.
55FieldDescriptorList GetFieldDescriptors(const google::protobuf::Descriptor& descriptor);
56
57/// @brief Get the descriptors of current and nested messages.
58DescriptorList GetNestedMessageDescriptors(const google::protobuf::Descriptor& descriptor);
59
60/// @brief Find a generated type by name.
61const google::protobuf::Descriptor* FindGeneratedMessage(std::string_view name);
62
63/// @brief Find the field of a generated type by name.
64const google::protobuf::FieldDescriptor*
65FindField(const google::protobuf::Descriptor* descriptor, std::string_view field);
66
67/// @brief Base class for @ref BaseVisitor.
68/// Constructs and manages the descriptor graph to collect the data about the messages
69/// and enable the visitors to find all selected structures.
71public:
72 enum class LockBehavior {
73 /// @brief Do not take shared_mutex locks for any operation on the visitor
74 kNone = 0,
75
76 /// @brief Take shared_lock for all read operations on the visitor
77 /// and unique_lock for all Compile operations
79 };
80
81 VisitorCompiler(VisitorCompiler&&) = delete;
82 VisitorCompiler(const VisitorCompiler&) = delete;
83
84 /// @brief Compiles the visitor for the given message type and its dependent types
85 void Compile(const google::protobuf::Descriptor* descriptor);
86
87 /// @brief Compiles the visitor for the given message types and their dependent types
88 void Compile(const DescriptorList& descriptors);
89
90 /// @brief Compiles the visitor for all message types we can find.
91 /// Not guaranteed to find all message types.
93
94 /// @brief Compiles the visitor for the given generated message type and its dependent types
95 void CompileGenerated(std::string_view message_name);
96
97 /// @brief Compiles the visitor for the given generated message type and their dependent types
98 void CompileGenerated(utils::span<std::string_view> message_names);
99
100 /// @brief Efficiently checks if the message contains any selected structures.
101 ///
102 /// You may want to call this before @ref BaseVisitor::Visit and @ref BaseVisitor::VisitRecursive
103 /// to avoid a copy of the message beforehand if you require one.
104 bool ContainsSelected(const google::protobuf::Descriptor* descriptor);
105
106 /// @cond
107 /// Only for internal use.
108 using Dependencies = std::unordered_map<
109 const google::protobuf::Descriptor*,
110 std::unordered_set<const google::protobuf::FieldDescriptor*>>;
111
112 /// Only for internal use.
113 using DescriptorSet = std::unordered_set<const google::protobuf::Descriptor*>;
114
115 /// Only for internal use.
116 using FieldDescriptorSet = std::unordered_set<const google::protobuf::FieldDescriptor*>;
117
118 /// Only for internal use.
119 const Dependencies& GetFieldsWithSelectedChildren(utils::impl::InternalTag) const {
120 return fields_with_selected_children_;
121 }
122
123 /// Only for internal use.
124 const Dependencies& GetReverseEdges(utils::impl::InternalTag) const { return reverse_edges_; }
125
126 /// Only for internal use.
127 const DescriptorSet& GetPropagated(utils::impl::InternalTag) const { return propagated_; }
128
129 /// Only for internal use.
130 const DescriptorSet& GetCompiled(utils::impl::InternalTag) const { return compiled_; }
131
132protected:
133 explicit VisitorCompiler(LockBehavior lock_behavior) : lock_behavior_(lock_behavior) {}
134
135 // Disallow destruction via pointer to base
136 ~VisitorCompiler() = default;
137
138 /// @brief Lock the visitor for read
139 std::shared_lock<std::shared_mutex> LockRead();
140
141 /// @brief Lock the visitor for write
142 std::unique_lock<std::shared_mutex> LockWrite();
143
144 const Dependencies& GetFieldsWithSelectedChildren() const { return fields_with_selected_children_; }
145 /// @endcond
146
147 /// @brief Compile one message without nested.
148 virtual void CompileOne(const google::protobuf::Descriptor& descriptor) = 0;
149
150 /// @brief Checks if the message is selected or has anything selected.
151 virtual bool IsSelected(const google::protobuf::Descriptor&) const = 0;
152
153private:
154 /// @brief Gets all submessages of the given messages.
155 DescriptorSet GetFullSubtrees(const DescriptorList& descriptors) const;
156
157 /// @brief Propagate the selection information upwards
158 void PropagateSelected(const google::protobuf::Descriptor* descriptor);
159
160 std::shared_mutex mutex_;
161 const LockBehavior lock_behavior_;
162
163 Dependencies fields_with_selected_children_;
164 Dependencies reverse_edges_;
165 DescriptorSet propagated_;
166 DescriptorSet compiled_;
167};
168
169/// @brief Base class for @ref FieldsVisitor and @ref MessagesVisitor.
170/// Provides the interface and contains common code to use the data collected in the @ref VisitorCompiler.
171template <typename Callback>
173public:
174 /// @brief Execute a callback without recursion
175 ///
176 /// Equivalent to @ref VisitFields
177 /// but utilizes the precompilation data from @ref Compile
178 void Visit(google::protobuf::Message& message, Callback callback) {
179 // Compile if not yet compiled
180 Compile(message.GetDescriptor());
181
182 std::shared_lock read_lock = LockRead();
183 DoVisit(message, callback);
184 }
185
186 /// @brief Execute a callback recursively
187 ///
188 /// Equivalent to @ref VisitFieldsRecursive and @ref VisitMessagesRecursive
189 /// but utilizes the precompilation data from @ref Compile
190 void VisitRecursive(google::protobuf::Message& message, Callback callback) {
191 // Compile if not yet compiled
192 Compile(message.GetDescriptor());
193
194 constexpr int kMaxRecursionLimit = 100;
195 std::shared_lock read_lock = LockRead();
196 VisitRecursiveImpl(message, callback, kMaxRecursionLimit);
197 }
198
199protected:
200 explicit BaseVisitor(LockBehavior lock_behavior) : VisitorCompiler(lock_behavior) {}
201
202 // Disallow destruction via pointer to base
203 ~BaseVisitor() = default;
204
205 /// @brief Execute a callback without recursion
206 virtual void DoVisit(google::protobuf::Message& message, Callback callback) const = 0;
207
208private:
209 /// @brief Safe version with recursion_limit
210 void VisitRecursiveImpl(google::protobuf::Message& message, Callback callback, int recursion_limit) {
211 UINVARIANT(recursion_limit > 0, "Recursion limit reached while traversing protobuf Message.");
212
213 // Loop over this message
214 DoVisit(message, callback);
215
216 // Recurse into nested messages
217 const auto it = GetFieldsWithSelectedChildren().find(message.GetDescriptor());
218 if (it == GetFieldsWithSelectedChildren().end()) return;
219
220 const FieldDescriptorSet& fields = it->second;
221 for (const google::protobuf::FieldDescriptor* field : fields) {
222 UINVARIANT(field, "field is nullptr");
223 VisitNestedMessage(message, *field, [&](google::protobuf::Message& msg) {
224 VisitRecursiveImpl(msg, callback, recursion_limit - 1);
225 });
226 }
227 }
228};
229
230/// @brief Collects knowledge of the structure of the protobuf messages
231/// allowing for efficient loops over fields to apply a callback to the ones
232/// selected by the 'selector' function.
233///
234/// If you do not have static knowledge of the required fields, you should
235/// use @ref VisitFields or @ref VisitFieldsRecursive that are equivalent to
236/// FieldsVisitor with a `return true;` selector.
237///
238/// @warning You should not construct this at runtime as it performs significant
239/// computations in the constructor to precompile the visitors.
240/// You should create these at start-up.
241///
242/// Example usage:
243/// @snippet grpc/src/ugrpc/impl/protobuf_utils.cpp fields visitor
244class FieldsVisitor final : public BaseVisitor<FieldVisitCallback> {
245public:
246 using Selector = utils::function_ref<bool(const google::protobuf::FieldDescriptor& field)>;
247
248 /// @brief Creates the visitor with the given selector
249 /// and compiles it for the message types we can find.
250 explicit FieldsVisitor(Selector selector);
251
252 /// @brief Creates the visitor with the given selector
253 /// and compiles it for the given message types and their fields recursively.
254 FieldsVisitor(Selector selector, const DescriptorList& descriptors);
255
256 /// @brief Creates the visitor with custom thread locking behavior
257 /// and the given selector for runtime compilation.
258 ///
259 /// @warning Do not use this unless you know what you are doing.
260 FieldsVisitor(Selector selector, LockBehavior lock_behavior);
261
262 /// @brief Creates the visitor with custom thread locking behavior
263 /// and the given selector; compiles it for the given message types.
264 ///
265 /// @warning Do not use this unless you know what you are doing.
266 FieldsVisitor(Selector selector, const DescriptorList& descriptors, LockBehavior lock_behavior);
267
268 /// @cond
269 /// Only for internal use.
270 const Dependencies& GetSelectedFields(utils::impl::InternalTag) const { return selected_fields_; }
271 /// @endcond
272
273private:
274 void CompileOne(const google::protobuf::Descriptor& descriptor) override;
275
276 bool IsSelected(const google::protobuf::Descriptor& descriptor) const override {
277 return selected_fields_.find(&descriptor) != selected_fields_.end();
278 }
279
280 void DoVisit(google::protobuf::Message& message, FieldVisitCallback callback) const override;
281
282 Dependencies selected_fields_;
283 const Selector selector_;
284};
285
286/// @brief Collects knowledge of the structure of the protobuf messages
287/// allowing for efficient loops over nested messages to apply a callback
288/// to the ones selected by the 'selector' function.
289///
290/// If you do not have static knowledge of the required messages, you should
291/// use @ref VisitMessagesRecursive that is equivalent to
292/// MessagesVisitor with a 'return true' selector.
293///
294/// @warning You should not construct this at runtime as it performs significant
295/// computations in the constructor to precompile the visitors.
296/// You should create this ones at start-up.
297class MessagesVisitor final : public BaseVisitor<MessageVisitCallback> {
298public:
299 using Selector = utils::function_ref<bool(const google::protobuf::Descriptor& descriptor)>;
300
301 /// @brief Creates the visitor with the given selector for runtime compilation
302 /// and compiles it for the message types we can find.
303 explicit MessagesVisitor(Selector selector);
304
305 /// @brief Creates the visitor with the given selector
306 /// and compiles it for the given message types and their fields recursively.
307 MessagesVisitor(Selector selector, const DescriptorList& descriptors);
308
309 /// @brief Creates the visitor with custom thread locking behavior
310 /// and the given selector for runtime compilation.
311 ///
312 /// @warning Do not use this unless you know what you are doing.
313 MessagesVisitor(Selector selector, LockBehavior lock_behavior);
314
315 /// @brief Creates the visitor with custom thread locking behavior
316 /// and the given selector; compiles it for the given message types.
317 ///
318 /// @warning Do not use this unless you know what you are doing.
319 MessagesVisitor(Selector selector, const DescriptorList& descriptors, LockBehavior lock_behavior);
320
321 /// @cond
322 /// Only for internal use.
323 const DescriptorSet& GetSelectedMessages(utils::impl::InternalTag) const { return selected_messages_; }
324 /// @endcond
325
326private:
327 void CompileOne(const google::protobuf::Descriptor& descriptor) override;
328
329 bool IsSelected(const google::protobuf::Descriptor& descriptor) const override {
330 return selected_messages_.find(&descriptor) != selected_messages_.end();
331 }
332
333 void DoVisit(google::protobuf::Message& message, MessageVisitCallback callback) const override;
334
335 DescriptorSet selected_messages_;
336 const Selector selector_;
337};
338
339} // namespace ugrpc
340
341USERVER_NAMESPACE_END