userver: /data/code/userver/testsuite/pytest_plugins/pytest_userver/plugins/caches.py Source File
Loading...
Searching...
No Matches
caches.py
1"""
2Fixtures for controlling userver caches.
3"""
4
5import copy
6import enum
7import types
8import typing
9
10import pytest
11
12from testsuite.daemons.pytest_plugin import DaemonInstance
13
14
16 def __init__(self):
17 self._hooks = {}
18
19 @property
20 def userver_cache_control_hooks(self) -> typing.Dict[str, str]:
21 return self._hooks
22
23 def pytest_plugin_registered(self, plugin, manager):
24 if not isinstance(plugin, types.ModuleType):
25 return
26 uhooks = getattr(plugin, 'USERVER_CACHE_CONTROL_HOOKS', None)
27 if uhooks is None:
28 return
29 if not isinstance(uhooks, dict):
30 raise RuntimeError(
31 f'USERVER_CACHE_CONTROL_HOOKS must be dictionary: '
32 f'{{cache_name: fixture_name}}, got {uhooks} instead',
33 )
34 for cache_name, fixture_name in uhooks.items():
35 if cache_name in self._hooks:
36 raise RuntimeError(
37 f'USERVER_CACHE_CONTROL_HOOKS: hook already registered '
38 f'for cache {cache_name}',
39 )
40 self._hooks[cache_name] = fixture_name
41
42
44 def __init__(self):
45 # None means that we should update all caches.
46 # We invalidate all caches at the start of each test.
47 self._invalidated_caches: typing.Optional[typing.Set[str]] = None
48
49 def invalidate_all(self) -> None:
50 self._invalidated_caches = None
51
52 def invalidate(self, caches: typing.Iterable[str]) -> None:
53 if self._invalidated_caches is not None:
54 self._invalidated_caches.update(caches)
55
56 @property
57 def should_update_all_caches(self) -> bool:
58 return self._invalidated_caches is None
59
60 @property
61 def caches_to_update(self) -> typing.FrozenSet[str]:
62 assert self._invalidated_caches is not None
63 return frozenset(self._invalidated_caches)
64
65 @property
66 def has_caches_to_update(self) -> bool:
67 caches = self._invalidated_caches
68 return caches is None or bool(caches)
69
70 def on_caches_updated(self, caches: typing.Iterable[str]) -> None:
71 if self._invalidated_caches is not None:
72 self._invalidated_caches.difference_update(caches)
73
74 def on_all_caches_updated(self) -> None:
75 self._invalidated_caches = set()
76
77 def assign_copy(self, other: 'InvalidationState') -> None:
78 # pylint: disable=protected-access
79 self._invalidated_caches = copy.deepcopy(other._invalidated_caches)
80
81
82class CacheControlAction(enum.Enum):
83 FULL = 0
84 INCREMENTAL = 1
85 EXCLUDE = 2
86
87
89 action = CacheControlAction.FULL
90
91 def exclude(self) -> None:
92 """Exclude cache from update."""
93 self.action = CacheControlAction.EXCLUDE
94
95 def incremental(self) -> None:
96 """Request incremental update instead of full."""
97 self.action = CacheControlAction.INCREMENTAL
98
99
101 def __init__(
102 self,
103 *,
104 enabled: bool,
105 context: typing.Dict,
106 fixtures: typing.List[str],
107 caches_disabled: typing.Set[str],
108 ):
109 self._enabled = enabled
110 self._context = context
111 self._fixtures = fixtures
112 self._caches_disabled = caches_disabled
113
115 self, cache_names: typing.Optional[typing.List[str]],
116 ) -> typing.Tuple[
117 typing.Dict, typing.List[typing.Tuple[str, CacheControlAction]],
118 ]:
119 """Query cache control handlers.
120
121 Returns pair (staged, [(cache_name, action), ...])
122 """
123 if not self._enabled:
124 if cache_names is None:
125 cache_names = self._fixtures.keys()
126 return {cache_name: None for cache_name in cache_names}, []
127 staged = {}
128 actions = []
129 for cache_name, fixture in self._fixtures.items():
130 if cache_names and cache_name not in cache_names:
131 continue
132 if cache_name in self._caches_disabled:
133 staged[cache_name] = None
134 continue
135 context = self._context.get(cache_name)
136 request = CacheControlRequest()
137 staged[cache_name] = fixture(request, context)
138 actions.append((cache_name, request.action))
139 return staged, actions
140
141 def commit_staged(self, staged: typing.Dict[str, typing.Any]) -> None:
142 """Apply recently commited state."""
143 self._context.update(staged)
144
145
146def pytest_configure(config):
147 config.pluginmanager.register(UserverCachePlugin(), 'userver_cache')
148 config.addinivalue_line(
149 'markers', 'userver_cache_control_disabled: disable cache control',
150 )
151
152
153@pytest.fixture
154def cache_invalidation_state() -> InvalidationState:
155 """
156 A fixture for notifying the service of changes in cache data sources.
157
158 Intended to be used by other fixtures that represent those data sources,
159 not by tests directly.
160
161 @ingroup userver_testsuite_fixtures
162 """
163 return InvalidationState()
164
165
166@pytest.fixture(scope='session')
167def _userver_cache_control_context() -> typing.Dict:
168 return {}
169
170
171@pytest.fixture
172def _userver_cache_fixtures(
173 pytestconfig, request,
174) -> typing.Dict[str, typing.Callable]:
175 plugin: UserverCachePlugin = pytestconfig.pluginmanager.get_plugin(
176 'userver_cache',
177 )
178 result = {}
179 for cache_name, fixture_name in plugin.userver_cache_control_hooks.items():
180 result[cache_name] = request.getfixturevalue(fixture_name)
181 return result
182
183
184@pytest.fixture
186 _userver_cache_control_context, _userver_cache_fixtures, request,
187) -> typing.Callable[[DaemonInstance], CacheControl]:
188 """Userver cache control handler.
189
190 To install per cache handler use USERVER_CACHE_CONTROL_HOOKS variable
191 in your pytest plugin:
192
193 @code
194 USERVER_CACHE_CONTROL_HOOKS = {
195 'my-cache-name': 'my_cache_cc',
196 }
197
198 @pytest.fixture
199 def my_cache_cc(my_cache_context):
200 def cache_control(request, state):
201 new_state = my_cache_context.get_state()
202 if state == new_state:
203 # Cache is already up to date, no need to update
204 request.exclude()
205 else:
206 # Request incremental update, if you cache supports it
207 request.incremental()
208 return new_state
209 return cache_control
210 @endcode
211
212 @ingroup userver_testsuite_fixtures
213 """
214 enabled = True
215 caches_disabled = set()
216
217 def userver_cache_control_disabled(
218 caches: typing.Sequence[str] = None, *, reason: str,
219 ):
220 if caches is not None:
221 caches_disabled.update(caches)
222 return enabled
223 return False
224
225 for mark in request.node.iter_markers('userver_cache_control_disabled'):
226 enabled = userver_cache_control_disabled(*mark.args, **mark.kwargs)
227
228 def get_cache_control(daemon: DaemonInstance):
229 context = _userver_cache_control_context.setdefault(daemon.id, {})
230 return CacheControl(
231 context=context,
232 fixtures=_userver_cache_fixtures,
233 enabled=enabled,
234 caches_disabled=caches_disabled,
235 )
236
237 return get_cache_control