userver: /home/user/userver/testsuite/pytest_plugins/pytest_userver/client.py Source File
Loading...
Searching...
No Matches
client.py
1"""
2Python module that provides clients for functional tests with
3testsuite; see
4@ref scripts/docs/en/userver/functional_testing.md for an introduction.
5
6@ingroup userver_testsuite
7"""
8
9# pylint: disable=too-many-lines
10
11from __future__ import annotations
12
13from collections.abc import Awaitable
14from collections.abc import Iterator
15import contextlib
16import copy
17import dataclasses
18import json
19import logging
20import typing
21from typing import Any
22from typing import TypeAlias
23import warnings
24
25import aiohttp
26
27from testsuite import logcapture
28from testsuite import utils
29from testsuite.daemons import service_client
30from testsuite.utils import approx
31from testsuite.utils import http
32
33from pytest_userver import userver_warnings
34import pytest_userver.metrics # pylint: disable=import-error
35from pytest_userver.plugins import caches
36
37# @cond
38logger = logging.getLogger(__name__)
39# @endcond
40
41JsonAny: TypeAlias = int | float | str | list | dict
42JsonAnyOptional: TypeAlias = JsonAny | None
43
44_UNKNOWN_STATE = '__UNKNOWN__'
45
46CACHE_INVALIDATION_MESSAGE = (
47 'Direct cache invalidation is deprecated.\n'
48 '\n'
49 ' - Use client.update_server_state() to synchronize service state\n'
50 ' - Explicitly pass cache names to invalidate, e.g.: '
51 'invalidate_caches(cache_names=[...]).'
52)
53
54
55class BaseError(Exception):
56 """Base class for exceptions of this module."""
57
58
60 pass
61
62
64 pass
65
66
76 pass
77
78
80 pass
81
82
84 def __init__(self) -> None:
85 self.suspended_tasks: set[str] = set()
86 self.tasks_to_suspend: set[str] = set()
87
88
90 def __init__(self, name, reason):
91 self.name = name
92 self.reason = reason
93 super().__init__(f'Testsuite task {name!r} failed: {reason}')
94
95
96@dataclasses.dataclass(frozen=True)
98 testsuite_action_path: str | None = None
99 server_monitor_path: str | None = None
100
101
102Metric: TypeAlias = pytest_userver.metrics.Metric
103
104
106 """
107 Base asyncio userver client that implements HTTP requests to service.
108
109 Compatible with werkzeug interface.
110
111 @ingroup userver_testsuite
112 """
113
114 def __init__(self, client):
115 self._client = client
116
117 async def post(
118 self,
119 path: str,
120 # pylint: disable=redefined-outer-name
121 json: JsonAnyOptional = None,
122 data: Any = None,
123 params: dict[str, str] | None = None,
124 bearer: str | None = None,
125 x_real_ip: str | None = None,
126 headers: dict[str, str] | None = None,
127 **kwargs,
128 ) -> http.ClientResponse:
129 """
130 Make a HTTP POST request
131 """
132 response = await self._client.post(
133 path,
134 json=json,
135 data=data,
136 params=params,
137 headers=headers,
138 bearer=bearer,
139 x_real_ip=x_real_ip,
140 **kwargs,
141 )
142 return await self._wrap_client_response(response)
143
144 async def put(
145 self,
146 path,
147 # pylint: disable=redefined-outer-name
148 json: JsonAnyOptional = None,
149 data: Any = None,
150 params: dict[str, str] | None = None,
151 bearer: str | None = None,
152 x_real_ip: str | None = None,
153 headers: dict[str, str] | None = None,
154 **kwargs,
155 ) -> http.ClientResponse:
156 """
157 Make a HTTP PUT request
158 """
159 response = await self._client.put(
160 path,
161 json=json,
162 data=data,
163 params=params,
164 headers=headers,
165 bearer=bearer,
166 x_real_ip=x_real_ip,
167 **kwargs,
168 )
169 return await self._wrap_client_response(response)
170
171 async def patch(
172 self,
173 path,
174 # pylint: disable=redefined-outer-name
175 json: JsonAnyOptional = None,
176 data: Any = None,
177 params: dict[str, str] | None = None,
178 bearer: str | None = None,
179 x_real_ip: str | None = None,
180 headers: dict[str, str] | None = None,
181 **kwargs,
182 ) -> http.ClientResponse:
183 """
184 Make a HTTP PATCH request
185 """
186 response = await self._client.patch(
187 path,
188 json=json,
189 data=data,
190 params=params,
191 headers=headers,
192 bearer=bearer,
193 x_real_ip=x_real_ip,
194 **kwargs,
195 )
196 return await self._wrap_client_response(response)
197
198 async def get(
199 self,
200 path: str,
201 headers: dict[str, str] | None = None,
202 bearer: str | None = None,
203 x_real_ip: str | None = None,
204 **kwargs,
205 ) -> http.ClientResponse:
206 """
207 Make a HTTP GET request
208 """
209 response = await self._client.get(
210 path,
211 headers=headers,
212 bearer=bearer,
213 x_real_ip=x_real_ip,
214 **kwargs,
215 )
216 return await self._wrap_client_response(response)
217
218 async def delete(
219 self,
220 path: str,
221 headers: dict[str, str] | None = None,
222 bearer: str | None = None,
223 x_real_ip: str | None = None,
224 **kwargs,
225 ) -> http.ClientResponse:
226 """
227 Make a HTTP DELETE request
228 """
229 response = await self._client.delete(
230 path,
231 headers=headers,
232 bearer=bearer,
233 x_real_ip=x_real_ip,
234 **kwargs,
235 )
236 return await self._wrap_client_response(response)
237
238 async def options(
239 self,
240 path: str,
241 headers: dict[str, str] | None = None,
242 bearer: str | None = None,
243 x_real_ip: str | None = None,
244 **kwargs,
245 ) -> http.ClientResponse:
246 """
247 Make a HTTP OPTIONS request
248 """
249 response = await self._client.options(
250 path,
251 headers=headers,
252 bearer=bearer,
253 x_real_ip=x_real_ip,
254 **kwargs,
255 )
256 return await self._wrap_client_response(response)
257
258 async def request(
259 self,
260 http_method: str,
261 path: str,
262 **kwargs,
263 ) -> http.ClientResponse:
264 """
265 Make a HTTP request with the specified method
266 """
267 response = await self._client.request(http_method, path, **kwargs)
268 return await self._wrap_client_response(response)
269
270 @property
272 """
273 @deprecated Use pytest_userver.client.Client directly instead.
274 """
275 return self._client
276
277 def _wrap_client_response(self, response: aiohttp.ClientResponse) -> Awaitable[http.ClientResponse]:
278 return http.wrap_client_response(response)
279
280
281# @cond
282
283
284def _wrap_client_error(func):
285 async def _wrapper(*args, **kwargs):
286 try:
287 return await func(*args, **kwargs)
288 except aiohttp.client_exceptions.ClientResponseError as exc:
289 raise http.HttpResponseError(
290 url=exc.request_info.url,
291 status=exc.status,
292 )
293
294 return _wrapper
295
296
297class AiohttpClientMonitor(service_client.AiohttpClient):
298 _config: TestsuiteClientConfig
299
300 def __init__(self, base_url, *, config: TestsuiteClientConfig, **kwargs):
301 super().__init__(base_url, **kwargs)
302 self._config = config
303
304 async def get_metrics(self, prefix=None):
305 if not self._config.server_monitor_path:
306 raise ConfigurationError(
307 'handler-server-monitor component is not configured',
308 )
309 params = {'format': 'internal'}
310 if prefix is not None:
311 params['prefix'] = prefix
312 response = await self.get(
313 self._config.server_monitor_path,
314 params=params,
315 )
316 async with response:
317 response.raise_for_status()
318 return await response.json(content_type=None)
319
320 async def get_metric(self, metric_name):
321 metrics = await self.get_metrics(metric_name)
322 assert metric_name in metrics, (
323 f'No metric with name {metric_name!r}. Use "single_metric" function instead of "get_metric"'
324 )
325 return metrics[metric_name]
326
327 async def metrics_raw(
328 self,
329 output_format,
330 *,
331 path: str = None,
332 prefix: str = None,
333 labels: dict[str, str] | None = None,
334 ) -> str:
335 if not self._config.server_monitor_path:
336 raise ConfigurationError(
337 'handler-server-monitor component is not configured',
338 )
339
340 params = {'format': output_format}
341 if prefix:
342 params['prefix'] = prefix
343
344 if path:
345 params['path'] = path
346
347 if labels:
348 params['labels'] = json.dumps(labels)
349
350 response = await self.get(
351 self._config.server_monitor_path,
352 params=params,
353 )
354 async with response:
355 response.raise_for_status()
356 return await response.text()
357
358 async def metrics(
359 self,
360 *,
361 path: str = None,
362 prefix: str = None,
363 labels: dict[str, str] | None = None,
365 response = await self.metrics_raw(
366 output_format='json',
367 path=path,
368 prefix=prefix,
369 labels=labels,
370 )
372
373 async def single_metric_optional(
374 self,
375 path: str,
376 *,
377 labels: dict[str, str] | None = None,
379 response = await self.metrics(path=path, labels=labels)
380 metrics_list = response.get(path, [])
381
382 assert len(metrics_list) <= 1, (f'More than one metric found for path {path} and labels {labels}: {response}',)
383
384 if not metrics_list:
385 return None
386
387 return next(iter(metrics_list))
388
389 async def single_metric(
390 self,
391 path: str,
392 *,
393 labels: dict[str, str] | None = None,
395 value = await self.single_metric_optional(path, labels=labels)
396 assert value is not None, (f'No metric was found for path {path} and labels {labels}',)
397 return value
398
399
400# @endcond
401
402
404 """
405 Asyncio userver client for monitor listeners, typically retrieved from
406 plugins.service_client.monitor_client fixture.
407
408 Compatible with werkzeug interface.
409
410 @ingroup userver_testsuite
411 """
412
414 self,
415 *,
416 path: str | None = None,
417 prefix: str | None = None,
418 labels: dict[str, str] | None = None,
419 diff_gauge: bool = False,
420 ) -> MetricsDiffer:
421 """
422 Creates a `MetricsDiffer` that fetches metrics using this client.
423 It's recommended to use this method over `metrics` to make sure
424 the tests don't affect each other.
425
426 With `diff_gauge` off, only `RATE` metrics are differentiated.
427 With `diff_gauge` on, `GAUGE` metrics are differentiated as well,
428 which may lead to nonsensical results for those.
429
430 @param path Optional full metric path
431 @param prefix Optional prefix on which the metric paths should start
432 @param labels Optional dictionary of labels that must be in the metric
433 @param diff_gauge Whether to differentiate GAUGE metrics
434
435 @snippet samples/testsuite-support/tests/test_metrics.py metrics diff
436 """
437 return MetricsDiffer(
438 _client=self,
439 _path=path,
440 _prefix=prefix,
441 _labels=labels,
442 _diff_gauge=diff_gauge,
443 )
444
445 @_wrap_client_error
446 async def metrics(
447 self,
448 *,
449 path: str | None = None,
450 prefix: str | None = None,
451 labels: dict[str, str] | None = None,
453 """
454 Returns a dict of metric names to Metric.
455
456 @param path Optional full metric path
457 @param prefix Optional prefix on which the metric paths should start
458 @param labels Optional dictionary of labels that must be in the metric
459
460 @snippet samples/testsuite-support/tests/test_metrics.py metrics metrics
461 """
462 return await self._client.metrics(
463 path=path,
464 prefix=prefix,
465 labels=labels,
466 )
467
468 @_wrap_client_error
470 self,
471 path: str,
472 *,
473 labels: dict[str, str] | None = None,
475 """
476 Either return a pytest_userver.metrics.Metric or None if there's no such metric.
477
478 @param path Full metric path
479 @param labels Optional dictionary of labels that must be in the metric
480
481 @throws AssertionError if more than one metric returned
482
483 @snippet samples/testsuite-support/tests/test_metrics.py metrics single_metric_optional
484 """
485 return await self._client.single_metric_optional(path, labels=labels)
486
487 @_wrap_client_error
488 async def single_metric(
489 self,
490 path: str,
491 *,
492 labels: dict[str, str] | None = None,
494 """
495 Returns the pytest_userver.metrics.Metric.
496
497 @param path Full metric path
498 @param labels Optional dictionary of labels that must be in the metric
499
500 @throws AssertionError if more than one metric or no metric found
501
502 @snippet samples/testsuite-support/tests/test_metrics.py metrics single_metric
503 """
504 return await self._client.single_metric(path, labels=labels)
505
506 @_wrap_client_error
507 async def metrics_raw(
508 self,
509 output_format: str,
510 *,
511 path: str | None = None,
512 prefix: str | None = None,
513 labels: dict[str, str] | None = None,
514 ) -> dict[str, pytest_userver.metrics.Metric]:
515 """
516 Low level function that returns metrics in a specific format.
517 Use `metrics` and `single_metric` instead if possible.
518
519 @param output_format pytest_userver.metrics.Metric output format. See
520 @ref server::handlers::ServerMonitor for a list of supported formats.
521 @param path Optional full metric path
522 @param prefix Optional prefix on which the metric paths should start
523 @param labels Optional dictionary of labels that must be in the metric
524 """
525 return await self._client.metrics_raw(
526 output_format=output_format,
527 path=path,
528 prefix=prefix,
529 labels=labels,
530 )
531
532 @_wrap_client_error
533 async def get_metrics(self, prefix=None):
534 """
535 @deprecated Use metrics() or single_metric() instead
536 """
537 return await self._client.get_metrics(prefix=prefix)
538
539 @_wrap_client_error
540 async def get_metric(self, metric_name):
541 """
542 @deprecated Use metrics() or single_metric() instead
543 """
544 return await self._client.get_metric(metric_name)
545
546 @_wrap_client_error
547 async def fired_alerts(self):
548 response = await self._client.get('/service/fired-alerts')
549 assert response.status == 200
550 return (await response.json())['alerts']
551
552
554 """
555 A helper class for computing metric differences.
556
557 @see ClientMonitor.metrics_diff
558
559 Example with @ref pytest_userver.client.ClientMonitor.metrics_diff "await monitor_client.metrics_diff()":
560 @snippet samples/testsuite-support/tests/test_metrics.py metrics diff
561
562 @ingroup userver_testsuite
563 """
564
565 # @cond
566 def __init__(
567 self,
568 _client: ClientMonitor,
569 _path: str | None,
570 _prefix: str | None,
571 _labels: dict[str, str] | None,
572 _diff_gauge: bool,
573 ):
574 self._client = _client
575 self._path = _path
576 self._prefix = _prefix
577 self._labels = _labels
578 self._diff_gauge = _diff_gauge
582
583 # @endcond
584
585 @property
586 def baseline(self) -> pytest_userver.metrics.MetricsSnapshot:
587 assert self._baseline is not None
588 return self._baseline
589
590 @baseline.setter
591 def baseline(self, value: pytest_userver.metrics.MetricsSnapshot) -> None:
592 self._baseline = value
593 if self._current is not None:
594 self._diff = _subtract_metrics_snapshots(
595 self._current,
596 self._baseline,
597 self._diff_gauge,
598 )
599
600 @property
601 def current(self) -> pytest_userver.metrics.MetricsSnapshot:
602 assert self._current is not None, 'Set self.current first'
603 return self._current
604
605 @current.setter
606 def current(self, value: pytest_userver.metrics.MetricsSnapshot) -> None:
607 self._current = value
608 assert self._baseline is not None, 'Set self.baseline first'
609 self._diff = _subtract_metrics_snapshots(
610 self._current,
611 self._baseline,
612 self._diff_gauge,
613 )
614
615 @property
616 def diff(self) -> pytest_userver.metrics.MetricsSnapshot:
617 assert self._diff is not None, 'Set self.current first'
618 return self._diff
619
621 self,
622 subpath: str | None = None,
623 add_labels: dict[str, str] | None = None,
624 *,
625 default: float | None = None,
626 ) -> pytest_userver.metrics.MetricValue:
627 """
628 Returns a single metric value at the specified path, prepending
629 the path provided at construction. If a dict of labels is provided,
630 does en exact match of labels, prepending the labels provided at construction.
631
632 @param subpath Suffix of the metric path; the path provided at construction is prepended
633 @param add_labels Labels that the metric must have in addition to the labels provided at construction
634 @param default An optional default value in case the metric is missing
635 @throws AssertionError if not one metric by path
636 """
637 base_path = self._path or self._prefix
638 if base_path and subpath:
639 path = f'{base_path}.{subpath}'
640 else:
641 assert base_path or subpath, 'No path provided'
642 path = base_path or subpath or ''
643 labels: dict | None = None
644 if self._labels is not None or add_labels is not None:
645 labels = {**(self._labels or {}), **(add_labels or {})}
646 return self.diff.value_at(path, labels, default=default)
647
648 async def fetch(self) -> pytest_userver.metrics.MetricsSnapshot:
649 """
650 Fetches metric values from the service.
651 """
652 return await self._client.metrics(
653 path=self._path,
654 prefix=self._prefix,
655 labels=self._labels,
656 )
657
658 async def __aenter__(self) -> MetricsDiffer:
659 self._baseline = await self.fetch()
660 self._current = None
661 return self
662
663 async def __aexit__(self, exc_type, exc, exc_tb) -> None:
664 self.current = await self.fetch()
665
666
667# @cond
668
669
670def _subtract_metrics_snapshots(
671 current: pytest_userver.metrics.MetricsSnapshot,
673 diff_gauge: bool,
676 path: {_subtract_metrics(path, current_metric, initial, diff_gauge) for current_metric in current_group}
677 for path, current_group in current.items()
678 })
679
680
681def _subtract_metrics(
682 path: str,
683 current_metric: pytest_userver.metrics.Metric,
685 diff_gauge: bool,
687 initial_group = initial.get(path, None)
688 if initial_group is None:
689 return current_metric
690 initial_metric = next(
691 (x for x in initial_group if x.labels == current_metric.labels),
692 None,
693 )
694 if initial_metric is None:
695 return current_metric
696
698 labels=current_metric.labels,
699 value=_subtract_metric_values(
700 current=current_metric,
701 initial=initial_metric,
702 diff_gauge=diff_gauge,
703 ),
704 _type=current_metric.type(),
705 )
706
707
708def _subtract_metric_values(
709 current: pytest_userver.metrics.Metric,
711 diff_gauge: bool,
712) -> pytest_userver.metrics.MetricValue:
713 assert current.type() is not pytest_userver.metrics.MetricType.UNSPECIFIED
714 assert initial.type() is not pytest_userver.metrics.MetricType.UNSPECIFIED
715 assert current.type() == initial.type()
716
717 if isinstance(current.value, pytest_userver.metrics.Histogram):
718 assert isinstance(initial.value, pytest_userver.metrics.Histogram)
719 return _subtract_metric_values_hist(current=current, initial=initial)
720 else:
721 assert not isinstance(initial.value, pytest_userver.metrics.Histogram)
722 return _subtract_metric_values_num(
723 current=current,
724 initial=initial,
725 diff_gauge=diff_gauge,
726 )
727
728
729def _subtract_metric_values_num(
730 current: pytest_userver.metrics.Metric,
732 diff_gauge: bool,
733) -> float:
734 current_value = typing.cast(float, current.value)
735 initial_value = typing.cast(float, initial.value)
736 should_diff = (
737 current.type() is pytest_userver.metrics.MetricType.RATE
738 or initial.type() is pytest_userver.metrics.MetricType.RATE
739 or diff_gauge
740 )
741 return current_value - initial_value if should_diff else current_value
742
743
744def _subtract_metric_values_hist(
745 current: pytest_userver.metrics.Metric,
748 current_value = typing.cast(pytest_userver.metrics.Histogram, current.value)
749 initial_value = typing.cast(pytest_userver.metrics.Histogram, initial.value)
750 assert current_value.bounds == initial_value.bounds
752 bounds=current_value.bounds,
753 buckets=[t[0] - t[1] for t in zip(current_value.buckets, initial_value.buckets, strict=True)],
754 inf=current_value.inf - initial_value.inf,
755 )
756
757
758class AiohttpClient(service_client.AiohttpClient):
759 PeriodicTaskFailed = PeriodicTaskFailed
760 TestsuiteActionFailed = TestsuiteActionFailed
761 TestsuiteTaskNotFound = TestsuiteTaskNotFound
762 TestsuiteTaskConflict = TestsuiteTaskConflict
763 TestsuiteTaskFailed = TestsuiteTaskFailed
764
765 def __init__(
766 self,
767 base_url: str,
768 *,
769 config: TestsuiteClientConfig,
770 mocked_time,
771 log_capture_fixture: logcapture.CaptureServer,
772 testpoint,
773 testpoint_control,
774 cache_invalidation_state,
775 span_id_header=None,
776 api_coverage_report=None,
777 periodic_tasks_state: PeriodicTasksState | None = None,
778 allow_all_caches_invalidation: bool = True,
779 cache_control: caches.CacheControl | None = None,
780 asyncexc_check=None,
781 **kwargs,
782 ):
783 super().__init__(base_url, span_id_header=span_id_header, **kwargs)
784 self._config = config
785 self._periodic_tasks = periodic_tasks_state
786 self._testpoint = testpoint
787 self._log_capture_fixture = log_capture_fixture
788 self._state_manager = _StateManager(
789 mocked_time=mocked_time,
790 testpoint=self._testpoint,
791 testpoint_control=testpoint_control,
792 invalidation_state=cache_invalidation_state,
793 cache_control=cache_control,
794 )
795 self._api_coverage_report = api_coverage_report
796 self._allow_all_caches_invalidation = allow_all_caches_invalidation
797 self._asyncexc_check = asyncexc_check
798
799 async def run_periodic(self, name) -> None:
800 await self.run_task(f'periodic/{name}')
801
802 async def run_periodic_task(self, name):
803 warnings.warn(userver_warnings.WARN_PERIODIC_DEPRECATION, DeprecationWarning)
804 response = await self._testsuite_action('run_periodic_task', name=name)
805 if not response['status']:
806 raise self.PeriodicTaskFailed(f'Periodic task {name} failed')
807
808 async def suspend_periodic_tasks(self, names: list[str]) -> None:
809 if not self._periodic_tasks:
810 raise ConfigurationError('No periodic_tasks_state given')
811 self._periodic_tasks.tasks_to_suspend.update(names)
812 await self._suspend_periodic_tasks()
813
814 async def resume_periodic_tasks(self, names: list[str]) -> None:
815 warnings.warn(userver_warnings.WARN_PERIODIC_DEPRECATION, DeprecationWarning)
816 if not self._periodic_tasks:
817 raise ConfigurationError('No periodic_tasks_state given')
818 self._periodic_tasks.tasks_to_suspend.difference_update(names)
819 await self._suspend_periodic_tasks()
820
821 async def resume_all_periodic_tasks(self) -> None:
822 if not self._periodic_tasks:
823 raise ConfigurationError('No periodic_tasks_state given')
824 self._periodic_tasks.tasks_to_suspend.clear()
825 await self._suspend_periodic_tasks()
826
827 async def write_cache_dumps(
828 self,
829 names: list[str],
830 *,
831 testsuite_skip_prepare=False,
832 ) -> None:
833 await self._testsuite_action(
834 'write_cache_dumps',
835 names=names,
836 testsuite_skip_prepare=testsuite_skip_prepare,
837 )
838
839 async def read_cache_dumps(
840 self,
841 names: list[str],
842 *,
843 testsuite_skip_prepare=False,
844 ) -> None:
845 await self._testsuite_action(
846 'read_cache_dumps',
847 names=names,
848 testsuite_skip_prepare=testsuite_skip_prepare,
849 )
850
851 async def run_distlock_task(self, name: str) -> None:
852 await self.run_task(f'distlock/{name}')
853
854 async def reset_metrics(self) -> None:
855 await self._testsuite_action('reset_metrics')
856
857 async def metrics_portability(
858 self,
859 *,
860 prefix: str | None = None,
861 ) -> dict[str, list[dict[str, str]]]:
862 return await self._testsuite_action(
863 'metrics_portability',
864 prefix=prefix,
865 )
866
867 async def list_tasks(self) -> list[str]:
868 response = await self._do_testsuite_action('tasks_list')
869 async with response:
870 response.raise_for_status()
871 body = await response.json(content_type=None)
872 return body['tasks']
873
874 async def run_task(self, name: str) -> None:
875 response = await self._do_testsuite_action(
876 'task_run',
877 json={'name': name},
878 )
879 await _task_check_response(name, response)
880
881 @contextlib.asynccontextmanager
882 async def spawn_task(self, name: str):
883 task_id = await self._task_spawn(name)
884 try:
885 yield
886 finally:
887 await self._task_stop_spawned(task_id)
888
889 async def _task_spawn(self, name: str) -> str:
890 response = await self._do_testsuite_action(
891 'task_spawn',
892 json={'name': name},
893 )
894 data = await _task_check_response(name, response)
895 return data['task_id']
896
897 async def _task_stop_spawned(self, task_id: str) -> None:
898 response = await self._do_testsuite_action(
899 'task_stop',
900 json={'task_id': task_id},
901 )
902 await _task_check_response(task_id, response)
903
904 async def http_allowed_urls_extra(
905 self,
906 http_allowed_urls_extra: list[str],
907 ) -> None:
908 await self._do_testsuite_action(
909 'http_allowed_urls_extra',
910 json={'allowed_urls_extra': http_allowed_urls_extra},
911 testsuite_skip_prepare=True,
912 )
913
914 @contextlib.asynccontextmanager
915 async def capture_logs(
916 self,
917 *,
918 log_level: str = 'DEBUG',
919 testsuite_skip_prepare: bool = False,
920 ):
921 async with self._log_capture_fixture.capture(
922 log_level=logcapture.LogLevel.from_string(log_level),
923 ) as capture:
924 logger.debug('Starting logcapture')
925 await self._testsuite_action(
926 'log_capture',
927 log_level=log_level,
928 socket_logging_duplication=True,
929 testsuite_skip_prepare=testsuite_skip_prepare,
930 )
931
932 try:
933 await self._log_capture_fixture.wait_for_client()
934 yield capture
935 finally:
936 await self._testsuite_action(
937 'log_capture',
938 log_level=self._log_capture_fixture.default_log_level.name,
939 socket_logging_duplication=False,
940 testsuite_skip_prepare=testsuite_skip_prepare,
941 )
942
943 async def log_flush(self, logger_name: str | None = None):
944 await self._testsuite_action(
945 'log_flush',
946 logger_name=logger_name,
947 testsuite_skip_prepare=True,
948 )
949
950 async def invalidate_caches(
951 self,
952 *,
953 clean_update: bool = True,
954 cache_names: list[str] | None = None,
955 testsuite_skip_prepare: bool = False,
956 ) -> None:
957 if cache_names is None and clean_update:
958 if self._allow_all_caches_invalidation:
959 warnings.warn(CACHE_INVALIDATION_MESSAGE, DeprecationWarning, stacklevel=2)
960 else:
961 __tracebackhide__ = True
962 raise RuntimeError(CACHE_INVALIDATION_MESSAGE)
963
964 if testsuite_skip_prepare:
965 await self._tests_control({
966 'invalidate_caches': {
967 'update_type': ('full' if clean_update else 'incremental'),
968 **({'names': cache_names} if cache_names else {}),
969 },
970 })
971 else:
972 await self.tests_control(
973 invalidate_caches=True,
974 clean_update=clean_update,
975 cache_names=cache_names,
976 )
977
978 async def tests_control(
979 self,
980 *,
981 invalidate_caches: bool = True,
982 clean_update: bool = True,
983 cache_names: list[str] | None = None,
984 http_allowed_urls_extra: list[str] | None = None,
985 ) -> dict[str, Any]:
986 body: dict[str, Any] = self._state_manager.get_pending_update()
987
988 if 'invalidate_caches' in body and invalidate_caches:
989 if not clean_update or cache_names:
990 logger.warning(
991 'Manual cache invalidation leads to indirect initial full cache invalidation',
992 )
993 await self._prepare()
994 body = {}
995
996 if invalidate_caches:
997 body['invalidate_caches'] = {
998 'update_type': ('full' if clean_update else 'incremental'),
999 }
1000 if cache_names:
1001 body['invalidate_caches']['names'] = cache_names
1002
1003 if http_allowed_urls_extra is not None:
1004 await self.http_allowed_urls_extra(http_allowed_urls_extra)
1005
1006 return await self._tests_control(body)
1007
1008 async def update_server_state(self) -> None:
1009 await self._prepare()
1010
1011 async def enable_testpoints(self, *, no_auto_cache_cleanup=False) -> None:
1012 if not self._testpoint:
1013 return
1014 if no_auto_cache_cleanup:
1015 await self._tests_control({
1016 'testpoints': sorted(self._testpoint.keys()),
1017 })
1018 else:
1019 await self.update_server_state()
1020
1021 async def get_dynamic_config_defaults(
1022 self,
1023 ) -> dict[str, Any]:
1024 return await self._testsuite_action(
1025 'get_dynamic_config_defaults',
1026 testsuite_skip_prepare=True,
1027 )
1028
1029 async def _tests_control(self, body: dict) -> dict[str, Any]:
1030 with self._state_manager.updating_state(body):
1031 async with await self._do_testsuite_action(
1032 'control',
1033 json=body,
1034 testsuite_skip_prepare=True,
1035 ) as response:
1036 if response.status == 404:
1037 raise ConfigurationError(
1038 'It seems that testsuite support is not enabled for your service',
1039 )
1040 response.raise_for_status()
1041 return await response.json(content_type=None)
1042
1043 async def _suspend_periodic_tasks(self):
1044 if self._periodic_tasks.tasks_to_suspend != self._periodic_tasks.suspended_tasks:
1045 await self._testsuite_action(
1046 'suspend_periodic_tasks',
1047 names=sorted(self._periodic_tasks.tasks_to_suspend),
1048 )
1049 self._periodic_tasks.suspended_tasks = set(
1050 self._periodic_tasks.tasks_to_suspend,
1051 )
1052
1053 def _do_testsuite_action(self, action, **kwargs):
1054 if not self._config.testsuite_action_path:
1055 raise ConfigurationError(
1056 'tests-control component is not properly configured',
1057 )
1058 path = self._config.testsuite_action_path.format(action=action)
1059 return self.post(path, **kwargs)
1060
1061 async def _testsuite_action(
1062 self,
1063 action,
1064 *,
1065 testsuite_skip_prepare=False,
1066 **kwargs,
1067 ):
1068 async with await self._do_testsuite_action(
1069 action,
1070 json=kwargs,
1071 testsuite_skip_prepare=testsuite_skip_prepare,
1072 ) as response:
1073 if response.status == 500:
1074 raise TestsuiteActionFailed
1075 response.raise_for_status()
1076 return await response.json(content_type=None)
1077
1078 async def _prepare(self) -> None:
1079 with self._state_manager.cache_control_update() as pending_update:
1080 if pending_update:
1081 await self._tests_control(pending_update)
1082
1083 async def _request( # pylint: disable=arguments-differ
1084 self,
1085 http_method: str,
1086 path: str,
1087 headers: dict[str, str] | None = None,
1088 bearer: str | None = None,
1089 x_real_ip: str | None = None,
1090 *,
1091 testsuite_skip_prepare: bool = False,
1092 **kwargs,
1093 ) -> aiohttp.ClientResponse:
1094 if self._asyncexc_check:
1095 # Check for pending background exceptions before call.
1096 self._asyncexc_check()
1097
1098 if not testsuite_skip_prepare:
1099 await self._prepare()
1100
1101 response = await super()._request(
1102 http_method,
1103 path,
1104 headers,
1105 bearer,
1106 x_real_ip,
1107 **kwargs,
1108 )
1109 if self._api_coverage_report:
1110 self._api_coverage_report.update_usage_stat(
1111 path,
1112 http_method,
1113 response.status,
1114 response.content_type,
1115 )
1116
1117 if self._asyncexc_check:
1118 # Check for pending background exceptions after call.
1119 self._asyncexc_check()
1120
1121 return response
1122
1123
1124# @endcond
1125
1126
1128 """
1129 Asyncio userver client, typically retrieved from
1130 @ref service_client "plugins.service_client.service_client"
1131 fixture.
1132
1133 Compatible with werkzeug interface.
1134
1135 @ingroup userver_testsuite
1136 """
1137
1138 PeriodicTaskFailed = PeriodicTaskFailed
1139 TestsuiteActionFailed = TestsuiteActionFailed
1140 TestsuiteTaskNotFound = TestsuiteTaskNotFound
1141 TestsuiteTaskConflict = TestsuiteTaskConflict
1142 TestsuiteTaskFailed = TestsuiteTaskFailed
1143
1144 def _wrap_client_response(self, response: aiohttp.ClientResponse) -> Awaitable[http.ClientResponse]:
1145 return http.wrap_client_response(
1146 response,
1147 json_loads=approx.json_loads,
1148 )
1149
1150 @_wrap_client_error
1151 async def run_periodic(self, name) -> None:
1152 await self._client.run_periodic(name)
1153
1154 @_wrap_client_error
1155 async def run_periodic_task(self, name):
1156 warnings.warn(userver_warnings.WARN_PERIODIC_DEPRECATION, DeprecationWarning)
1157 await self._client.run_periodic_task(name)
1158
1159 @_wrap_client_error
1160 async def suspend_periodic_tasks(self, names: list[str]) -> None:
1161 await self._client.suspend_periodic_tasks(names)
1162
1163 @_wrap_client_error
1164 async def resume_periodic_tasks(self, names: list[str]) -> None:
1165 warnings.warn(userver_warnings.WARN_PERIODIC_DEPRECATION, DeprecationWarning)
1166 await self._client.resume_periodic_tasks(names)
1167
1168 @_wrap_client_error
1169 async def resume_all_periodic_tasks(self) -> None:
1170 await self._client.resume_all_periodic_tasks()
1171
1172 @_wrap_client_error
1173 async def write_cache_dumps(
1174 self,
1175 names: list[str],
1176 *,
1177 testsuite_skip_prepare=False,
1178 ) -> None:
1179 await self._client.write_cache_dumps(
1180 names=names,
1181 testsuite_skip_prepare=testsuite_skip_prepare,
1182 )
1183
1184 @_wrap_client_error
1185 async def read_cache_dumps(
1186 self,
1187 names: list[str],
1188 *,
1189 testsuite_skip_prepare=False,
1190 ) -> None:
1191 await self._client.read_cache_dumps(
1192 names=names,
1193 testsuite_skip_prepare=testsuite_skip_prepare,
1194 )
1195
1196 async def run_task(self, name: str) -> None:
1197 await self._client.run_task(name)
1198
1199 async def run_distlock_task(self, name: str) -> None:
1200 await self._client.run_distlock_task(name)
1201
1202 async def reset_metrics(self) -> None:
1203 """
1204 Calls `ResetMetric(metric);` for each metric that has such C++ function.
1205
1206 Note that using `reset_metrics()` is discouraged, prefer using a more reliable
1207 @ref pytest_userver.client.ClientMonitor.metrics_diff "await monitor_client.metrics_diff()".
1208
1209 @snippet samples/testsuite-support/tests/test_metrics.py metrics reset
1210 """
1211 await self._client.reset_metrics()
1212
1214 self,
1215 *,
1216 prefix: str | None = None,
1217 ) -> dict[str, list[dict[str, str]]]:
1218 """
1219 Reports metrics related issues that could be encountered on
1220 different monitoring systems.
1221
1222 @sa @ref utils::statistics::GetPortabilityWarnings
1223 """
1224 return await self._client.metrics_portability(prefix=prefix)
1225
1226 def list_tasks(self) -> list[str]:
1227 return self._client.list_tasks()
1228
1229 def spawn_task(self, name: str):
1230 return self._client.spawn_task(name)
1231
1233 self,
1234 *,
1235 log_level: str = 'DEBUG',
1236 testsuite_skip_prepare: bool = False,
1237 ):
1238 """
1239 Captures logs from the service.
1240
1241 @param log_level Do not capture logs below this level.
1242 @param testsuite_skip_prepare An advanced parameter to skip auto-`update_server_state`.
1243
1244 @see @ref testsuite_logs_capture
1245 """
1246 return self._client.capture_logs(
1247 log_level=log_level,
1248 testsuite_skip_prepare=testsuite_skip_prepare,
1249 )
1250
1251 def log_flush(self, logger_name: str | None = None):
1252 """
1253 Flush service logs.
1254 """
1255 return self._client.log_flush(logger_name=logger_name)
1256
1257 @_wrap_client_error
1259 self,
1260 *,
1261 clean_update: bool = True,
1262 cache_names: list[str] | None = None,
1263 testsuite_skip_prepare: bool = False,
1264 ) -> None:
1265 """
1266 Send request to service to update caches.
1267
1268 @param clean_update if False, service will do a faster incremental
1269 update of caches whenever possible.
1270 @param cache_names which caches specifically should be updated;
1271 update all if None.
1272 @param testsuite_skip_prepare if False, service will automatically do
1273 update_server_state().
1274 """
1275 __tracebackhide__ = True
1276 await self._client.invalidate_caches(
1277 clean_update=clean_update,
1278 cache_names=cache_names,
1279 testsuite_skip_prepare=testsuite_skip_prepare,
1280 )
1281
1282 @_wrap_client_error
1283 async def tests_control(
1284 self,
1285 invalidate_caches: bool = True,
1286 clean_update: bool = True,
1287 cache_names: list[str] | None = None,
1288 http_allowed_urls_extra: list[str] | None = None,
1289 ) -> dict[str, Any]:
1290 return await self._client.tests_control(
1291 invalidate_caches=invalidate_caches,
1292 clean_update=clean_update,
1293 cache_names=cache_names,
1294 http_allowed_urls_extra=http_allowed_urls_extra,
1295 )
1296
1297 @_wrap_client_error
1298 async def update_server_state(self) -> None:
1299 """
1300 Update service-side state through http call to 'tests/control':
1301 - clear dirty (from other tests) caches
1302 - set service-side mocked time,
1303 - resume / suspend periodic tasks
1304 - enable testpoints
1305 If service is up-to-date, does nothing.
1306 """
1307 await self._client.update_server_state()
1308
1309 @_wrap_client_error
1310 async def enable_testpoints(self, no_auto_cache_cleanup: bool = False) -> None:
1311 """
1312 Send list of handled testpoint pats to service. For these paths service
1313 will no more skip http calls from TESTPOINT(...) macro.
1314
1315 @param no_auto_cache_cleanup prevent automatic cache cleanup.
1316 When calling service client first time in scope of current test, client
1317 makes additional http call to `tests/control` to update caches, to get
1318 rid of data from previous test.
1319 """
1320 await self._client.enable_testpoints(no_auto_cache_cleanup=no_auto_cache_cleanup)
1321
1322 @_wrap_client_error
1323 async def get_dynamic_config_defaults(
1324 self,
1325 ) -> dict[str, Any]:
1326 return await self._client.get_dynamic_config_defaults()
1327
1328
1329@dataclasses.dataclass
1331 """Reflects the (supposed) current service state."""
1332
1333 invalidation_state: caches.InvalidationState
1334 now: str | None = _UNKNOWN_STATE
1335 testpoints: frozenset[str] = frozenset([_UNKNOWN_STATE])
1336
1337
1339 """
1340 Used for computing the requests that we need to automatically align
1341 the service state with the test fixtures state.
1342 """
1343
1344 def __init__(
1345 self,
1346 *,
1347 mocked_time,
1348 testpoint,
1349 testpoint_control,
1350 invalidation_state: caches.InvalidationState,
1351 cache_control: caches.CacheControl | None,
1352 ):
1353 self._state = _State(
1354 invalidation_state=copy.deepcopy(invalidation_state),
1355 )
1356 self._mocked_time = mocked_time
1357 self._testpoint = testpoint
1358 self._testpoint_control = testpoint_control
1359 self._invalidation_state = invalidation_state
1360 self._cache_control = cache_control
1361
1362 @contextlib.contextmanager
1363 def updating_state(self, body: dict[str, Any]):
1364 """
1365 Whenever `tests_control` handler is invoked
1366 (by the client itself during `prepare` or manually by the user),
1367 we need to synchronize `_state` with the (supposed) service state.
1368 The state update is decoded from the request body.
1369 """
1370 saved_state = copy.deepcopy(self._state)
1371 try:
1372 self._update_state(body)
1373 self._apply_new_state()
1374 yield
1375 except Exception: # noqa
1376 self._state = saved_state
1377 self._apply_new_state()
1378 raise
1379
1380 def get_pending_update(self) -> dict[str, Any]:
1381 """
1382 Compose the body of the `tests_control` request required to completely
1383 synchronize the service state with the state of test fixtures.
1384 """
1385 body: dict[str, Any] = {}
1386
1387 if self._invalidation_state.has_caches_to_update:
1388 body['invalidate_caches'] = {'update_type': 'full'}
1389 if not self._invalidation_state.should_update_all_caches:
1390 body['invalidate_caches']['names'] = list(
1391 self._invalidation_state.caches_to_update,
1392 )
1393
1394 desired_testpoints = self._testpoint.keys()
1395 if self._state.testpoints != frozenset(desired_testpoints):
1396 body['testpoints'] = sorted(desired_testpoints)
1397
1398 desired_now = self._get_desired_now()
1399 if self._state.now != desired_now:
1400 body['mock_now'] = desired_now
1401
1402 return body
1403
1404 @contextlib.contextmanager
1405 def cache_control_update(self) -> Iterator[dict[Any, Any]]:
1406 pending_update = self.get_pending_update()
1407 invalidate_caches = pending_update.get('invalidate_caches')
1408 if not invalidate_caches or not self._cache_control:
1409 yield pending_update
1410 else:
1411 cache_names = invalidate_caches.get('names')
1412 staged, actions = self._cache_control.query_caches(cache_names)
1413 self._apply_cache_control_actions(invalidate_caches, actions)
1414 yield pending_update
1415 self._cache_control.commit_staged(staged)
1416
1417 @staticmethod
1418 def _apply_cache_control_actions(
1419 invalidate_caches: dict[Any, Any],
1420 actions: list[tuple[str, caches.CacheControlAction]],
1421 ) -> None:
1422 cache_names = invalidate_caches.get('names')
1423 exclude_names = invalidate_caches.setdefault('exclude_names', [])
1424 force_incremental_names = invalidate_caches.setdefault(
1425 'force_incremental_names',
1426 [],
1427 )
1428 for cache_name, action in actions:
1429 match action:
1430 case caches.CacheControlAction.FULL:
1431 pass
1432 case caches.CacheControlAction.INCREMENTAL:
1433 force_incremental_names.append(cache_name)
1434 case caches.CacheControlAction.EXCLUDE:
1435 if cache_names is not None:
1436 cache_names.remove(cache_name)
1437 else:
1438 exclude_names.append(cache_name)
1439
1440 def _update_state(self, body: dict[str, Any]) -> None:
1441 body_invalidate_caches = body.get('invalidate_caches', {})
1442 update_type = body_invalidate_caches.get('update_type', 'full')
1443 body_cache_names = body_invalidate_caches.get('names', None)
1444 # An incremental update is considered insufficient to bring a cache
1445 # to a known state.
1446 if body_invalidate_caches and update_type == 'full':
1447 if body_cache_names is None:
1448 self._state.invalidation_state.on_all_caches_updated()
1449 else:
1450 self._state.invalidation_state.on_caches_updated(
1451 body_cache_names,
1452 )
1453
1454 if 'mock_now' in body:
1455 self._state.now = body['mock_now']
1456
1457 testpoints: list[str] | None = body.get('testpoints')
1458 if testpoints is not None:
1459 self._state.testpoints = frozenset(testpoints)
1460
1462 """Apply new state to related components."""
1463 self._testpoint_control.enabled_testpoints = self._state.testpoints
1464 self._invalidation_state.assign_copy(self._state.invalidation_state)
1465
1466 def _get_desired_now(self) -> str | None:
1467 if self._mocked_time.is_enabled:
1468 return utils.timestring(self._mocked_time.now())
1469 return None
1470
1471
1472async def _task_check_response(name: str, response) -> dict:
1473 async with response:
1474 if response.status == 404:
1475 raise TestsuiteTaskNotFound(f'Testsuite task {name!r} not found')
1476 if response.status == 409:
1477 raise TestsuiteTaskConflict(f'Testsuite task {name!r} conflict')
1478 assert response.status == 200
1479 data = await response.json()
1480 if not data.get('status', True):
1481 raise TestsuiteTaskFailed(name, data['reason'])
1482 return data