userver: /data/code/service_template/third_party/userver/testsuite/pytest_plugins/pytest_userver/client.py Source File
Loading...
Searching...
No Matches
client.py
1"""
2Python module that provides clients for functional tests with
3testsuite; see
4@ref scripts/docs/en/userver/functional_testing.md for an introduction.
5
6@ingroup userver_testsuite
7"""
8
9# pylint: disable=too-many-lines
10
11import contextlib
12import copy
13import dataclasses
14import json
15import logging
16import typing
17import warnings
18
19import aiohttp
20
21from testsuite import annotations
22from testsuite import utils
23from testsuite.daemons import service_client
24from testsuite.utils import approx
25from testsuite.utils import http
26
27import pytest_userver.metrics as metric_module # pylint: disable=import-error
28from pytest_userver.plugins import caches
29
30# @cond
31logger = logging.getLogger(__name__)
32# @endcond
33
34_UNKNOWN_STATE = '__UNKNOWN__'
35
36CACHE_INVALIDATION_MESSAGE = (
37 'Direct cache invalidation is deprecated.\n'
38 '\n'
39 ' - Use client.update_server_state() to synchronize service state\n'
40 ' - Explicitly pass cache names to invalidate, e.g.: '
41 'invalidate_caches(cache_names=[...]).'
42)
43
44
45class BaseError(Exception):
46 """Base class for exceptions of this module."""
47
48
50 pass
51
52
54 pass
55
56
66 pass
67
68
70 pass
71
72
74 def __init__(self):
75 self.suspended_tasks: typing.Set[str] = set()
76 self.tasks_to_suspend: typing.Set[str] = set()
77
78
80 def __init__(self, name, reason):
81 self.name = name
82 self.reason = reason
83 super().__init__(f'Testsuite task {name!r} failed: {reason}')
84
85
86@dataclasses.dataclass(frozen=True)
88 testsuite_action_path: typing.Optional[str] = None
89 server_monitor_path: typing.Optional[str] = None
90
91
92Metric = metric_module.Metric
93
94
96 """
97 Base asyncio userver client that implements HTTP requests to service.
98
99 Compatible with werkzeug interface.
100
101 @ingroup userver_testsuite
102 """
103
104 def __init__(self, client):
105 self._client = client
106
107 async def post(
108 self,
109 path: str,
110 # pylint: disable=redefined-outer-name
111 json: annotations.JsonAnyOptional = None,
112 data: typing.Any = None,
113 params: typing.Optional[typing.Dict[str, str]] = None,
114 bearer: typing.Optional[str] = None,
115 x_real_ip: typing.Optional[str] = None,
116 headers: typing.Optional[typing.Dict[str, str]] = None,
117 **kwargs,
118 ) -> http.ClientResponse:
119 """
120 Make a HTTP POST request
121 """
122 response = await self._client.post(
123 path,
124 json=json,
125 data=data,
126 params=params,
127 headers=headers,
128 bearer=bearer,
129 x_real_ip=x_real_ip,
130 **kwargs,
131 )
132 return await self._wrap_client_response(response)
133
134 async def put(
135 self,
136 path,
137 # pylint: disable=redefined-outer-name
138 json: annotations.JsonAnyOptional = None,
139 data: typing.Any = None,
140 params: typing.Optional[typing.Dict[str, str]] = None,
141 bearer: typing.Optional[str] = None,
142 x_real_ip: typing.Optional[str] = None,
143 headers: typing.Optional[typing.Dict[str, str]] = None,
144 **kwargs,
145 ) -> http.ClientResponse:
146 """
147 Make a HTTP PUT request
148 """
149 response = await self._client.put(
150 path,
151 json=json,
152 data=data,
153 params=params,
154 headers=headers,
155 bearer=bearer,
156 x_real_ip=x_real_ip,
157 **kwargs,
158 )
159 return await self._wrap_client_response(response)
160
161 async def patch(
162 self,
163 path,
164 # pylint: disable=redefined-outer-name
165 json: annotations.JsonAnyOptional = None,
166 data: typing.Any = None,
167 params: typing.Optional[typing.Dict[str, str]] = None,
168 bearer: typing.Optional[str] = None,
169 x_real_ip: typing.Optional[str] = None,
170 headers: typing.Optional[typing.Dict[str, str]] = None,
171 **kwargs,
172 ) -> http.ClientResponse:
173 """
174 Make a HTTP PATCH request
175 """
176 response = await self._client.patch(
177 path,
178 json=json,
179 data=data,
180 params=params,
181 headers=headers,
182 bearer=bearer,
183 x_real_ip=x_real_ip,
184 **kwargs,
185 )
186 return await self._wrap_client_response(response)
187
188 async def get(
189 self,
190 path: str,
191 headers: typing.Optional[typing.Dict[str, str]] = None,
192 bearer: typing.Optional[str] = None,
193 x_real_ip: typing.Optional[str] = None,
194 **kwargs,
195 ) -> http.ClientResponse:
196 """
197 Make a HTTP GET request
198 """
199 response = await self._client.get(
200 path,
201 headers=headers,
202 bearer=bearer,
203 x_real_ip=x_real_ip,
204 **kwargs,
205 )
206 return await self._wrap_client_response(response)
207
208 async def delete(
209 self,
210 path: str,
211 headers: typing.Optional[typing.Dict[str, str]] = None,
212 bearer: typing.Optional[str] = None,
213 x_real_ip: typing.Optional[str] = None,
214 **kwargs,
215 ) -> http.ClientResponse:
216 """
217 Make a HTTP DELETE request
218 """
219 response = await self._client.delete(
220 path,
221 headers=headers,
222 bearer=bearer,
223 x_real_ip=x_real_ip,
224 **kwargs,
225 )
226 return await self._wrap_client_response(response)
227
228 async def options(
229 self,
230 path: str,
231 headers: typing.Optional[typing.Dict[str, str]] = None,
232 bearer: typing.Optional[str] = None,
233 x_real_ip: typing.Optional[str] = None,
234 **kwargs,
235 ) -> http.ClientResponse:
236 """
237 Make a HTTP OPTIONS request
238 """
239 response = await self._client.options(
240 path,
241 headers=headers,
242 bearer=bearer,
243 x_real_ip=x_real_ip,
244 **kwargs,
245 )
246 return await self._wrap_client_response(response)
247
248 async def request(
249 self, http_method: str, path: str, **kwargs,
250 ) -> http.ClientResponse:
251 """
252 Make a HTTP request with the specified method
253 """
254 response = await self._client.request(http_method, path, **kwargs)
255 return await self._wrap_client_response(response)
256
257 def _wrap_client_response(
258 self, response: aiohttp.ClientResponse,
259 ) -> typing.Awaitable[http.ClientResponse]:
260 return http.wrap_client_response(response)
261
262
263# @cond
264
265
266def _wrap_client_error(func):
267 async def _wrapper(*args, **kwargs):
268 try:
269 return await func(*args, **kwargs)
270 except aiohttp.client_exceptions.ClientResponseError as exc:
271 raise http.HttpResponseError(
272 url=exc.request_info.url, status=exc.status,
273 )
274
275 return _wrapper
276
277
278class AiohttpClientMonitor(service_client.AiohttpClient):
279 _config: TestsuiteClientConfig
280
281 def __init__(self, base_url, *, config: TestsuiteClientConfig, **kwargs):
282 super().__init__(base_url, **kwargs)
283 self._config = config
284
285 async def get_metrics(self, prefix=None):
286 if not self._config.server_monitor_path:
287 raise ConfigurationError(
288 'handler-server-monitor component is not configured',
289 )
290 params = {'format': 'internal'}
291 if prefix is not None:
292 params['prefix'] = prefix
293 response = await self.get(
294 self._config.server_monitor_path, params=params,
295 )
296 async with response:
297 response.raise_for_status()
298 return await response.json(content_type=None)
299
300 async def get_metric(self, metric_name):
301 metrics = await self.get_metrics(metric_name)
302 assert metric_name in metrics, (
303 f'No metric with name {metric_name!r}. '
304 f'Use "single_metric" function instead of "get_metric"'
305 )
306 return metrics[metric_name]
307
308 async def metrics_raw(
309 self,
310 output_format,
311 *,
312 path: str = None,
313 prefix: str = None,
314 labels: typing.Optional[typing.Dict[str, str]] = None,
315 ) -> str:
316 if not self._config.server_monitor_path:
317 raise ConfigurationError(
318 'handler-server-monitor component is not configured',
319 )
320
321 params = {'format': output_format}
322 if prefix:
323 params['prefix'] = prefix
324
325 if path:
326 params['path'] = path
327
328 if labels:
329 params['labels'] = json.dumps(labels)
330
331 response = await self.get(
332 self._config.server_monitor_path, params=params,
333 )
334 async with response:
335 response.raise_for_status()
336 return await response.text()
337
338 async def metrics(
339 self,
340 *,
341 path: str = None,
342 prefix: str = None,
343 labels: typing.Optional[typing.Dict[str, str]] = None,
344 ) -> metric_module.MetricsSnapshot:
345 response = await self.metrics_raw(
346 output_format='json', path=path, prefix=prefix, labels=labels,
347 )
348 return metric_module.MetricsSnapshot.from_json(str(response))
349
350 async def single_metric_optional(
351 self,
352 path: str,
353 *,
354 labels: typing.Optional[typing.Dict[str, str]] = None,
355 ) -> typing.Optional[Metric]:
356 response = await self.metrics(path=path, labels=labels)
357 metrics_list = response.get(path, [])
358
359 assert len(metrics_list) <= 1, (
360 f'More than one metric found for path {path} and labels {labels}: '
361 f'{response}',
362 )
363
364 if not metrics_list:
365 return None
366
367 return next(iter(metrics_list))
368
369 async def single_metric(
370 self,
371 path: str,
372 *,
373 labels: typing.Optional[typing.Dict[str, str]] = None,
374 ) -> Metric:
375 value = await self.single_metric_optional(path, labels=labels)
376 assert value is not None, (
377 f'No metric was found for path {path} and labels {labels}',
378 )
379 return value
380
381
382# @endcond
383
384
386 """
387 Asyncio userver client for monitor listeners, typically retrieved from
388 plugins.service_client.monitor_client fixture.
389
390 Compatible with werkzeug interface.
391
392 @ingroup userver_testsuite
393 """
394
396 self,
397 *,
398 path: typing.Optional[str] = None,
399 prefix: typing.Optional[str] = None,
400 labels: typing.Optional[typing.Dict[str, str]] = None,
401 diff_gauge: bool = False,
402 ) -> 'MetricsDiffer':
403 """
404 Creates a `MetricsDiffer` that fetches metrics using this client.
405 It's recommended to use this method over `metrics` to make sure
406 the tests don't affect each other.
407
408 With `diff_gauge` off, only RATE metrics are differentiated.
409 With `diff_gauge` on, GAUGE metrics are differentiated as well,
410 which may lead to nonsensical results for those.
411
412 @param path Optional full metric path
413 @param prefix Optional prefix on which the metric paths should start
414 @param labels Optional dictionary of labels that must be in the metric
415 @param diff_gauge Whether to differentiate GAUGE metrics
416
417 @code
418 async with monitor_client.metrics_diff(prefix='foo') as differ:
419 # Do something that makes the service update its metrics
420 assert differ.value_at('path-suffix', {'label'}) == 42
421 @endcode
422 """
423 return MetricsDiffer(
424 _client=self,
425 _path=path,
426 _prefix=prefix,
427 _labels=labels,
428 _diff_gauge=diff_gauge,
429 )
430
431 @_wrap_client_error
432 async def metrics(
433 self,
434 *,
435 path: typing.Optional[str] = None,
436 prefix: typing.Optional[str] = None,
437 labels: typing.Optional[typing.Dict[str, str]] = None,
438 ) -> metric_module.MetricsSnapshot:
439 """
440 Returns a dict of metric names to Metric.
441
442 @param path Optional full metric path
443 @param prefix Optional prefix on which the metric paths should start
444 @param labels Optional dictionary of labels that must be in the metric
445 """
446 return await self._client.metrics(
447 path=path, prefix=prefix, labels=labels,
448 )
449
450 @_wrap_client_error
452 self,
453 path: str,
454 *,
455 labels: typing.Optional[typing.Dict[str, str]] = None,
456 ) -> typing.Optional[Metric]:
457 """
458 Either return a Metric or None if there's no such metric.
459
460 @param path Full metric path
461 @param labels Optional dictionary of labels that must be in the metric
462
463 @throws AssertionError if more than one metric returned
464 """
465 return await self._client.single_metric_optional(path, labels=labels)
466
467 @_wrap_client_error
468 async def single_metric(
469 self,
470 path: str,
471 *,
472 labels: typing.Optional[typing.Dict[str, str]] = None,
473 ) -> typing.Optional[Metric]:
474 """
475 Returns the Metric.
476
477 @param path Full metric path
478 @param labels Optional dictionary of labels that must be in the metric
479
480 @throws AssertionError if more than one metric or no metric found
481 """
482 return await self._client.single_metric(path, labels=labels)
483
484 @_wrap_client_error
485 async def metrics_raw(
486 self,
487 output_format: str,
488 *,
489 path: typing.Optional[str] = None,
490 prefix: typing.Optional[str] = None,
491 labels: typing.Optional[typing.Dict[str, str]] = None,
492 ) -> typing.Dict[str, Metric]:
493 """
494 Low level function that returns metrics in a specific format.
495 Use `metrics` and `single_metric` instead if possible.
496
497 @param output_format Metric output format. See
498 server::handlers::ServerMonitor for a list of supported formats.
499 @param path Optional full metric path
500 @param prefix Optional prefix on which the metric paths should start
501 @param labels Optional dictionary of labels that must be in the metric
502 """
503 return await self._client.metrics_raw(
504 output_format=output_format,
505 path=path,
506 prefix=prefix,
507 labels=labels,
508 )
509
510 @_wrap_client_error
511 async def get_metrics(self, prefix=None):
512 """
513 @deprecated Use metrics() or single_metric() instead
514 """
515 return await self._client.get_metrics(prefix=prefix)
516
517 @_wrap_client_error
518 async def get_metric(self, metric_name):
519 """
520 @deprecated Use metrics() or single_metric() instead
521 """
522 return await self._client.get_metric(metric_name)
523
524 @_wrap_client_error
525 async def fired_alerts(self):
526 response = await self._client.get('/service/fired-alerts')
527 assert response.status == 200
528 return (await response.json())['alerts']
529
530
532 """
533 A helper class for computing metric differences.
534
535 @see ClientMonitor.metrics_diff
536 @ingroup userver_testsuite
537 """
538
539 # @cond
540 def __init__(
541 self,
542 _client: ClientMonitor,
543 _path: typing.Optional[str],
544 _prefix: typing.Optional[str],
545 _labels: typing.Optional[typing.Dict[str, str]],
546 _diff_gauge: bool,
547 ):
548 self._client = _client
549 self._path = _path
550 self._prefix = _prefix
551 self._labels = _labels
552 self._diff_gauge = _diff_gauge
553 self._baseline: typing.Optional[metric_module.MetricsSnapshot] = None
554 self._current: typing.Optional[metric_module.MetricsSnapshot] = None
555 self._diff: typing.Optional[metric_module.MetricsSnapshot] = None
556
557 # @endcond
558
559 @property
560 def baseline(self) -> metric_module.MetricsSnapshot:
561 assert self._baseline is not None
562 return self._baseline
563
564 @baseline.setter
565 def baseline(self, value: metric_module.MetricsSnapshot) -> None:
566 self._baseline = value
567 if self._current is not None:
568 self._diff = _subtract_metrics_snapshots(
569 self._current, self._baseline, self._diff_gauge,
570 )
571
572 @property
573 def current(self) -> metric_module.MetricsSnapshot:
574 assert self._current is not None, 'Set self.current first'
575 return self._current
576
577 @current.setter
578 def current(self, value: metric_module.MetricsSnapshot) -> None:
579 self._current = value
580 assert self._baseline is not None, 'Set self.baseline first'
581 self._diff = _subtract_metrics_snapshots(
582 self._current, self._baseline, self._diff_gauge,
583 )
584
585 @property
586 def diff(self) -> metric_module.MetricsSnapshot:
587 assert self._diff is not None, 'Set self.current first'
588 return self._diff
589
591 self,
592 subpath: typing.Optional[str] = None,
593 add_labels: typing.Optional[typing.Dict] = None,
594 *,
595 default: typing.Optional[float] = None,
596 ) -> metric_module.MetricValue:
597 """
598 Returns a single metric value at the specified path, prepending
599 the path provided at construction. If a dict of labels is provided,
600 does en exact match of labels, prepending the labels provided
601 at construction.
602
603 @param subpath Suffix of the metric path; the path provided
604 at construction is prepended
605 @param add_labels Labels that the metric must have in addition
606 to the labels provided at construction
607 @param default An optional default value in case the metric is missing
608 @throws AssertionError if not one metric by path
609 """
610 base_path = self._path or self._prefix
611 if base_path and subpath:
612 path = f'{base_path}.{subpath}'
613 else:
614 assert base_path or subpath, 'No path provided'
615 path = base_path or subpath or ''
616 labels: typing.Optional[dict] = None
617 if self._labels is not None or add_labels is not None:
618 labels = {**(self._labels or {}), **(add_labels or {})}
619 return self.diff.value_at(path, labels, default=default)
620
621 async def fetch(self) -> metric_module.MetricsSnapshot:
622 """
623 Fetches metric values from the service.
624 """
625 return await self._client.metrics(
626 path=self._path, prefix=self._prefix, labels=self._labels,
627 )
628
629 async def __aenter__(self) -> 'MetricsDiffer':
630 self._baseline = await self.fetch()
631 self._current = None
632 return self
633
634 async def __aexit__(self, exc_type, exc, exc_tb) -> None:
635 self.currentcurrentcurrent = await self.fetch()
636
637
638# @cond
639
640
641def _subtract_metrics_snapshots(
642 current: metric_module.MetricsSnapshot,
643 initial: metric_module.MetricsSnapshot,
644 diff_gauge: bool,
645) -> metric_module.MetricsSnapshot:
646 return metric_module.MetricsSnapshot(
647 {
648 path: {
649 _subtract_metrics(path, current_metric, initial, diff_gauge)
650 for current_metric in current_group
651 }
652 for path, current_group in current.items()
653 },
654 )
655
656
657def _subtract_metrics(
658 path: str,
659 current_metric: metric_module.Metric,
660 initial: metric_module.MetricsSnapshot,
661 diff_gauge: bool,
662) -> metric_module.Metric:
663 initial_group = initial.get(path, None)
664 if initial_group is None:
665 return current_metric
666 initial_metric = next(
667 (x for x in initial_group if x.labels == current_metric.labels), None,
668 )
669 if initial_metric is None:
670 return current_metric
671
672 return metric_module.Metric(
673 labels=current_metric.labels,
674 value=_subtract_metric_values(
675 current=current_metric,
676 initial=initial_metric,
677 diff_gauge=diff_gauge,
678 ),
679 _type=current_metric.type(),
680 )
681
682
683def _subtract_metric_values(
684 current: metric_module.Metric,
685 initial: metric_module.Metric,
686 diff_gauge: bool,
687) -> metric_module.MetricValue:
688 assert current.type() is not metric_module.MetricType.UNSPECIFIED
689 assert initial.type() is not metric_module.MetricType.UNSPECIFIED
690 assert current.type() == initial.type()
691
692 if isinstance(current.value, metric_module.Histogram):
693 assert isinstance(initial.value, metric_module.Histogram)
694 return _subtract_metric_values_hist(current=current, initial=initial)
695 else:
696 assert not isinstance(initial.value, metric_module.Histogram)
697 return _subtract_metric_values_num(
698 current=current, initial=initial, diff_gauge=diff_gauge,
699 )
700
701
702def _subtract_metric_values_num(
703 current: metric_module.Metric,
704 initial: metric_module.Metric,
705 diff_gauge: bool,
706) -> float:
707 current_value = typing.cast(float, current.value)
708 initial_value = typing.cast(float, initial.value)
709 should_diff = (
710 current.type() is metric_module.MetricType.RATE
711 or initial.type() is metric_module.MetricType.RATE
712 or diff_gauge
713 )
714 return current_value - initial_value if should_diff else current_value
715
716
717def _subtract_metric_values_hist(
718 current: metric_module.Metric, initial: metric_module.Metric,
719) -> metric_module.Histogram:
720 current_value = typing.cast(metric_module.Histogram, current.value)
721 initial_value = typing.cast(metric_module.Histogram, initial.value)
722 assert current_value.bounds == initial_value.bounds
723 return metric_module.Histogram(
724 bounds=current_value.bounds,
725 buckets=[
726 t[0] - t[1]
727 for t in zip(current_value.buckets, initial_value.buckets)
728 ],
729 inf=current_value.inf - initial_value.inf,
730 )
731
732
733class AiohttpClient(service_client.AiohttpClient):
734 PeriodicTaskFailed = PeriodicTaskFailed
735 TestsuiteActionFailed = TestsuiteActionFailed
736 TestsuiteTaskNotFound = TestsuiteTaskNotFound
737 TestsuiteTaskConflict = TestsuiteTaskConflict
738 TestsuiteTaskFailed = TestsuiteTaskFailed
739
740 def __init__(
741 self,
742 base_url: str,
743 *,
744 config: TestsuiteClientConfig,
745 mocked_time,
746 log_capture_fixture,
747 testpoint,
748 testpoint_control,
749 cache_invalidation_state,
750 span_id_header=None,
751 api_coverage_report=None,
752 periodic_tasks_state: typing.Optional[PeriodicTasksState] = None,
753 allow_all_caches_invalidation: bool = True,
754 cache_control: typing.Optional[caches.CacheControl] = None,
755 **kwargs,
756 ):
757 super().__init__(base_url, span_id_header=span_id_header, **kwargs)
758 self._config = config
759 self._periodic_tasks = periodic_tasks_state
760 self._testpoint = testpoint
761 self._log_capture_fixture = log_capture_fixture
762 self._state_manager = _StateManager(
763 mocked_time=mocked_time,
764 testpoint=self._testpoint,
765 testpoint_control=testpoint_control,
766 invalidation_state=cache_invalidation_state,
767 cache_control=cache_control,
768 )
769 self._api_coverage_report = api_coverage_report
770 self._allow_all_caches_invalidation = allow_all_caches_invalidation
771
772 async def run_periodic_task(self, name):
773 response = await self._testsuite_action('run_periodic_task', name=name)
774 if not response['status']:
775 raise self.PeriodicTaskFailed(f'Periodic task {name} failed')
776
777 async def suspend_periodic_tasks(self, names: typing.List[str]) -> None:
778 if not self._periodic_tasks:
779 raise ConfigurationError('No periodic_tasks_state given')
780 self._periodic_tasks.tasks_to_suspend.update(names)
781 await self._suspend_periodic_tasks()
782
783 async def resume_periodic_tasks(self, names: typing.List[str]) -> None:
784 if not self._periodic_tasks:
785 raise ConfigurationError('No periodic_tasks_state given')
786 self._periodic_tasks.tasks_to_suspend.difference_update(names)
787 await self._suspend_periodic_tasks()
788
789 async def resume_all_periodic_tasks(self) -> None:
790 if not self._periodic_tasks:
791 raise ConfigurationError('No periodic_tasks_state given')
792 self._periodic_tasks.tasks_to_suspend.clear()
793 await self._suspend_periodic_tasks()
794
795 async def write_cache_dumps(
796 self, names: typing.List[str], *, testsuite_skip_prepare=False,
797 ) -> None:
798 await self._testsuite_action(
799 'write_cache_dumps',
800 names=names,
801 testsuite_skip_prepare=testsuite_skip_prepare,
802 )
803
804 async def read_cache_dumps(
805 self, names: typing.List[str], *, testsuite_skip_prepare=False,
806 ) -> None:
807 await self._testsuite_action(
808 'read_cache_dumps',
809 names=names,
810 testsuite_skip_prepare=testsuite_skip_prepare,
811 )
812
813 async def run_distlock_task(self, name: str) -> None:
814 await self.run_task(f'distlock/{name}')
815
816 async def reset_metrics(self) -> None:
817 await self._testsuite_action('reset_metrics')
818
819 async def metrics_portability(
820 self, *, prefix: typing.Optional[str] = None,
821 ) -> typing.Dict[str, typing.List[typing.Dict[str, str]]]:
822 return await self._testsuite_action(
823 'metrics_portability', prefix=prefix,
824 )
825
826 async def list_tasks(self) -> typing.List[str]:
827 response = await self._do_testsuite_action('tasks_list')
828 async with response:
829 response.raise_for_status()
830 body = await response.json(content_type=None)
831 return body['tasks']
832
833 async def run_task(self, name: str) -> None:
834 response = await self._do_testsuite_action(
835 'task_run', json={'name': name},
836 )
837 await _task_check_response(name, response)
838
839 @contextlib.asynccontextmanager
840 async def spawn_task(self, name: str):
841 task_id = await self._task_spawn(name)
842 try:
843 yield
844 finally:
845 await self._task_stop_spawned(task_id)
846
847 async def _task_spawn(self, name: str) -> str:
848 response = await self._do_testsuite_action(
849 'task_spawn', json={'name': name},
850 )
851 data = await _task_check_response(name, response)
852 return data['task_id']
853
854 async def _task_stop_spawned(self, task_id: str) -> None:
855 response = await self._do_testsuite_action(
856 'task_stop', json={'task_id': task_id},
857 )
858 await _task_check_response(task_id, response)
859
860 async def http_allowed_urls_extra(
861 self, http_allowed_urls_extra: typing.List[str],
862 ) -> None:
863 await self._do_testsuite_action(
864 'http_allowed_urls_extra',
865 json={'allowed_urls_extra': http_allowed_urls_extra},
866 testsuite_skip_prepare=True,
867 )
868
869 @contextlib.asynccontextmanager
870 async def capture_logs(
871 self,
872 *,
873 log_level: str = 'DEBUG',
874 testsuite_skip_prepare: bool = False,
875 ):
876 async with self._log_capture_fixture.start_capture(
877 log_level=log_level,
878 ) as capture:
879 await self._testsuite_action(
880 'log_capture',
881 log_level=log_level,
882 socket_logging_duplication=True,
883 testsuite_skip_prepare=testsuite_skip_prepare,
884 )
885 try:
886 yield capture
887 finally:
888 await self._testsuite_action(
889 'log_capture',
890 log_level=self._log_capture_fixture.default_log_level,
891 socket_logging_duplication=False,
892 testsuite_skip_prepare=testsuite_skip_prepare,
893 )
894
895 async def invalidate_caches(
896 self,
897 *,
898 clean_update: bool = True,
899 cache_names: typing.Optional[typing.List[str]] = None,
900 testsuite_skip_prepare: bool = False,
901 ) -> None:
902 if cache_names is None and clean_update:
903 if self._allow_all_caches_invalidation:
904 warnings.warn(CACHE_INVALIDATION_MESSAGE, DeprecationWarning)
905 else:
906 __tracebackhide__ = True
907 raise RuntimeError(CACHE_INVALIDATION_MESSAGE)
908
909 if testsuite_skip_prepare:
910 await self._tests_control(
911 {
912 'invalidate_caches': {
913 'update_type': (
914 'full' if clean_update else 'incremental'
915 ),
916 **({'names': cache_names} if cache_names else {}),
917 },
918 },
919 )
920 else:
921 await self.tests_control(
922 invalidate_caches=True,
923 clean_update=clean_update,
924 cache_names=cache_names,
925 )
926
927 async def tests_control(
928 self,
929 *,
930 invalidate_caches: bool = True,
931 clean_update: bool = True,
932 cache_names: typing.Optional[typing.List[str]] = None,
933 http_allowed_urls_extra=None,
934 ) -> typing.Dict[str, typing.Any]:
935 body: typing.Dict[
936 str, typing.Any,
937 ] = self._state_manager.get_pending_update()
938
939 if 'invalidate_caches' in body and invalidate_caches:
940 if not clean_update or cache_names:
941 logger.warning(
942 'Manual cache invalidation leads to indirect initial '
943 'full cache invalidation',
944 )
945 await self._prepare()
946 body = {}
947
948 if invalidate_caches:
949 body['invalidate_caches'] = {
950 'update_type': ('full' if clean_update else 'incremental'),
951 }
952 if cache_names:
953 body['invalidate_caches']['names'] = cache_names
954
955 if http_allowed_urls_extra is not None:
956 await self.http_allowed_urls_extra(http_allowed_urls_extra)
957
958 return await self._tests_control(body)
959
960 async def update_server_state(self) -> None:
961 await self._prepare()
962
963 async def enable_testpoints(self, *, no_auto_cache_cleanup=False) -> None:
964 if not self._testpoint:
965 return
966 if no_auto_cache_cleanup:
967 await self._tests_control(
968 {'testpoints': sorted(self._testpoint.keys())},
969 )
970 else:
971 await self.update_server_state()
972
973 async def get_dynamic_config_defaults(
974 self,
975 ) -> typing.Dict[str, typing.Any]:
976 return await self._testsuite_action(
977 'get_dynamic_config_defaults', testsuite_skip_prepare=True,
978 )
979
980 async def _tests_control(self, body: dict) -> typing.Dict[str, typing.Any]:
981 with self._state_manager.updating_state(body):
982 async with await self._do_testsuite_action(
983 'control', json=body, testsuite_skip_prepare=True,
984 ) as response:
985 if response.status == 404:
986 raise ConfigurationError(
987 'It seems that testsuite support is not enabled '
988 'for your service',
989 )
990 response.raise_for_status()
991 return await response.json(content_type=None)
992
993 async def _suspend_periodic_tasks(self):
994 if (
995 self._periodic_tasks.tasks_to_suspend
996 != self._periodic_tasks.suspended_tasks
997 ):
998 await self._testsuite_action(
999 'suspend_periodic_tasks',
1000 names=sorted(self._periodic_tasks.tasks_to_suspend),
1001 )
1002 self._periodic_tasks.suspended_tasks = set(
1003 self._periodic_tasks.tasks_to_suspend,
1004 )
1005
1006 def _do_testsuite_action(self, action, **kwargs):
1007 if not self._config.testsuite_action_path:
1008 raise ConfigurationError(
1009 'tests-control component is not properly configured',
1010 )
1011 path = self._config.testsuite_action_path.format(action=action)
1012 return self.post(path, **kwargs)
1013
1014 async def _testsuite_action(
1015 self, action, *, testsuite_skip_prepare=False, **kwargs,
1016 ):
1017 async with await self._do_testsuite_action(
1018 action,
1019 json=kwargs,
1020 testsuite_skip_prepare=testsuite_skip_prepare,
1021 ) as response:
1022 if response.status == 500:
1023 raise TestsuiteActionFailed
1024 response.raise_for_status()
1025 return await response.json(content_type=None)
1026
1027 async def _prepare(self) -> None:
1028 with self._state_manager.cache_control_update() as pending_update:
1029 if pending_update:
1030 await self._tests_control(pending_update)
1031
1032 async def _request( # pylint: disable=arguments-differ
1033 self,
1034 http_method: str,
1035 path: str,
1036 headers: typing.Optional[typing.Dict[str, str]] = None,
1037 bearer: typing.Optional[str] = None,
1038 x_real_ip: typing.Optional[str] = None,
1039 *,
1040 testsuite_skip_prepare: bool = False,
1041 **kwargs,
1042 ) -> aiohttp.ClientResponse:
1043 if not testsuite_skip_prepare:
1044 await self._prepare()
1045
1046 response = await super()._request(
1047 http_method, path, headers, bearer, x_real_ip, **kwargs,
1048 )
1049 if self._api_coverage_report:
1050 self._api_coverage_report.update_usage_stat(
1051 path, http_method, response.status, response.content_type,
1052 )
1053
1054 return response
1055
1056
1057# @endcond
1058
1059
1061 """
1062 Asyncio userver client, typically retrieved from
1063 @ref service_client "plugins.service_client.service_client"
1064 fixture.
1065
1066 Compatible with werkzeug interface.
1067
1068 @ingroup userver_testsuite
1069 """
1070
1071 PeriodicTaskFailed = PeriodicTaskFailed
1072 TestsuiteActionFailed = TestsuiteActionFailed
1073 TestsuiteTaskNotFound = TestsuiteTaskNotFound
1074 TestsuiteTaskConflict = TestsuiteTaskConflict
1075 TestsuiteTaskFailed = TestsuiteTaskFailed
1076
1077 def _wrap_client_response(
1078 self, response: aiohttp.ClientResponse,
1079 ) -> typing.Awaitable[http.ClientResponse]:
1080 return http.wrap_client_response(
1081 response, json_loads=approx.json_loads,
1082 )
1083
1084 @_wrap_client_error
1085 async def run_periodic_task(self, name):
1086 await self._client.run_periodic_task(name)
1087
1088 @_wrap_client_error
1089 async def suspend_periodic_tasks(self, names: typing.List[str]) -> None:
1090 await self._client.suspend_periodic_tasks(names)
1091
1092 @_wrap_client_error
1093 async def resume_periodic_tasks(self, names: typing.List[str]) -> None:
1094 await self._client.resume_periodic_tasks(names)
1095
1096 @_wrap_client_error
1097 async def resume_all_periodic_tasks(self) -> None:
1098 await self._client.resume_all_periodic_tasks()
1099
1100 @_wrap_client_error
1101 async def write_cache_dumps(
1102 self, names: typing.List[str], *, testsuite_skip_prepare=False,
1103 ) -> None:
1104 await self._client.write_cache_dumps(
1105 names=names, testsuite_skip_prepare=testsuite_skip_prepare,
1106 )
1107
1108 @_wrap_client_error
1109 async def read_cache_dumps(
1110 self, names: typing.List[str], *, testsuite_skip_prepare=False,
1111 ) -> None:
1112 await self._client.read_cache_dumps(
1113 names=names, testsuite_skip_prepare=testsuite_skip_prepare,
1114 )
1115
1116 async def run_task(self, name: str) -> None:
1117 await self._client.run_task(name)
1118
1119 async def run_distlock_task(self, name: str) -> None:
1120 await self._client.run_distlock_task(name)
1121
1122 async def reset_metrics(self) -> None:
1123 """
1124 Calls `ResetMetric(metric);` for each metric that has such C++ function
1125 """
1126 await self._client.reset_metrics()
1127
1129 self, *, prefix: typing.Optional[str] = None,
1130 ) -> typing.Dict[str, typing.List[typing.Dict[str, str]]]:
1131 """
1132 Reports metrics related issues that could be encountered on
1133 different monitoring systems.
1134
1135 @sa @ref utils::statistics::GetPortabilityInfo
1136 """
1137 return await self._client.metrics_portability(prefix=prefix)
1138
1139 def list_tasks(self) -> typing.List[str]:
1140 return self._client.list_tasks()
1141
1142 def spawn_task(self, name: str):
1143 return self._client.spawn_task(name)
1144
1146 self,
1147 *,
1148 log_level: str = 'DEBUG',
1149 testsuite_skip_prepare: bool = False,
1150 ):
1151 """
1152 Captures logs from the service.
1153
1154 @param log_level Do not capture logs below this level.
1155
1156 @see @ref testsuite_logs_capture
1157 """
1158 return self._client.capture_logs(
1159 log_level=log_level, testsuite_skip_prepare=testsuite_skip_prepare,
1160 )
1161
1162 @_wrap_client_error
1164 self,
1165 *,
1166 clean_update: bool = True,
1167 cache_names: typing.Optional[typing.List[str]] = None,
1168 testsuite_skip_prepare: bool = False,
1169 ) -> None:
1170 """
1171 Send request to service to update caches.
1172
1173 @param clean_update if False, service will do a faster incremental
1174 update of caches whenever possible.
1175 @param cache_names which caches specifically should be updated;
1176 update all if None.
1177 @param testsuite_skip_prepare if False, service will automatically do
1178 update_server_state().
1179 """
1180 __tracebackhide__ = True
1181 await self._client.invalidate_caches(
1182 clean_update=clean_update,
1183 cache_names=cache_names,
1184 testsuite_skip_prepare=testsuite_skip_prepare,
1185 )
1186
1187 @_wrap_client_error
1188 async def tests_control(
1189 self, *args, **kwargs,
1190 ) -> typing.Dict[str, typing.Any]:
1191 return await self._client.tests_control(*args, **kwargs)
1192
1193 @_wrap_client_error
1194 async def update_server_state(self) -> None:
1195 """
1196 Update service-side state through http call to 'tests/control':
1197 - clear dirty (from other tests) caches
1198 - set service-side mocked time,
1199 - resume / suspend periodic tasks
1200 - enable testpoints
1201 If service is up-to-date, does nothing.
1202 """
1203 await self._client.update_server_state()
1204
1205 @_wrap_client_error
1206 async def enable_testpoints(self, *args, **kwargs) -> None:
1207 """
1208 Send list of handled testpoint pats to service. For these paths service
1209 will no more skip http calls from TESTPOINT(...) macro.
1210
1211 @param no_auto_cache_cleanup prevent automatic cache cleanup.
1212 When calling service client first time in scope of current test, client
1213 makes additional http call to `tests/control` to update caches, to get
1214 rid of data from previous test.
1215 """
1216 await self._client.enable_testpoints(*args, **kwargs)
1217
1218 @_wrap_client_error
1219 async def get_dynamic_config_defaults(
1220 self,
1221 ) -> typing.Dict[str, typing.Any]:
1222 return await self._client.get_dynamic_config_defaults()
1223
1224
1225@dataclasses.dataclass
1227 """Reflects the (supposed) current service state."""
1228
1229 invalidation_state: caches.InvalidationState
1230 now: typing.Optional[str] = _UNKNOWN_STATE
1231 testpoints: typing.FrozenSet[str] = frozenset([_UNKNOWN_STATE])
1232
1233
1235 """
1236 Used for computing the requests that we need to automatically align
1237 the service state with the test fixtures state.
1238 """
1239
1240 def __init__(
1241 self,
1242 *,
1243 mocked_time,
1244 testpoint,
1245 testpoint_control,
1246 invalidation_state: caches.InvalidationState,
1247 cache_control: typing.Optional[caches.CacheControl],
1248 ):
1249 self._state = _State(
1250 invalidation_state=copy.deepcopy(invalidation_state),
1251 )
1252 self._mocked_time = mocked_time
1253 self._testpoint = testpoint
1254 self._testpoint_control = testpoint_control
1255 self._invalidation_state = invalidation_state
1256 self._cache_control = cache_control
1257
1258 @contextlib.contextmanager
1259 def updating_state(self, body: typing.Dict[str, typing.Any]):
1260 """
1261 Whenever `tests_control` handler is invoked
1262 (by the client itself during `prepare` or manually by the user),
1263 we need to synchronize `_state` with the (supposed) service state.
1264 The state update is decoded from the request body.
1265 """
1266 saved_state = copy.deepcopy(self._state)
1267 try:
1268 self._update_state(body)
1269 self._apply_new_state()
1270 yield
1271 except Exception: # noqa
1272 self._state = saved_state
1273 self._apply_new_state()
1274 raise
1275
1276 def get_pending_update(self) -> typing.Dict[str, typing.Any]:
1277 """
1278 Compose the body of the `tests_control` request required to completely
1279 synchronize the service state with the state of test fixtures.
1280 """
1281 body: typing.Dict[str, typing.Any] = {}
1282
1283 if self._invalidation_state.has_caches_to_update:
1284 body['invalidate_caches'] = {'update_type': 'full'}
1285 if not self._invalidation_state.should_update_all_caches:
1286 body['invalidate_caches']['names'] = list(
1287 self._invalidation_state.caches_to_update,
1288 )
1289
1290 desired_testpoints = self._testpoint.keys()
1291 if self._state.testpoints != frozenset(desired_testpoints):
1292 body['testpoints'] = sorted(desired_testpoints)
1293
1294 desired_now = self._get_desired_now()
1295 if self._state.now != desired_now:
1296 body['mock_now'] = desired_now
1297
1298 return body
1299
1300 @contextlib.contextmanager
1301 def cache_control_update(self) -> typing.ContextManager[typing.Dict]:
1302 pending_update = self.get_pending_update()
1303 invalidate_caches = pending_update.get('invalidate_caches')
1304 if not invalidate_caches or not self._cache_control:
1305 yield pending_update
1306 else:
1307 cache_names = invalidate_caches.get('names')
1308 staged, actions = self._cache_control.query_caches(cache_names)
1309 self._apply_cache_control_actions(invalidate_caches, actions)
1310 yield pending_update
1311 self._cache_control.commit_staged(staged)
1312
1313 @staticmethod
1314 def _apply_cache_control_actions(
1315 invalidate_caches: typing.Dict,
1316 actions: typing.List[typing.Tuple[str, caches.CacheControlAction]],
1317 ) -> None:
1318 cache_names = invalidate_caches.get('names')
1319 exclude_names = invalidate_caches.setdefault('exclude_names', [])
1320 force_incremental_names = invalidate_caches.setdefault(
1321 'force_incremental_names', [],
1322 )
1323 for cache_name, action in actions:
1324 if action == caches.CacheControlAction.INCREMENTAL:
1325 force_incremental_names.append(cache_name)
1326 elif action == caches.CacheControlAction.EXCLUDE:
1327 if cache_names is not None:
1328 cache_names.remove(cache_name)
1329 else:
1330 exclude_names.append(cache_name)
1331
1332 def _update_state(self, body: typing.Dict[str, typing.Any]) -> None:
1333 body_invalidate_caches = body.get('invalidate_caches', {})
1334 update_type = body_invalidate_caches.get('update_type', 'full')
1335 body_cache_names = body_invalidate_caches.get('names', None)
1336 # An incremental update is considered insufficient to bring a cache
1337 # to a known state.
1338 if body_invalidate_caches and update_type == 'full':
1339 if body_cache_names is None:
1340 self._state.invalidation_state.on_all_caches_updated()
1341 else:
1342 self._state.invalidation_state.on_caches_updated(
1343 body_cache_names,
1344 )
1345
1346 if 'mock_now' in body:
1347 self._state.now = body['mock_now']
1348
1349 testpoints: typing.Optional[typing.List[str]] = body.get(
1350 'testpoints', None,
1351 )
1352 if testpoints is not None:
1353 self._state.testpoints = frozenset(testpoints)
1354
1356 """Apply new state to related components."""
1357 self._testpoint_control.enabled_testpoints = self._state.testpoints
1358 self._invalidation_state.assign_copy(self._state.invalidation_state)
1359
1360 def _get_desired_now(self) -> typing.Optional[str]:
1361 if self._mocked_time.is_enabled:
1362 return utils.timestring(self._mocked_time.now())
1363 return None
1364
1365
1366async def _task_check_response(name: str, response) -> dict:
1367 async with response:
1368 if response.status == 404:
1369 raise TestsuiteTaskNotFound(f'Testsuite task {name!r} not found')
1370 if response.status == 409:
1371 raise TestsuiteTaskConflict(f'Testsuite task {name!r} conflict')
1372 assert response.status == 200
1373 data = await response.json()
1374 if not data.get('status', True):
1375 raise TestsuiteTaskFailed(name, data['reason'])
1376 return data