userver: /data/code/service_template/third_party/userver/testsuite/pytest_plugins/pytest_userver/plugins/caches.py Source File
Loading...
Searching...
No Matches
caches.py
1"""
2Fixtures for controlling userver caches.
3"""
4import copy
5import enum
6import types
7import typing
8
9import pytest
10
11from testsuite.daemons.pytest_plugin import DaemonInstance
12
13
15 def __init__(self):
16 self._hooks = {}
17
18 @property
19 def userver_cache_control_hooks(self) -> typing.Dict[str, str]:
20 return self._hooks
21
22 def pytest_plugin_registered(self, plugin, manager):
23 if not isinstance(plugin, types.ModuleType):
24 return
25 uhooks = getattr(plugin, 'USERVER_CACHE_CONTROL_HOOKS', None)
26 if uhooks is None:
27 return
28 if not isinstance(uhooks, dict):
29 raise RuntimeError(
30 f'USERVER_CACHE_CONTROL_HOOKS must be dictionary: '
31 f'{{cache_name: fixture_name}}, got {uhooks} instead',
32 )
33 for cache_name, fixture_name in uhooks.items():
34 if cache_name in self._hooks:
35 raise RuntimeError(
36 f'USERVER_CACHE_CONTROL_HOOKS: hook already registered '
37 f'for cache {cache_name}',
38 )
39 self._hooks[cache_name] = fixture_name
40
41
43 def __init__(self):
44 # None means that we should update all caches.
45 # We invalidate all caches at the start of each test.
46 self._invalidated_caches: typing.Optional[typing.Set[str]] = None
47
48 def invalidate_all(self) -> None:
49 self._invalidated_caches = None
50
51 def invalidate(self, caches: typing.Iterable[str]) -> None:
52 if self._invalidated_caches is not None:
53 self._invalidated_caches.update(caches)
54
55 @property
56 def should_update_all_caches(self) -> bool:
57 return self._invalidated_caches is None
58
59 @property
60 def caches_to_update(self) -> typing.FrozenSet[str]:
61 assert self._invalidated_caches is not None
62 return frozenset(self._invalidated_caches)
63
64 @property
65 def has_caches_to_update(self) -> bool:
66 caches = self._invalidated_caches
67 return caches is None or bool(caches)
68
69 def on_caches_updated(self, caches: typing.Iterable[str]) -> None:
70 if self._invalidated_caches is not None:
71 self._invalidated_caches.difference_update(caches)
72
73 def on_all_caches_updated(self) -> None:
74 self._invalidated_caches = set()
75
76 def assign_copy(self, other: 'InvalidationState') -> None:
77 # pylint: disable=protected-access
78 self._invalidated_caches = copy.deepcopy(other._invalidated_caches)
79
80
81class CacheControlAction(enum.Enum):
82 FULL = 0
83 INCREMENTAL = 1
84 EXCLUDE = 2
85
86
88 action = CacheControlAction.FULL
89
90 def exclude(self) -> None:
91 """Exclude cache from update."""
92 self.action = CacheControlAction.EXCLUDE
93
94 def incremental(self) -> None:
95 """Request incremental update instead of full."""
96 self.action = CacheControlAction.INCREMENTAL
97
98
100 def __init__(
101 self,
102 *,
103 enabled: bool,
104 context: typing.Dict,
105 fixtures: typing.List[str],
106 caches_disabled: typing.Set[str],
107 ):
108 self._enabled = enabled
109 self._context = context
110 self._fixtures = fixtures
111 self._caches_disabled = caches_disabled
112
114 self, cache_names: typing.Optional[typing.List[str]],
115 ) -> typing.Tuple[
116 typing.Dict, typing.List[typing.Tuple[str, CacheControlAction]],
117 ]:
118 """Query cache control handlers.
119
120 Returns pair (staged, [(cache_name, action), ...])
121 """
122 if not self._enabled:
123 if cache_names is None:
124 cache_names = self._fixtures.keys()
125 return {cache_name: None for cache_name in cache_names}, []
126 staged = {}
127 actions = []
128 for cache_name, fixture in self._fixtures.items():
129 if cache_names and cache_name not in cache_names:
130 continue
131 if cache_name in self._caches_disabled:
132 staged[cache_name] = None
133 continue
134 context = self._context.get(cache_name)
135 request = CacheControlRequest()
136 staged[cache_name] = fixture(request, context)
137 actions.append((cache_name, request.action))
138 return staged, actions
139
140 def commit_staged(self, staged: typing.Dict[str, typing.Any]) -> None:
141 """Apply recently commited state."""
142 self._context.update(staged)
143
144
145def pytest_configure(config):
146 config.pluginmanager.register(UserverCachePlugin(), 'userver_cache')
147 config.addinivalue_line(
148 'markers', 'userver_cache_control_disabled: disable cache control',
149 )
150
151
152@pytest.fixture
153def cache_invalidation_state() -> InvalidationState:
154 """
155 A fixture for notifying the service of changes in cache data sources.
156
157 Intended to be used by other fixtures that represent those data sources,
158 not by tests directly.
159
160 @ingroup userver_testsuite_fixtures
161 """
162 return InvalidationState()
163
164
165@pytest.fixture(scope='session')
166def _userver_cache_control_context() -> typing.Dict:
167 return {}
168
169
170@pytest.fixture
171def _userver_cache_fixtures(
172 pytestconfig, request,
173) -> typing.Dict[str, typing.Callable]:
174 plugin: UserverCachePlugin = pytestconfig.pluginmanager.get_plugin(
175 'userver_cache',
176 )
177 result = {}
178 for cache_name, fixture_name in plugin.userver_cache_control_hooks.items():
179 result[cache_name] = request.getfixturevalue(fixture_name)
180 return result
181
182
183@pytest.fixture
185 _userver_cache_control_context, _userver_cache_fixtures, request,
186) -> typing.Callable[[DaemonInstance], CacheControl]:
187 """Userver cache control handler.
188
189 To install per cache handler use USERVER_CACHE_CONTROL_HOOKS variable
190 in your pytest plugin:
191
192 @code
193 USERVER_CACHE_CONTROL_HOOKS = {
194 'my-cache-name': 'my_cache_cc',
195 }
196
197 @pytest.fixture
198 def my_cache_cc(my_cache_context):
199 def cache_control(request, state):
200 new_state = my_cache_context.get_state()
201 if state == new_state:
202 # Cache is already up to date, no need to update
203 request.exclude()
204 else:
205 # Request incremental update, if you cache supports it
206 request.incremental()
207 return new_state
208 return cache_control
209 @endcode
210
211 @ingroup userver_testsuite_fixtures
212 """
213 enabled = True
214 caches_disabled = set()
215
216 def userver_cache_control_disabled(
217 caches: typing.Sequence[str] = None, *, reason: str,
218 ):
219 if caches is not None:
220 caches_disabled.update(caches)
221 return enabled
222 return False
223
224 for mark in request.node.iter_markers('userver_cache_control_disabled'):
225 enabled = userver_cache_control_disabled(*mark.args, **mark.kwargs)
226
227 def get_cache_control(daemon: DaemonInstance):
228 context = _userver_cache_control_context.setdefault(daemon.id, {})
229 return CacheControl(
230 context=context,
231 fixtures=_userver_cache_fixtures,
232 enabled=enabled,
233 caches_disabled=caches_disabled,
234 )
235
236 return get_cache_control