userver: /data/code/userver/testsuite/pytest_plugins/pytest_userver/client.py Source File
Loading...
Searching...
No Matches
client.py
1"""
2Python module that provides clients for functional tests with
3testsuite; see
4@ref scripts/docs/en/userver/functional_testing.md for an introduction.
5
6@ingroup userver_testsuite
7"""
8
9# pylint: disable=too-many-lines
10
11from __future__ import annotations
12
13from collections.abc import Awaitable
14from collections.abc import Iterator
15import contextlib
16import copy
17import dataclasses
18import json
19import logging
20import typing
21from typing import Any
22from typing import TypeAlias
23import warnings
24
25import aiohttp
26
27from testsuite import logcapture
28from testsuite import utils
29from testsuite.daemons import service_client
30from testsuite.utils import approx
31from testsuite.utils import http
32
33import pytest_userver.metrics as metric_module # pylint: disable=import-error
34from pytest_userver.plugins import caches
35
36# @cond
37logger = logging.getLogger(__name__)
38# @endcond
39
40JsonAny: TypeAlias = int | float | str | list | dict
41JsonAnyOptional: TypeAlias = JsonAny | None
42
43_UNKNOWN_STATE = '__UNKNOWN__'
44
45CACHE_INVALIDATION_MESSAGE = (
46 'Direct cache invalidation is deprecated.\n'
47 '\n'
48 ' - Use client.update_server_state() to synchronize service state\n'
49 ' - Explicitly pass cache names to invalidate, e.g.: '
50 'invalidate_caches(cache_names=[...]).'
51)
52
53
54class BaseError(Exception):
55 """Base class for exceptions of this module."""
56
57
59 pass
60
61
63 pass
64
65
75 pass
76
77
79 pass
80
81
83 def __init__(self) -> None:
84 self.suspended_tasks: set[str] = set()
85 self.tasks_to_suspend: set[str] = set()
86
87
89 def __init__(self, name, reason):
90 self.name = name
91 self.reason = reason
92 super().__init__(f'Testsuite task {name!r} failed: {reason}')
93
94
95@dataclasses.dataclass(frozen=True)
97 testsuite_action_path: str | None = None
98 server_monitor_path: str | None = None
99
100
101Metric: TypeAlias = metric_module.Metric
102
103
105 """
106 Base asyncio userver client that implements HTTP requests to service.
107
108 Compatible with werkzeug interface.
109
110 @ingroup userver_testsuite
111 """
112
113 def __init__(self, client):
114 self._client = client
115
116 async def post(
117 self,
118 path: str,
119 # pylint: disable=redefined-outer-name
120 json: JsonAnyOptional = None,
121 data: Any = None,
122 params: dict[str, str] | None = None,
123 bearer: str | None = None,
124 x_real_ip: str | None = None,
125 headers: dict[str, str] | None = None,
126 **kwargs,
127 ) -> http.ClientResponse:
128 """
129 Make a HTTP POST request
130 """
131 response = await self._client.post(
132 path,
133 json=json,
134 data=data,
135 params=params,
136 headers=headers,
137 bearer=bearer,
138 x_real_ip=x_real_ip,
139 **kwargs,
140 )
141 return await self._wrap_client_response(response)
142
143 async def put(
144 self,
145 path,
146 # pylint: disable=redefined-outer-name
147 json: JsonAnyOptional = None,
148 data: Any = None,
149 params: dict[str, str] | None = None,
150 bearer: str | None = None,
151 x_real_ip: str | None = None,
152 headers: dict[str, str] | None = None,
153 **kwargs,
154 ) -> http.ClientResponse:
155 """
156 Make a HTTP PUT request
157 """
158 response = await self._client.put(
159 path,
160 json=json,
161 data=data,
162 params=params,
163 headers=headers,
164 bearer=bearer,
165 x_real_ip=x_real_ip,
166 **kwargs,
167 )
168 return await self._wrap_client_response(response)
169
170 async def patch(
171 self,
172 path,
173 # pylint: disable=redefined-outer-name
174 json: JsonAnyOptional = None,
175 data: Any = None,
176 params: dict[str, str] | None = None,
177 bearer: str | None = None,
178 x_real_ip: str | None = None,
179 headers: dict[str, str] | None = None,
180 **kwargs,
181 ) -> http.ClientResponse:
182 """
183 Make a HTTP PATCH request
184 """
185 response = await self._client.patch(
186 path,
187 json=json,
188 data=data,
189 params=params,
190 headers=headers,
191 bearer=bearer,
192 x_real_ip=x_real_ip,
193 **kwargs,
194 )
195 return await self._wrap_client_response(response)
196
197 async def get(
198 self,
199 path: str,
200 headers: dict[str, str] | None = None,
201 bearer: str | None = None,
202 x_real_ip: str | None = None,
203 **kwargs,
204 ) -> http.ClientResponse:
205 """
206 Make a HTTP GET request
207 """
208 response = await self._client.get(
209 path,
210 headers=headers,
211 bearer=bearer,
212 x_real_ip=x_real_ip,
213 **kwargs,
214 )
215 return await self._wrap_client_response(response)
216
217 async def delete(
218 self,
219 path: str,
220 headers: dict[str, str] | None = None,
221 bearer: str | None = None,
222 x_real_ip: str | None = None,
223 **kwargs,
224 ) -> http.ClientResponse:
225 """
226 Make a HTTP DELETE request
227 """
228 response = await self._client.delete(
229 path,
230 headers=headers,
231 bearer=bearer,
232 x_real_ip=x_real_ip,
233 **kwargs,
234 )
235 return await self._wrap_client_response(response)
236
237 async def options(
238 self,
239 path: str,
240 headers: dict[str, str] | None = None,
241 bearer: str | None = None,
242 x_real_ip: str | None = None,
243 **kwargs,
244 ) -> http.ClientResponse:
245 """
246 Make a HTTP OPTIONS request
247 """
248 response = await self._client.options(
249 path,
250 headers=headers,
251 bearer=bearer,
252 x_real_ip=x_real_ip,
253 **kwargs,
254 )
255 return await self._wrap_client_response(response)
256
257 async def request(
258 self,
259 http_method: str,
260 path: str,
261 **kwargs,
262 ) -> http.ClientResponse:
263 """
264 Make a HTTP request with the specified method
265 """
266 response = await self._client.request(http_method, path, **kwargs)
267 return await self._wrap_client_response(response)
268
269 @property
271 """
272 @deprecated Use pytest_userver.client.Client directly instead.
273 """
274 return self._client
275
276 def _wrap_client_response(self, response: aiohttp.ClientResponse) -> Awaitable[http.ClientResponse]:
277 return http.wrap_client_response(response)
278
279
280# @cond
281
282
283def _wrap_client_error(func):
284 async def _wrapper(*args, **kwargs):
285 try:
286 return await func(*args, **kwargs)
287 except aiohttp.client_exceptions.ClientResponseError as exc:
288 raise http.HttpResponseError(
289 url=exc.request_info.url,
290 status=exc.status,
291 )
292
293 return _wrapper
294
295
296class AiohttpClientMonitor(service_client.AiohttpClient):
297 _config: TestsuiteClientConfig
298
299 def __init__(self, base_url, *, config: TestsuiteClientConfig, **kwargs):
300 super().__init__(base_url, **kwargs)
301 self._config = config
302
303 async def get_metrics(self, prefix=None):
304 if not self._config.server_monitor_path:
305 raise ConfigurationError(
306 'handler-server-monitor component is not configured',
307 )
308 params = {'format': 'internal'}
309 if prefix is not None:
310 params['prefix'] = prefix
311 response = await self.get(
312 self._config.server_monitor_path,
313 params=params,
314 )
315 async with response:
316 response.raise_for_status()
317 return await response.json(content_type=None)
318
319 async def get_metric(self, metric_name):
320 metrics = await self.get_metrics(metric_name)
321 assert metric_name in metrics, (
322 f'No metric with name {metric_name!r}. Use "single_metric" function instead of "get_metric"'
323 )
324 return metrics[metric_name]
325
326 async def metrics_raw(
327 self,
328 output_format,
329 *,
330 path: str = None,
331 prefix: str = None,
332 labels: dict[str, str] | None = None,
333 ) -> str:
334 if not self._config.server_monitor_path:
335 raise ConfigurationError(
336 'handler-server-monitor component is not configured',
337 )
338
339 params = {'format': output_format}
340 if prefix:
341 params['prefix'] = prefix
342
343 if path:
344 params['path'] = path
345
346 if labels:
347 params['labels'] = json.dumps(labels)
348
349 response = await self.get(
350 self._config.server_monitor_path,
351 params=params,
352 )
353 async with response:
354 response.raise_for_status()
355 return await response.text()
356
357 async def metrics(
358 self,
359 *,
360 path: str = None,
361 prefix: str = None,
362 labels: dict[str, str] | None = None,
363 ) -> metric_module.MetricsSnapshot:
364 response = await self.metrics_raw(
365 output_format='json',
366 path=path,
367 prefix=prefix,
368 labels=labels,
369 )
370 return metric_module.MetricsSnapshot.from_json(str(response))
371
372 async def single_metric_optional(
373 self,
374 path: str,
375 *,
376 labels: dict[str, str] | None = None,
377 ) -> Metric | None:
378 response = await self.metrics(path=path, labels=labels)
379 metrics_list = response.get(path, [])
380
381 assert len(metrics_list) <= 1, (f'More than one metric found for path {path} and labels {labels}: {response}',)
382
383 if not metrics_list:
384 return None
385
386 return next(iter(metrics_list))
387
388 async def single_metric(
389 self,
390 path: str,
391 *,
392 labels: dict[str, str] | None = None,
393 ) -> Metric:
394 value = await self.single_metric_optional(path, labels=labels)
395 assert value is not None, (f'No metric was found for path {path} and labels {labels}',)
396 return value
397
398
399# @endcond
400
401
403 """
404 Asyncio userver client for monitor listeners, typically retrieved from
405 plugins.service_client.monitor_client fixture.
406
407 Compatible with werkzeug interface.
408
409 @ingroup userver_testsuite
410 """
411
413 self,
414 *,
415 path: str | None = None,
416 prefix: str | None = None,
417 labels: dict[str, str] | None = None,
418 diff_gauge: bool = False,
419 ) -> MetricsDiffer:
420 """
421 Creates a `MetricsDiffer` that fetches metrics using this client.
422 It's recommended to use this method over `metrics` to make sure
423 the tests don't affect each other.
424
425 With `diff_gauge` off, only RATE metrics are differentiated.
426 With `diff_gauge` on, GAUGE metrics are differentiated as well,
427 which may lead to nonsensical results for those.
428
429 @param path Optional full metric path
430 @param prefix Optional prefix on which the metric paths should start
431 @param labels Optional dictionary of labels that must be in the metric
432 @param diff_gauge Whether to differentiate GAUGE metrics
433
434 @code
435 async with monitor_client.metrics_diff(prefix='foo') as differ:
436 # Do something that makes the service update its metrics
437 assert differ.value_at('path-suffix', {'label'}) == 42
438 @endcode
439 """
440 return MetricsDiffer(
441 _client=self,
442 _path=path,
443 _prefix=prefix,
444 _labels=labels,
445 _diff_gauge=diff_gauge,
446 )
447
448 @_wrap_client_error
449 async def metrics(
450 self,
451 *,
452 path: str | None = None,
453 prefix: str | None = None,
454 labels: dict[str, str] | None = None,
455 ) -> metric_module.MetricsSnapshot:
456 """
457 Returns a dict of metric names to Metric.
458
459 @param path Optional full metric path
460 @param prefix Optional prefix on which the metric paths should start
461 @param labels Optional dictionary of labels that must be in the metric
462 """
463 return await self._client.metrics(
464 path=path,
465 prefix=prefix,
466 labels=labels,
467 )
468
469 @_wrap_client_error
471 self,
472 path: str,
473 *,
474 labels: dict[str, str] | None = None,
475 ) -> Metric | None:
476 """
477 Either return a Metric or None if there's no such metric.
478
479 @param path Full metric path
480 @param labels Optional dictionary of labels that must be in the metric
481
482 @throws AssertionError if more than one metric returned
483 """
484 return await self._client.single_metric_optional(path, labels=labels)
485
486 @_wrap_client_error
487 async def single_metric(
488 self,
489 path: str,
490 *,
491 labels: dict[str, str] | None = None,
492 ) -> Metric | None:
493 """
494 Returns the Metric.
495
496 @param path Full metric path
497 @param labels Optional dictionary of labels that must be in the metric
498
499 @throws AssertionError if more than one metric or no metric found
500 """
501 return await self._client.single_metric(path, labels=labels)
502
503 @_wrap_client_error
504 async def metrics_raw(
505 self,
506 output_format: str,
507 *,
508 path: str | None = None,
509 prefix: str | None = None,
510 labels: dict[str, str] | None = None,
511 ) -> dict[str, Metric]:
512 """
513 Low level function that returns metrics in a specific format.
514 Use `metrics` and `single_metric` instead if possible.
515
516 @param output_format Metric output format. See
517 server::handlers::ServerMonitor for a list of supported formats.
518 @param path Optional full metric path
519 @param prefix Optional prefix on which the metric paths should start
520 @param labels Optional dictionary of labels that must be in the metric
521 """
522 return await self._client.metrics_raw(
523 output_format=output_format,
524 path=path,
525 prefix=prefix,
526 labels=labels,
527 )
528
529 @_wrap_client_error
530 async def get_metrics(self, prefix=None):
531 """
532 @deprecated Use metrics() or single_metric() instead
533 """
534 return await self._client.get_metrics(prefix=prefix)
535
536 @_wrap_client_error
537 async def get_metric(self, metric_name):
538 """
539 @deprecated Use metrics() or single_metric() instead
540 """
541 return await self._client.get_metric(metric_name)
542
543 @_wrap_client_error
544 async def fired_alerts(self):
545 response = await self._client.get('/service/fired-alerts')
546 assert response.status == 200
547 return (await response.json())['alerts']
548
549
551 """
552 A helper class for computing metric differences.
553
554 @see ClientMonitor.metrics_diff
555 @ingroup userver_testsuite
556 """
557
558 # @cond
559 def __init__(
560 self,
561 _client: ClientMonitor,
562 _path: str | None,
563 _prefix: str | None,
564 _labels: dict[str, str] | None,
565 _diff_gauge: bool,
566 ):
567 self._client = _client
568 self._path = _path
569 self._prefix = _prefix
570 self._labels = _labels
571 self._diff_gauge = _diff_gauge
572 self._baseline: metric_module.MetricsSnapshot | None = None
573 self._current: metric_module.MetricsSnapshot | None = None
574 self._diff: metric_module.MetricsSnapshot | None = None
575
576 # @endcond
577
578 @property
579 def baseline(self) -> metric_module.MetricsSnapshot:
580 assert self._baseline is not None
581 return self._baseline
582
583 @baseline.setter
584 def baseline(self, value: metric_module.MetricsSnapshot) -> None:
585 self._baseline = value
586 if self._current is not None:
587 self._diff = _subtract_metrics_snapshots(
588 self._current,
589 self._baseline,
590 self._diff_gauge,
591 )
592
593 @property
594 def current(self) -> metric_module.MetricsSnapshot:
595 assert self._current is not None, 'Set self.current first'
596 return self._current
597
598 @current.setter
599 def current(self, value: metric_module.MetricsSnapshot) -> None:
600 self._current = value
601 assert self._baseline is not None, 'Set self.baseline first'
602 self._diff = _subtract_metrics_snapshots(
603 self._current,
604 self._baseline,
605 self._diff_gauge,
606 )
607
608 @property
609 def diff(self) -> metric_module.MetricsSnapshot:
610 assert self._diff is not None, 'Set self.current first'
611 return self._diff
612
614 self,
615 subpath: str | None = None,
616 add_labels: dict[str, str] | None = None,
617 *,
618 default: float | None = None,
619 ) -> metric_module.MetricValue:
620 """
621 Returns a single metric value at the specified path, prepending
622 the path provided at construction. If a dict of labels is provided,
623 does en exact match of labels, prepending the labels provided
624 at construction.
625
626 @param subpath Suffix of the metric path; the path provided
627 at construction is prepended
628 @param add_labels Labels that the metric must have in addition
629 to the labels provided at construction
630 @param default An optional default value in case the metric is missing
631 @throws AssertionError if not one metric by path
632 """
633 base_path = self._path or self._prefix
634 if base_path and subpath:
635 path = f'{base_path}.{subpath}'
636 else:
637 assert base_path or subpath, 'No path provided'
638 path = base_path or subpath or ''
639 labels: dict | None = None
640 if self._labels is not None or add_labels is not None:
641 labels = {**(self._labels or {}), **(add_labels or {})}
642 return self.diff.value_at(path, labels, default=default)
643
644 async def fetch(self) -> metric_module.MetricsSnapshot:
645 """
646 Fetches metric values from the service.
647 """
648 return await self._client.metrics(
649 path=self._path,
650 prefix=self._prefix,
651 labels=self._labels,
652 )
653
654 async def __aenter__(self) -> MetricsDiffer:
655 self._baseline = await self.fetch()
656 self._current = None
657 return self
658
659 async def __aexit__(self, exc_type, exc, exc_tb) -> None:
660 self.current = await self.fetch()
661
662
663# @cond
664
665
666def _subtract_metrics_snapshots(
667 current: metric_module.MetricsSnapshot,
668 initial: metric_module.MetricsSnapshot,
669 diff_gauge: bool,
670) -> metric_module.MetricsSnapshot:
671 return metric_module.MetricsSnapshot({
672 path: {_subtract_metrics(path, current_metric, initial, diff_gauge) for current_metric in current_group}
673 for path, current_group in current.items()
674 })
675
676
677def _subtract_metrics(
678 path: str,
679 current_metric: metric_module.Metric,
680 initial: metric_module.MetricsSnapshot,
681 diff_gauge: bool,
682) -> metric_module.Metric:
683 initial_group = initial.get(path, None)
684 if initial_group is None:
685 return current_metric
686 initial_metric = next(
687 (x for x in initial_group if x.labels == current_metric.labels),
688 None,
689 )
690 if initial_metric is None:
691 return current_metric
692
693 return metric_module.Metric(
694 labels=current_metric.labels,
695 value=_subtract_metric_values(
696 current=current_metric,
697 initial=initial_metric,
698 diff_gauge=diff_gauge,
699 ),
700 _type=current_metric.type(),
701 )
702
703
704def _subtract_metric_values(
705 current: metric_module.Metric,
706 initial: metric_module.Metric,
707 diff_gauge: bool,
708) -> metric_module.MetricValue:
709 assert current.type() is not metric_module.MetricType.UNSPECIFIED
710 assert initial.type() is not metric_module.MetricType.UNSPECIFIED
711 assert current.type() == initial.type()
712
713 if isinstance(current.value, metric_module.Histogram):
714 assert isinstance(initial.value, metric_module.Histogram)
715 return _subtract_metric_values_hist(current=current, initial=initial)
716 else:
717 assert not isinstance(initial.value, metric_module.Histogram)
718 return _subtract_metric_values_num(
719 current=current,
720 initial=initial,
721 diff_gauge=diff_gauge,
722 )
723
724
725def _subtract_metric_values_num(
726 current: metric_module.Metric,
727 initial: metric_module.Metric,
728 diff_gauge: bool,
729) -> float:
730 current_value = typing.cast(float, current.value)
731 initial_value = typing.cast(float, initial.value)
732 should_diff = (
733 current.type() is metric_module.MetricType.RATE or initial.type() is metric_module.MetricType.RATE or diff_gauge
734 )
735 return current_value - initial_value if should_diff else current_value
736
737
738def _subtract_metric_values_hist(
739 current: metric_module.Metric,
740 initial: metric_module.Metric,
741) -> metric_module.Histogram:
742 current_value = typing.cast(metric_module.Histogram, current.value)
743 initial_value = typing.cast(metric_module.Histogram, initial.value)
744 assert current_value.bounds == initial_value.bounds
745 return metric_module.Histogram(
746 bounds=current_value.bounds,
747 buckets=[t[0] - t[1] for t in zip(current_value.buckets, initial_value.buckets, strict=True)],
748 inf=current_value.inf - initial_value.inf,
749 )
750
751
752class AiohttpClient(service_client.AiohttpClient):
753 PeriodicTaskFailed = PeriodicTaskFailed
754 TestsuiteActionFailed = TestsuiteActionFailed
755 TestsuiteTaskNotFound = TestsuiteTaskNotFound
756 TestsuiteTaskConflict = TestsuiteTaskConflict
757 TestsuiteTaskFailed = TestsuiteTaskFailed
758
759 def __init__(
760 self,
761 base_url: str,
762 *,
763 config: TestsuiteClientConfig,
764 mocked_time,
765 log_capture_fixture: logcapture.CaptureServer,
766 testpoint,
767 testpoint_control,
768 cache_invalidation_state,
769 span_id_header=None,
770 api_coverage_report=None,
771 periodic_tasks_state: PeriodicTasksState | None = None,
772 allow_all_caches_invalidation: bool = True,
773 cache_control: caches.CacheControl | None = None,
774 asyncexc_check=None,
775 **kwargs,
776 ):
777 super().__init__(base_url, span_id_header=span_id_header, **kwargs)
778 self._config = config
779 self._periodic_tasks = periodic_tasks_state
780 self._testpoint = testpoint
781 self._log_capture_fixture = log_capture_fixture
782 self._state_manager = _StateManager(
783 mocked_time=mocked_time,
784 testpoint=self._testpoint,
785 testpoint_control=testpoint_control,
786 invalidation_state=cache_invalidation_state,
787 cache_control=cache_control,
788 )
789 self._api_coverage_report = api_coverage_report
790 self._allow_all_caches_invalidation = allow_all_caches_invalidation
791 self._asyncexc_check = asyncexc_check
792
793 async def run_periodic_task(self, name):
794 response = await self._testsuite_action('run_periodic_task', name=name)
795 if not response['status']:
796 raise self.PeriodicTaskFailed(f'Periodic task {name} failed')
797
798 async def suspend_periodic_tasks(self, names: list[str]) -> None:
799 if not self._periodic_tasks:
800 raise ConfigurationError('No periodic_tasks_state given')
801 self._periodic_tasks.tasks_to_suspend.update(names)
802 await self._suspend_periodic_tasks()
803
804 async def resume_periodic_tasks(self, names: list[str]) -> None:
805 if not self._periodic_tasks:
806 raise ConfigurationError('No periodic_tasks_state given')
807 self._periodic_tasks.tasks_to_suspend.difference_update(names)
808 await self._suspend_periodic_tasks()
809
810 async def resume_all_periodic_tasks(self) -> None:
811 if not self._periodic_tasks:
812 raise ConfigurationError('No periodic_tasks_state given')
813 self._periodic_tasks.tasks_to_suspend.clear()
814 await self._suspend_periodic_tasks()
815
816 async def write_cache_dumps(
817 self,
818 names: list[str],
819 *,
820 testsuite_skip_prepare=False,
821 ) -> None:
822 await self._testsuite_action(
823 'write_cache_dumps',
824 names=names,
825 testsuite_skip_prepare=testsuite_skip_prepare,
826 )
827
828 async def read_cache_dumps(
829 self,
830 names: list[str],
831 *,
832 testsuite_skip_prepare=False,
833 ) -> None:
834 await self._testsuite_action(
835 'read_cache_dumps',
836 names=names,
837 testsuite_skip_prepare=testsuite_skip_prepare,
838 )
839
840 async def run_distlock_task(self, name: str) -> None:
841 await self.run_task(f'distlock/{name}')
842
843 async def reset_metrics(self) -> None:
844 await self._testsuite_action('reset_metrics')
845
846 async def metrics_portability(
847 self,
848 *,
849 prefix: str | None = None,
850 ) -> dict[str, list[dict[str, str]]]:
851 return await self._testsuite_action(
852 'metrics_portability',
853 prefix=prefix,
854 )
855
856 async def list_tasks(self) -> list[str]:
857 response = await self._do_testsuite_action('tasks_list')
858 async with response:
859 response.raise_for_status()
860 body = await response.json(content_type=None)
861 return body['tasks']
862
863 async def run_task(self, name: str) -> None:
864 response = await self._do_testsuite_action(
865 'task_run',
866 json={'name': name},
867 )
868 await _task_check_response(name, response)
869
870 @contextlib.asynccontextmanager
871 async def spawn_task(self, name: str):
872 task_id = await self._task_spawn(name)
873 try:
874 yield
875 finally:
876 await self._task_stop_spawned(task_id)
877
878 async def _task_spawn(self, name: str) -> str:
879 response = await self._do_testsuite_action(
880 'task_spawn',
881 json={'name': name},
882 )
883 data = await _task_check_response(name, response)
884 return data['task_id']
885
886 async def _task_stop_spawned(self, task_id: str) -> None:
887 response = await self._do_testsuite_action(
888 'task_stop',
889 json={'task_id': task_id},
890 )
891 await _task_check_response(task_id, response)
892
893 async def http_allowed_urls_extra(
894 self,
895 http_allowed_urls_extra: list[str],
896 ) -> None:
897 await self._do_testsuite_action(
898 'http_allowed_urls_extra',
899 json={'allowed_urls_extra': http_allowed_urls_extra},
900 testsuite_skip_prepare=True,
901 )
902
903 @contextlib.asynccontextmanager
904 async def capture_logs(
905 self,
906 *,
907 log_level: str = 'DEBUG',
908 testsuite_skip_prepare: bool = False,
909 ):
910 async with self._log_capture_fixture.capture(
911 log_level=logcapture.LogLevel.from_string(log_level),
912 ) as capture:
913 logger.debug('Starting logcapture')
914 await self._testsuite_action(
915 'log_capture',
916 log_level=log_level,
917 socket_logging_duplication=True,
918 testsuite_skip_prepare=testsuite_skip_prepare,
919 )
920
921 try:
922 await self._log_capture_fixture.wait_for_client()
923 yield capture
924 finally:
925 await self._testsuite_action(
926 'log_capture',
927 log_level=self._log_capture_fixture.default_log_level.name,
928 socket_logging_duplication=False,
929 testsuite_skip_prepare=testsuite_skip_prepare,
930 )
931
932 async def log_flush(self, logger_name: str | None = None):
933 await self._testsuite_action(
934 'log_flush',
935 logger_name=logger_name,
936 testsuite_skip_prepare=True,
937 )
938
939 async def invalidate_caches(
940 self,
941 *,
942 clean_update: bool = True,
943 cache_names: list[str] | None = None,
944 testsuite_skip_prepare: bool = False,
945 ) -> None:
946 if cache_names is None and clean_update:
947 if self._allow_all_caches_invalidation:
948 warnings.warn(CACHE_INVALIDATION_MESSAGE, DeprecationWarning, stacklevel=2)
949 else:
950 __tracebackhide__ = True
951 raise RuntimeError(CACHE_INVALIDATION_MESSAGE)
952
953 if testsuite_skip_prepare:
954 await self._tests_control({
955 'invalidate_caches': {
956 'update_type': ('full' if clean_update else 'incremental'),
957 **({'names': cache_names} if cache_names else {}),
958 },
959 })
960 else:
961 await self.tests_control(
962 invalidate_caches=True,
963 clean_update=clean_update,
964 cache_names=cache_names,
965 )
966
967 async def tests_control(
968 self,
969 *,
970 invalidate_caches: bool = True,
971 clean_update: bool = True,
972 cache_names: list[str] | None = None,
973 http_allowed_urls_extra: list[str] | None = None,
974 ) -> dict[str, Any]:
975 body: dict[str, Any] = self._state_manager.get_pending_update()
976
977 if 'invalidate_caches' in body and invalidate_caches:
978 if not clean_update or cache_names:
979 logger.warning(
980 'Manual cache invalidation leads to indirect initial full cache invalidation',
981 )
982 await self._prepare()
983 body = {}
984
985 if invalidate_caches:
986 body['invalidate_caches'] = {
987 'update_type': ('full' if clean_update else 'incremental'),
988 }
989 if cache_names:
990 body['invalidate_caches']['names'] = cache_names
991
992 if http_allowed_urls_extra is not None:
993 await self.http_allowed_urls_extra(http_allowed_urls_extra)
994
995 return await self._tests_control(body)
996
997 async def update_server_state(self) -> None:
998 await self._prepare()
999
1000 async def enable_testpoints(self, *, no_auto_cache_cleanup=False) -> None:
1001 if not self._testpoint:
1002 return
1003 if no_auto_cache_cleanup:
1004 await self._tests_control({
1005 'testpoints': sorted(self._testpoint.keys()),
1006 })
1007 else:
1008 await self.update_server_state()
1009
1010 async def get_dynamic_config_defaults(
1011 self,
1012 ) -> dict[str, Any]:
1013 return await self._testsuite_action(
1014 'get_dynamic_config_defaults',
1015 testsuite_skip_prepare=True,
1016 )
1017
1018 async def _tests_control(self, body: dict) -> dict[str, Any]:
1019 with self._state_manager.updating_state(body):
1020 async with await self._do_testsuite_action(
1021 'control',
1022 json=body,
1023 testsuite_skip_prepare=True,
1024 ) as response:
1025 if response.status == 404:
1026 raise ConfigurationError(
1027 'It seems that testsuite support is not enabled for your service',
1028 )
1029 response.raise_for_status()
1030 return await response.json(content_type=None)
1031
1032 async def _suspend_periodic_tasks(self):
1033 if self._periodic_tasks.tasks_to_suspend != self._periodic_tasks.suspended_tasks:
1034 await self._testsuite_action(
1035 'suspend_periodic_tasks',
1036 names=sorted(self._periodic_tasks.tasks_to_suspend),
1037 )
1038 self._periodic_tasks.suspended_tasks = set(
1039 self._periodic_tasks.tasks_to_suspend,
1040 )
1041
1042 def _do_testsuite_action(self, action, **kwargs):
1043 if not self._config.testsuite_action_path:
1044 raise ConfigurationError(
1045 'tests-control component is not properly configured',
1046 )
1047 path = self._config.testsuite_action_path.format(action=action)
1048 return self.post(path, **kwargs)
1049
1050 async def _testsuite_action(
1051 self,
1052 action,
1053 *,
1054 testsuite_skip_prepare=False,
1055 **kwargs,
1056 ):
1057 async with await self._do_testsuite_action(
1058 action,
1059 json=kwargs,
1060 testsuite_skip_prepare=testsuite_skip_prepare,
1061 ) as response:
1062 if response.status == 500:
1063 raise TestsuiteActionFailed
1064 response.raise_for_status()
1065 return await response.json(content_type=None)
1066
1067 async def _prepare(self) -> None:
1068 with self._state_manager.cache_control_update() as pending_update:
1069 if pending_update:
1070 await self._tests_control(pending_update)
1071
1072 async def _request( # pylint: disable=arguments-differ
1073 self,
1074 http_method: str,
1075 path: str,
1076 headers: dict[str, str] | None = None,
1077 bearer: str | None = None,
1078 x_real_ip: str | None = None,
1079 *,
1080 testsuite_skip_prepare: bool = False,
1081 **kwargs,
1082 ) -> aiohttp.ClientResponse:
1083 if self._asyncexc_check:
1084 # Check for pending background exceptions before call.
1085 self._asyncexc_check()
1086
1087 if not testsuite_skip_prepare:
1088 await self._prepare()
1089
1090 response = await super()._request(
1091 http_method,
1092 path,
1093 headers,
1094 bearer,
1095 x_real_ip,
1096 **kwargs,
1097 )
1098 if self._api_coverage_report:
1099 self._api_coverage_report.update_usage_stat(
1100 path,
1101 http_method,
1102 response.status,
1103 response.content_type,
1104 )
1105
1106 if self._asyncexc_check:
1107 # Check for pending background exceptions after call.
1108 self._asyncexc_check()
1109
1110 return response
1111
1112
1113# @endcond
1114
1115
1117 """
1118 Asyncio userver client, typically retrieved from
1119 @ref service_client "plugins.service_client.service_client"
1120 fixture.
1121
1122 Compatible with werkzeug interface.
1123
1124 @ingroup userver_testsuite
1125 """
1126
1127 PeriodicTaskFailed = PeriodicTaskFailed
1128 TestsuiteActionFailed = TestsuiteActionFailed
1129 TestsuiteTaskNotFound = TestsuiteTaskNotFound
1130 TestsuiteTaskConflict = TestsuiteTaskConflict
1131 TestsuiteTaskFailed = TestsuiteTaskFailed
1132
1133 def _wrap_client_response(self, response: aiohttp.ClientResponse) -> Awaitable[http.ClientResponse]:
1134 return http.wrap_client_response(
1135 response,
1136 json_loads=approx.json_loads,
1137 )
1138
1139 @_wrap_client_error
1140 async def run_periodic_task(self, name):
1141 await self._client.run_periodic_task(name)
1142
1143 @_wrap_client_error
1144 async def suspend_periodic_tasks(self, names: list[str]) -> None:
1145 await self._client.suspend_periodic_tasks(names)
1146
1147 @_wrap_client_error
1148 async def resume_periodic_tasks(self, names: list[str]) -> None:
1149 await self._client.resume_periodic_tasks(names)
1150
1151 @_wrap_client_error
1152 async def resume_all_periodic_tasks(self) -> None:
1153 await self._client.resume_all_periodic_tasks()
1154
1155 @_wrap_client_error
1156 async def write_cache_dumps(
1157 self,
1158 names: list[str],
1159 *,
1160 testsuite_skip_prepare=False,
1161 ) -> None:
1162 await self._client.write_cache_dumps(
1163 names=names,
1164 testsuite_skip_prepare=testsuite_skip_prepare,
1165 )
1166
1167 @_wrap_client_error
1168 async def read_cache_dumps(
1169 self,
1170 names: list[str],
1171 *,
1172 testsuite_skip_prepare=False,
1173 ) -> None:
1174 await self._client.read_cache_dumps(
1175 names=names,
1176 testsuite_skip_prepare=testsuite_skip_prepare,
1177 )
1178
1179 async def run_task(self, name: str) -> None:
1180 await self._client.run_task(name)
1181
1182 async def run_distlock_task(self, name: str) -> None:
1183 await self._client.run_distlock_task(name)
1184
1185 async def reset_metrics(self) -> None:
1186 """
1187 Calls `ResetMetric(metric);` for each metric that has such C++ function
1188 """
1189 await self._client.reset_metrics()
1190
1192 self,
1193 *,
1194 prefix: str | None = None,
1195 ) -> dict[str, list[dict[str, str]]]:
1196 """
1197 Reports metrics related issues that could be encountered on
1198 different monitoring systems.
1199
1200 @sa @ref utils::statistics::GetPortabilityWarnings
1201 """
1202 return await self._client.metrics_portability(prefix=prefix)
1203
1204 def list_tasks(self) -> list[str]:
1205 return self._client.list_tasks()
1206
1207 def spawn_task(self, name: str):
1208 return self._client.spawn_task(name)
1209
1211 self,
1212 *,
1213 log_level: str = 'DEBUG',
1214 testsuite_skip_prepare: bool = False,
1215 ):
1216 """
1217 Captures logs from the service.
1218
1219 @param log_level Do not capture logs below this level.
1220 @param testsuite_skip_prepare An advanced parameter to skip auto-`update_server_state`.
1221
1222 @see @ref testsuite_logs_capture
1223 """
1224 return self._client.capture_logs(
1225 log_level=log_level,
1226 testsuite_skip_prepare=testsuite_skip_prepare,
1227 )
1228
1229 def log_flush(self, logger_name: str | None = None):
1230 """
1231 Flush service logs.
1232 """
1233 return self._client.log_flush(logger_name=logger_name)
1234
1235 @_wrap_client_error
1237 self,
1238 *,
1239 clean_update: bool = True,
1240 cache_names: list[str] | None = None,
1241 testsuite_skip_prepare: bool = False,
1242 ) -> None:
1243 """
1244 Send request to service to update caches.
1245
1246 @param clean_update if False, service will do a faster incremental
1247 update of caches whenever possible.
1248 @param cache_names which caches specifically should be updated;
1249 update all if None.
1250 @param testsuite_skip_prepare if False, service will automatically do
1251 update_server_state().
1252 """
1253 __tracebackhide__ = True
1254 await self._client.invalidate_caches(
1255 clean_update=clean_update,
1256 cache_names=cache_names,
1257 testsuite_skip_prepare=testsuite_skip_prepare,
1258 )
1259
1260 @_wrap_client_error
1261 async def tests_control(
1262 self,
1263 invalidate_caches: bool = True,
1264 clean_update: bool = True,
1265 cache_names: list[str] | None = None,
1266 http_allowed_urls_extra: list[str] | None = None,
1267 ) -> dict[str, Any]:
1268 return await self._client.tests_control(
1269 invalidate_caches=invalidate_caches,
1270 clean_update=clean_update,
1271 cache_names=cache_names,
1272 http_allowed_urls_extra=http_allowed_urls_extra,
1273 )
1274
1275 @_wrap_client_error
1276 async def update_server_state(self) -> None:
1277 """
1278 Update service-side state through http call to 'tests/control':
1279 - clear dirty (from other tests) caches
1280 - set service-side mocked time,
1281 - resume / suspend periodic tasks
1282 - enable testpoints
1283 If service is up-to-date, does nothing.
1284 """
1285 await self._client.update_server_state()
1286
1287 @_wrap_client_error
1288 async def enable_testpoints(self, no_auto_cache_cleanup: bool = False) -> None:
1289 """
1290 Send list of handled testpoint pats to service. For these paths service
1291 will no more skip http calls from TESTPOINT(...) macro.
1292
1293 @param no_auto_cache_cleanup prevent automatic cache cleanup.
1294 When calling service client first time in scope of current test, client
1295 makes additional http call to `tests/control` to update caches, to get
1296 rid of data from previous test.
1297 """
1298 await self._client.enable_testpoints(no_auto_cache_cleanup=no_auto_cache_cleanup)
1299
1300 @_wrap_client_error
1301 async def get_dynamic_config_defaults(
1302 self,
1303 ) -> dict[str, Any]:
1304 return await self._client.get_dynamic_config_defaults()
1305
1306
1307@dataclasses.dataclass
1309 """Reflects the (supposed) current service state."""
1310
1311 invalidation_state: caches.InvalidationState
1312 now: str | None = _UNKNOWN_STATE
1313 testpoints: frozenset[str] = frozenset([_UNKNOWN_STATE])
1314
1315
1317 """
1318 Used for computing the requests that we need to automatically align
1319 the service state with the test fixtures state.
1320 """
1321
1322 def __init__(
1323 self,
1324 *,
1325 mocked_time,
1326 testpoint,
1327 testpoint_control,
1328 invalidation_state: caches.InvalidationState,
1329 cache_control: caches.CacheControl | None,
1330 ):
1331 self._state = _State(
1332 invalidation_state=copy.deepcopy(invalidation_state),
1333 )
1334 self._mocked_time = mocked_time
1335 self._testpoint = testpoint
1336 self._testpoint_control = testpoint_control
1337 self._invalidation_state = invalidation_state
1338 self._cache_control = cache_control
1339
1340 @contextlib.contextmanager
1341 def updating_state(self, body: dict[str, Any]):
1342 """
1343 Whenever `tests_control` handler is invoked
1344 (by the client itself during `prepare` or manually by the user),
1345 we need to synchronize `_state` with the (supposed) service state.
1346 The state update is decoded from the request body.
1347 """
1348 saved_state = copy.deepcopy(self._state)
1349 try:
1350 self._update_state(body)
1351 self._apply_new_state()
1352 yield
1353 except Exception: # noqa
1354 self._state = saved_state
1355 self._apply_new_state()
1356 raise
1357
1358 def get_pending_update(self) -> dict[str, Any]:
1359 """
1360 Compose the body of the `tests_control` request required to completely
1361 synchronize the service state with the state of test fixtures.
1362 """
1363 body: dict[str, Any] = {}
1364
1365 if self._invalidation_state.has_caches_to_update:
1366 body['invalidate_caches'] = {'update_type': 'full'}
1367 if not self._invalidation_state.should_update_all_caches:
1368 body['invalidate_caches']['names'] = list(
1369 self._invalidation_state.caches_to_update,
1370 )
1371
1372 desired_testpoints = self._testpoint.keys()
1373 if self._state.testpoints != frozenset(desired_testpoints):
1374 body['testpoints'] = sorted(desired_testpoints)
1375
1376 desired_now = self._get_desired_now()
1377 if self._state.now != desired_now:
1378 body['mock_now'] = desired_now
1379
1380 return body
1381
1382 @contextlib.contextmanager
1383 def cache_control_update(self) -> Iterator[dict[Any, Any]]:
1384 pending_update = self.get_pending_update()
1385 invalidate_caches = pending_update.get('invalidate_caches')
1386 if not invalidate_caches or not self._cache_control:
1387 yield pending_update
1388 else:
1389 cache_names = invalidate_caches.get('names')
1390 staged, actions = self._cache_control.query_caches(cache_names)
1391 self._apply_cache_control_actions(invalidate_caches, actions)
1392 yield pending_update
1393 self._cache_control.commit_staged(staged)
1394
1395 @staticmethod
1396 def _apply_cache_control_actions(
1397 invalidate_caches: dict[Any, Any],
1398 actions: list[tuple[str, caches.CacheControlAction]],
1399 ) -> None:
1400 cache_names = invalidate_caches.get('names')
1401 exclude_names = invalidate_caches.setdefault('exclude_names', [])
1402 force_incremental_names = invalidate_caches.setdefault(
1403 'force_incremental_names',
1404 [],
1405 )
1406 for cache_name, action in actions:
1407 match action:
1408 case caches.CacheControlAction.FULL:
1409 pass
1410 case caches.CacheControlAction.INCREMENTAL:
1411 force_incremental_names.append(cache_name)
1412 case caches.CacheControlAction.EXCLUDE:
1413 if cache_names is not None:
1414 cache_names.remove(cache_name)
1415 else:
1416 exclude_names.append(cache_name)
1417
1418 def _update_state(self, body: dict[str, Any]) -> None:
1419 body_invalidate_caches = body.get('invalidate_caches', {})
1420 update_type = body_invalidate_caches.get('update_type', 'full')
1421 body_cache_names = body_invalidate_caches.get('names', None)
1422 # An incremental update is considered insufficient to bring a cache
1423 # to a known state.
1424 if body_invalidate_caches and update_type == 'full':
1425 if body_cache_names is None:
1426 self._state.invalidation_state.on_all_caches_updated()
1427 else:
1428 self._state.invalidation_state.on_caches_updated(
1429 body_cache_names,
1430 )
1431
1432 if 'mock_now' in body:
1433 self._state.now = body['mock_now']
1434
1435 testpoints: list[str] | None = body.get('testpoints')
1436 if testpoints is not None:
1437 self._state.testpoints = frozenset(testpoints)
1438
1440 """Apply new state to related components."""
1441 self._testpoint_control.enabled_testpoints = self._state.testpoints
1442 self._invalidation_state.assign_copy(self._state.invalidation_state)
1443
1444 def _get_desired_now(self) -> str | None:
1445 if self._mocked_time.is_enabled:
1446 return utils.timestring(self._mocked_time.now())
1447 return None
1448
1449
1450async def _task_check_response(name: str, response) -> dict:
1451 async with response:
1452 if response.status == 404:
1453 raise TestsuiteTaskNotFound(f'Testsuite task {name!r} not found')
1454 if response.status == 409:
1455 raise TestsuiteTaskConflict(f'Testsuite task {name!r} conflict')
1456 assert response.status == 200
1457 data = await response.json()
1458 if not data.get('status', True):
1459 raise TestsuiteTaskFailed(name, data['reason'])
1460 return data