userver: /home/antonyzhilin/arcadia/taxi/uservices/userver/testsuite/pytest_plugins/pytest_userver/chaos.py Source File
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages Concepts
chaos.py
1# pylint: disable=too-many-lines
2"""
3Python module that provides testsuite support for
4chaos tests; see
5@ref scripts/docs/en/userver/chaos_testing.md for an introduction.
6
7@ingroup userver_testsuite
8"""
9
10import asyncio
11import dataclasses
12import functools
13import io
14import logging
15import random
16import re
17import socket
18import time
19import typing
20
21import pytest
22
23from testsuite import asyncio_socket
24from testsuite.utils import callinfo
25
26
27class BaseError(Exception):
28 pass
29
30
32 pass
33
34
35@dataclasses.dataclass(frozen=True)
37 """
38 Class that describes the route for TcpGate or UdpGate.
39
40 Use `port_for_client == 0` to bind to some unused port. In that case the
41 actual address could be retrieved via BaseGate.get_sockname_for_clients().
42
43 @ingroup userver_testsuite
44 """
45
46 name: str
47 host_to_server: str
48 port_to_server: int
49 host_for_client: str = '127.0.0.1'
50 port_for_client: int = 0
51
52
53# @cond
54
55# https://docs.python.org/3/library/socket.html#socket.socket.recv
56RECV_MAX_SIZE = 4096
57MAX_DELAY = 60.0
58
59
60logger = logging.getLogger(__name__)
61
62
63Address = typing.Tuple[str, int]
64EvLoop = typing.Any
65Socket = socket.socket
66Interceptor = typing.Callable[
67 [EvLoop, Socket, Socket],
68 typing.Coroutine[typing.Any, typing.Any, None],
69]
70
71
72class GateException(Exception):
73 pass
74
75
76class GateInterceptException(Exception):
77 pass
78
79
80async def _intercept_ok(
81 loop: EvLoop,
82 socket_from: Socket,
83 socket_to: Socket,
84) -> None:
85 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
86 if not data:
87 raise ConnectionClosedError()
88 await loop.sock_sendall(socket_to, data)
89
90
91async def _intercept_drop(
92 loop: EvLoop,
93 socket_from: Socket,
94 socket_to: Socket,
95) -> None:
96 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
97 if not data:
98 raise ConnectionClosedError()
99
100
101async def _intercept_delay(
102 delay: float,
103 loop: EvLoop,
104 socket_from: Socket,
105 socket_to: Socket,
106) -> None:
107 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
108 if not data:
109 raise ConnectionClosedError()
110 await asyncio.sleep(delay)
111 await loop.sock_sendall(socket_to, data)
112
113
114async def _intercept_close_on_data(
115 loop: EvLoop,
116 socket_from: Socket,
117 socket_to: Socket,
118) -> None:
119 data = await loop.sock_recv(socket_from, 1)
120 if not data:
121 raise ConnectionClosedError()
122 raise GateInterceptException('Closing socket on data')
123
124
125async def _intercept_corrupt(
126 loop: EvLoop,
127 socket_from: Socket,
128 socket_to: Socket,
129) -> None:
130 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
131 if not data:
132 raise ConnectionClosedError()
133 await loop.sock_sendall(socket_to, bytearray([not x for x in data]))
134
135
136class _InterceptBpsLimit:
137 def __init__(self, bytes_per_second: float):
138 assert bytes_per_second >= 1
139 self._bytes_per_second = bytes_per_second
140 self._time_last_added = 0.0
141 self._bytes_left = self._bytes_per_second
142
143 def _update_limit(self) -> None:
144 current_time = time.monotonic()
145 elapsed = current_time - self._time_last_added
146 bytes_addition = self._bytes_per_second * elapsed
147 if bytes_addition > 0:
148 self._bytes_left += bytes_addition
149 self._time_last_added = current_time
150
151 if self._bytes_left > self._bytes_per_second:
152 self._bytes_left = self._bytes_per_second
153
154 async def __call__(
155 self,
156 loop: EvLoop,
157 socket_from: Socket,
158 socket_to: Socket,
159 ) -> None:
160 self._update_limit()
161
162 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
163 if bytes_to_recv > 0:
164 data = await loop.sock_recv(socket_from, bytes_to_recv)
165 if not data:
166 raise ConnectionClosedError()
167 self._bytes_left -= len(data)
168
169 await loop.sock_sendall(socket_to, data)
170 else:
171 logger.info('Socket hits the bytes per second limit')
172 await asyncio.sleep(1.0 / self._bytes_per_second)
173
174
175class _InterceptTimeLimit:
176 def __init__(self, timeout: float, jitter: float):
177 self._sockets: typing.Dict[Socket, float] = {}
178 assert timeout >= 0.0
179 self._timeout = timeout
180 assert jitter >= 0.0
181 self._jitter = jitter
182
183 def raise_if_timed_out(self, socket_from: Socket) -> None:
184 if socket_from not in self._sockets:
185 jitter = self._jitter * random.random()
186 expire_at = time.monotonic() + self._timeout + jitter
187 self._sockets[socket_from] = expire_at
188
189 if self._sockets[socket_from] <= time.monotonic():
190 del self._sockets[socket_from]
191 raise GateInterceptException('Socket hits the time limit')
192
193 async def __call__(
194 self,
195 loop: EvLoop,
196 socket_from: Socket,
197 socket_to: Socket,
198 ) -> None:
199 self.raise_if_timed_out(socket_from)
200 await _intercept_ok(loop, socket_from, socket_to)
201
202
203class _InterceptSmallerParts:
204 def __init__(self, max_size: int, sleep_per_packet: float):
205 assert max_size > 0
206 self._max_size = max_size
207 self._sleep_per_packet = sleep_per_packet
208
209 async def __call__(
210 self,
211 loop: EvLoop,
212 socket_from: Socket,
213 socket_to: Socket,
214 ) -> None:
215 data = await loop.sock_recv(socket_from, self._max_size)
216 if not data:
217 raise ConnectionClosedError()
218 await asyncio.sleep(self._sleep_per_packet)
219 await loop.sock_sendall(socket_to, data)
220
221
222class _InterceptConcatPackets:
223 def __init__(self, packet_size: int):
224 assert packet_size >= 0
225 self._packet_size = packet_size
226 self._expire_at: typing.Optional[float] = None
227 self._buf = io.BytesIO()
228
229 async def __call__(
230 self,
231 loop: EvLoop,
232 socket_from: Socket,
233 socket_to: Socket,
234 ) -> None:
235 if self._expire_at is None:
236 self._expire_at = time.monotonic() + MAX_DELAY
237
238 if self._expire_at <= time.monotonic():
239 pytest.fail(
240 f'Failed to make a packet of sufficient size in {MAX_DELAY} '
241 'seconds. Check the test logic, it should end with checking '
242 'that the data was sent and by calling TcpGate function '
243 'to_client_pass() to pass the remaining packets.',
244 )
245
246 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
247 if not data:
248 raise ConnectionClosedError()
249 self._buf.write(data)
250 if self._buf.tell() >= self._packet_size:
251 await loop.sock_sendall(socket_to, self._buf.getvalue())
252 self._buf = io.BytesIO()
253 self._expire_at = None
254
255
256class _InterceptBytesLimit:
257 def __init__(self, bytes_limit: int, gate: 'BaseGate'):
258 assert bytes_limit >= 0
259 self._bytes_limit = bytes_limit
260 self._bytes_remain = self._bytes_limit
261 self._gate = gate
262
263 async def __call__(
264 self,
265 loop: EvLoop,
266 socket_from: Socket,
267 socket_to: Socket,
268 ) -> None:
269 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
270 if not data:
271 raise ConnectionClosedError()
272 if self._bytes_remain <= len(data):
273 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
274 await self._gate.sockets_close()
275 self._bytes_remain = self._bytes_limit
276 raise GateInterceptException('Data transmission limit reached')
277 self._bytes_remain -= len(data)
278 await loop.sock_sendall(socket_to, data)
279
280
281class _InterceptSubstitute:
282 def __init__(self, pattern: str, repl: str, encoding='utf-8'):
283 self._pattern = re.compile(pattern)
284 self._repl = repl
285 self._encoding = encoding
286
287 async def __call__(
288 self,
289 loop: EvLoop,
290 socket_from: Socket,
291 socket_to: Socket,
292 ) -> None:
293 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
294 if not data:
295 raise ConnectionClosedError()
296 try:
297 res = self._pattern.sub(self._repl, data.decode(self._encoding))
298 data = res.encode(self._encoding)
299 except UnicodeError:
300 pass
301 await loop.sock_sendall(socket_to, data)
302
303
304async def _cancel_and_join(task: typing.Optional[asyncio.Task]) -> None:
305 if not task or task.cancelled():
306 return
307
308 try:
309 task.cancel()
310 await task
311 except asyncio.CancelledError:
312 return
313 except Exception: # pylint: disable=broad-except
314 logger.exception('Exception in _cancel_and_join')
315
316
317def _make_socket_nonblocking(sock: Socket) -> None:
318 sock.setblocking(False)
319 if sock.type == socket.SOCK_STREAM:
320 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
321
322
323class _UdpDemuxSocketMock:
324 """
325 Emulates a point-to-point connection over UDP socket
326 with a non-blocking socket interface
327 """
328
329 def gettimeout(self):
330 return self._sock.gettimeout()
331
332 def __init__(self, sock: Socket, peer_address: Address):
333 self._sock: Socket = sock
334 self._peeraddr: Address = peer_address
335
336 sockpair = socket.socketpair(type=socket.SOCK_DGRAM)
337 self._demux_in: Socket = sockpair[0]
338 self._demux_out: Socket = sockpair[1]
339 _make_socket_nonblocking(self._demux_in)
340 _make_socket_nonblocking(self._demux_out)
341 self._is_active: bool = True
342
343 @property
344 def peer_address(self):
345 return self._peeraddr
346
347 async def push(self, loop: EvLoop, data: bytes):
348 return await loop.sock_sendall(self._demux_in, data)
349
350 def is_active(self):
351 return self._is_active
352
353 def close(self):
354 self._is_active = False
355 self._demux_out.close()
356 self._demux_in.close()
357
358 def recvfrom(self, bufsize: int, flags: int = 0):
359 return self._demux_out.recvfrom(bufsize, flags)
360
361 def recv(self, bufsize: int, flags: int = 0):
362 return self._demux_out.recv(bufsize, flags)
363
364 def get_demux_out(self):
365 return self._demux_out
366
367 def fileno(self):
368 return self._demux_out.fileno()
369
370 def send(self, data: bytes):
371 return self._sock.sendto(data, self._peeraddr)
372
373
374class InterceptTask:
375 def __init__(self, socket_from, socket_to, interceptor):
376 self._socket_from = socket_from
377 self._socket_to = socket_to
378 self._condition = asyncio.Condition()
379 self._interceptor = interceptor
380
381 def get_interceptor(self):
382 return self._interceptor
383
384 async def set_interceptor(self, interceptor):
385 async with self._condition:
386 self._interceptor = interceptor
387 self._condition.notify()
388
389 async def run(self):
390 loop = asyncio.get_running_loop()
391 while True:
392 # Applies new interceptors faster.
393 #
394 # To avoid long awaiting on sock_recv in an outdated
395 # interceptor we wait for data before grabbing and applying
396 # the interceptor.
397 await _wait_for_data(self._socket_from)
398
399 # Wait for interceptor attched
400 async with self._condition:
401 interceptor = await self._condition.wait_for(self.get_interceptor)
402
403 logging.trace('running interceptor: %s', interceptor)
404 await interceptor(loop, self._socket_from, self._socket_to)
405
406
407class _SocketsPaired:
408 def __init__(
409 self,
410 proxy_name: str,
411 loop: EvLoop,
412 client: typing.Union[socket.socket, _UdpDemuxSocketMock],
413 server: socket.socket,
414 to_server_intercept: Interceptor,
415 to_client_intercept: Interceptor,
416 ) -> None:
417 self._proxy_name = proxy_name
418
419 self._client = client
420 self._server = server
421
422 self._task_to_server = InterceptTask(client, server, to_server_intercept)
423 self._task_to_client = InterceptTask(server, client, to_client_intercept)
424
425 self._task = asyncio.create_task(self._run())
426 self._interceptor_tasks = []
427
428 async def set_to_server_interceptor(self, interceptor):
429 await self._task_to_server.set_interceptor(interceptor)
430
431 async def set_to_client_interceptor(self, interceptor: Interceptor):
432 await self._task_to_client.set_interceptor(interceptor)
433
434 async def shutdown(self) -> None:
435 for task in self._interceptor_tasks:
436 await _cancel_and_join(task)
437 await _cancel_and_join(self._task)
438
439 def is_active(self) -> bool:
440 return not self._task.done()
441
442 def info(self) -> str:
443 if not self.is_active():
444 return '<inactive>'
445
446 return f'client fd={self._client.fileno()} <=> server fd={self._server.fileno()}'
447
448 async def _run(self):
449 self._interceptor_tasks = [
450 asyncio.create_task(obj.run()) for obj in (self._task_to_server, self._task_to_client)
451 ]
452 try:
453 done, _ = await asyncio.wait(self._interceptor_tasks, return_when=asyncio.FIRST_EXCEPTION)
454 for task in done:
455 task.result()
456 except GateInterceptException as exc:
457 logger.info('In "%s": %s', self._proxy_name, exc)
458 except socket.error as exc:
459 logger.error('Exception in "%s": %s', self._proxy_name, exc)
460 except Exception:
461 logger.exception('interceptor failed')
462 finally:
463 for task in self._interceptor_tasks:
464 task.cancel()
465
466 # Closing the sockets here so that the self.shutdown()
467 # returns only when the sockets are actually closed
468 for sock in self._server, self._client:
469 try:
470 sock.close()
471 except socket.error:
472 logger.exception('Exception in "%s" on closing %s:', self._proxy_name, sock)
473
474
475# @endcond
476
477
479 """
480 This base class maintain endpoints of two types:
481
482 Server-side endpoints to receive messages from clients. Address of this
483 endpoint is described by (host_for_client, port_for_client).
484
485 Client-side endpoints to forward messages to server. Server must listen on
486 (host_to_server, port_to_server).
487
488 Asynchronously concurrently passes data from client to server and from
489 server to client, allowing intercepting the data, injecting delays and
490 dropping connections.
491
492 @warning Do not use this class itself. Use one of the specifications
493 TcpGate or UdpGate
494
495 @ingroup userver_testsuite
496
497 @see @ref scripts/docs/en/userver/chaos_testing.md
498 """
499
500 _NOT_IMPLEMENTED_MESSAGE = 'Do not use BaseGate itself, use one of specializations TcpGate or UdpGate'
501
502 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None) -> None:
503 self._route = route
504 if loop is None:
505 loop = asyncio.get_running_loop()
506 self._loop = loop
507
508 self._to_server_intercept: Interceptor = _intercept_ok
509 self._to_client_intercept: Interceptor = _intercept_ok
510
511 self._accept_sockets: typing.List[socket.socket] = []
512 self._accept_tasks: typing.List[asyncio.Task[None]] = []
513
514 self._sockets: typing.Set[_SocketsPaired] = set()
515
516 async def __aenter__(self) -> 'BaseGate':
517 self.start()
518 return self
519
520 async def __aexit__(self, exc_type, exc_value, traceback) -> None:
521 await self.stop()
522
523 def _create_accepting_sockets(self) -> typing.List[Socket]:
524 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
525
526 def start(self):
527 """Open the socket and start accepting tasks"""
528 if self._accept_sockets:
529 return
530
532
533 if not self._accept_sockets:
534 raise GateException(
535 f'Could not resolve hostname {self._route.host_for_client}',
536 )
537
538 if self._route.port_for_client == 0:
539 # In case of stop()+start() bind to the same port
540 self._route = GateRoute(
541 name=self._route.name,
542 host_to_server=self._route.host_to_server,
543 port_to_server=self._route.port_to_server,
544 host_for_client=self._accept_sockets[0].getsockname()[0],
545 port_for_client=self._accept_sockets[0].getsockname()[1],
546 )
547
548 self.start_accepting()
549
550 def start_accepting(self) -> None:
551 """Start accepting tasks"""
552 assert self._accept_sockets
553 if not all(tsk.done() for tsk in self._accept_tasks):
554 return
555
556 self._accept_tasks.clear()
557 for sock in self._accept_sockets:
558 self._accept_tasks.append(
559 asyncio.create_task(self._do_accept(sock)),
560 )
561
562 async def stop_accepting(self) -> None:
563 """
564 Stop accepting tasks without closing the accepting socket.
565 """
566 for tsk in self._accept_tasks:
567 await _cancel_and_join(tsk)
568 self._accept_tasks.clear()
569
570 async def stop(self) -> None:
571 """
572 Stop accepting tasks, close all the sockets
573 """
574 if not self._accept_sockets and not self._sockets:
575 return
576
577 await self.to_server_pass()
578 await self.to_client_pass()
579
580 await self.stop_accepting()
581 logger.info('Before close() %s', self.info())
582 await self.sockets_close()
583 assert not self._sockets
584
585 for sock in self._accept_sockets:
586 sock.close()
587 self._accept_sockets.clear()
588 logger.info('Stopped. %s', self.info())
589
590 async def sockets_close(
591 self,
592 *,
593 count: typing.Optional[int] = None,
594 ) -> None:
595 """Close all the connection going through the gate"""
596 for x in list(self._sockets)[0:count]:
597 await x.shutdown()
598 self._collect_garbage()
599
600 def get_sockname_for_clients(self) -> Address:
601 """
602 Returns the client socket address that the gate listens on.
603
604 This function allows to use 0 in GateRoute.port_for_client and retrieve
605 the actual port and host.
606 """
607 assert self._route.port_for_client != 0, ('Gate was not started and the port_for_client is still 0',)
608 return (self._route.host_for_client, self._route.port_for_client)
609
610 def info(self) -> str:
611 """Print info on open sockets"""
612 if not self._sockets:
613 return f'"{self._route.name}" no active sockets'
614
615 return f'"{self._route.name}" active sockets:\n\t' + '\n\t'.join(x.info() for x in self._sockets)
616
617 def _collect_garbage(self) -> None:
618 self._sockets = {x for x in self._sockets if x.is_active()}
619
620 async def _do_accept(self, accept_sock: Socket) -> None:
621 """
622 This task should wait for connections and create SocketPair
623 """
624 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
625
626 async def set_to_server_interceptor(self, interceptor: Interceptor) -> callinfo.AsyncCallQueue:
627 """
628 Replace existing interceptor of client to server data with a custom
629 """
630 self._to_server_intercept = _create_callqueue(interceptor)
631 for x in self._sockets:
632 await x.set_to_server_interceptor(self._to_server_intercept)
633 return self._to_server_intercept
634
635 async def set_to_client_interceptor(self, interceptor: Interceptor) -> callinfo.AsyncCallQueue:
636 """
637 Replace existing interceptor of server to client data with a custom
638
639 """
640 if interceptor is not None:
641 self._to_client_intercept = _create_callqueue(interceptor)
642 else:
643 self._to_client_intercept = None
644 for x in self._sockets:
645 await x.set_to_client_interceptor(self._to_client_intercept)
646 return self._to_client_intercept
647
648 async def to_server_pass(self) -> callinfo.AsyncCallQueue:
649 """Pass data as is"""
650 logging.trace('to_server_pass')
651 return await self.set_to_server_interceptor(_intercept_ok)
652
653 async def to_client_pass(self) -> callinfo.AsyncCallQueue:
654 """Pass data as is"""
655 logging.trace('to_client_pass')
656 return await self.set_to_client_interceptor(_intercept_ok)
657
658 async def to_server_noop(self) -> callinfo.AsyncCallQueue:
659 """Do not read data, causing client to keep multiple data"""
660 logging.trace('to_server_noop')
661 return await self.set_to_server_interceptor(None)
662
663 async def to_client_noop(self) -> callinfo.AsyncCallQueue:
664 """Do not read data, causing server to keep multiple data"""
665 logging.trace('to_client_noop')
666 return await self.set_to_client_interceptor(None)
667
668 async def to_server_drop(self) -> callinfo.AsyncCallQueue:
669 """Read and discard data"""
670 logging.trace('to_server_drop')
671 return await self.set_to_server_interceptor(_intercept_drop)
672
673 async def to_client_drop(self) -> callinfo.AsyncCallQueue:
674 """Read and discard data"""
675 logging.trace('to_client_drop')
676 return await self.set_to_client_interceptor(_intercept_drop)
677
678 async def to_server_delay(self, delay: float) -> callinfo.AsyncCallQueue:
679 """Delay data transmission"""
680 logging.trace('to_server_delay, delay: %s', delay)
681
682 async def _intercept_delay_bound(
683 loop: EvLoop,
684 socket_from: Socket,
685 socket_to: Socket,
686 ) -> None:
687 await _intercept_delay(delay, loop, socket_from, socket_to)
688
689 return await self.set_to_server_interceptor(_intercept_delay_bound)
690
691 async def to_client_delay(self, delay: float) -> callinfo.AsyncCallQueue:
692 """Delay data transmission"""
693 logging.trace('to_client_delay, delay: %s', delay)
694
695 async def _intercept_delay_bound(
696 loop: EvLoop,
697 socket_from: Socket,
698 socket_to: Socket,
699 ) -> None:
700 await _intercept_delay(delay, loop, socket_from, socket_to)
701
702 return await self.set_to_client_interceptor(_intercept_delay_bound)
703
704 async def to_server_close_on_data(self) -> callinfo.AsyncCallQueue:
705 """Close on first bytes of data from client"""
706 logging.trace('to_server_close_on_data')
707 return await self.set_to_server_interceptor(_intercept_close_on_data)
708
709 async def to_client_close_on_data(self) -> callinfo.AsyncCallQueue:
710 """Close on first bytes of data from server"""
711 logging.trace('to_client_close_on_data')
712 return await self.set_to_client_interceptor(_intercept_close_on_data)
713
714 async def to_server_corrupt_data(self) -> callinfo.AsyncCallQueue:
715 """Corrupt data received from client"""
716 logging.trace('to_server_corrupt_data')
717 return await self.set_to_server_interceptor(_intercept_corrupt)
718
719 async def to_client_corrupt_data(self) -> callinfo.AsyncCallQueue:
720 """Corrupt data received from server"""
721 logging.trace('to_client_corrupt_data')
722 return await self.set_to_client_interceptor(_intercept_corrupt)
723
724 async def to_server_limit_bps(self, bytes_per_second: float) -> callinfo.AsyncCallQueue:
725 """Limit bytes per second transmission by network from client"""
726 logging.trace(
727 'to_server_limit_bps, bytes_per_second: %s',
728 bytes_per_second,
729 )
730 return await self.set_to_server_interceptor(_InterceptBpsLimit(bytes_per_second))
731
732 async def to_client_limit_bps(self, bytes_per_second: float) -> callinfo.AsyncCallQueue:
733 """Limit bytes per second transmission by network from server"""
734 logging.trace(
735 'to_client_limit_bps, bytes_per_second: %s',
736 bytes_per_second,
737 )
738 return await self.set_to_client_interceptor(_InterceptBpsLimit(bytes_per_second))
739
740 async def to_server_limit_time(self, timeout: float, jitter: float) -> callinfo.AsyncCallQueue:
741 """Limit connection lifetime on receive of first bytes from client"""
742 logging.trace(
743 'to_server_limit_time, timeout: %s, jitter: %s',
744 timeout,
745 jitter,
746 )
747 return await self.set_to_server_interceptor(_InterceptTimeLimit(timeout, jitter))
748
749 async def to_client_limit_time(self, timeout: float, jitter: float) -> callinfo.AsyncCallQueue:
750 """Limit connection lifetime on receive of first bytes from server"""
751 logging.trace(
752 'to_client_limit_time, timeout: %s, jitter: %s',
753 timeout,
754 jitter,
755 )
756 return await self.set_to_client_interceptor(_InterceptTimeLimit(timeout, jitter))
757
758 async def to_server_smaller_parts(
759 self,
760 max_size: int,
761 *,
762 sleep_per_packet: float = 0,
763 ) -> callinfo.AsyncCallQueue:
764 """
765 Pass data to server in smaller parts
766
767 @param max_size Max packet size to send to server
768 @param sleep_per_packet Optional sleep interval per packet, seconds
769 """
770 logging.trace('to_server_smaller_parts, max_size: %s', max_size)
771 return await self.set_to_server_interceptor(
772 _InterceptSmallerParts(max_size, sleep_per_packet),
773 )
774
775 async def to_client_smaller_parts(
776 self,
777 max_size: int,
778 *,
779 sleep_per_packet: float = 0,
780 ) -> callinfo.AsyncCallQueue:
781 """
782 Pass data to client in smaller parts
783
784 @param max_size Max packet size to send to client
785 @param sleep_per_packet Optional sleep interval per packet, seconds
786 """
787 logging.trace('to_client_smaller_parts, max_size: %s', max_size)
788 return await self.set_to_client_interceptor(
789 _InterceptSmallerParts(max_size, sleep_per_packet),
790 )
791
792 async def to_server_concat_packets(self, packet_size: int) -> callinfo.AsyncCallQueue:
793 """
794 Pass data in bigger parts
795 @param packet_size minimal size of the resulting packet
796 """
797 logging.trace('to_server_concat_packets, packet_size: %s', packet_size)
798 return await self.set_to_server_interceptor(_InterceptConcatPackets(packet_size))
799
800 async def to_client_concat_packets(self, packet_size: int) -> callinfo.AsyncCallQueue:
801 """
802 Pass data in bigger parts
803 @param packet_size minimal size of the resulting packet
804 """
805 logging.trace('to_client_concat_packets, packet_size: %s', packet_size)
806 return await self.set_to_client_interceptor(_InterceptConcatPackets(packet_size))
807
808 async def to_server_limit_bytes(self, bytes_limit: int) -> callinfo.AsyncCallQueue:
809 """Drop all connections each `bytes_limit` of data sent by network"""
810 logging.trace('to_server_limit_bytes, bytes_limit: %s', bytes_limit)
811 return await self.set_to_server_interceptor(_InterceptBytesLimit(bytes_limit, self))
812
813 async def to_client_limit_bytes(self, bytes_limit: int) -> callinfo.AsyncCallQueue:
814 """Drop all connections each `bytes_limit` of data sent by network"""
815 logging.trace('to_client_limit_bytes, bytes_limit: %s', bytes_limit)
816 return await self.set_to_client_interceptor(_InterceptBytesLimit(bytes_limit, self))
817
818 async def to_server_substitute(self, pattern: str, repl: str) -> callinfo.AsyncCallQueue:
819 """Apply regex substitution to data from client"""
820 logging.trace(
821 'to_server_substitute, pattern: %s, repl: %s',
822 pattern,
823 repl,
824 )
825 return await self.set_to_server_interceptor(_InterceptSubstitute(pattern, repl))
826
827 async def to_client_substitute(self, pattern: str, repl: str) -> callinfo.AsyncCallQueue:
828 """Apply regex substitution to data from server"""
829 logging.trace(
830 'to_client_substitute, pattern: %s, repl: %s',
831 pattern,
832 repl,
833 )
834 return await self.set_to_client_interceptor(_InterceptSubstitute(pattern, repl))
835
836
837class TcpGate(BaseGate):
838 """
839 Implements TCP chaos-proxy logic such as accepting incoming tcp client
840 connections. On each new connection new tcp client connects to server
841 (host_to_server, port_to_server).
842
843 @ingroup userver_testsuite
844
845 @see @ref scripts/docs/en/userver/chaos_testing.md
846 """
847
848 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None) -> None:
849 self._connected_event = asyncio.Event()
850 super().__init__(route, loop)
851
852 def connections_count(self) -> int:
853 """
854 Returns maximal amount of connections going through the gate at
855 the moment.
856
857 @warning Some of the connections could be closing, or could be opened
858 right before the function starts. Use with caution!
859 """
860 return len(self._sockets)
861
862 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
863 """
864 Wait for at least `count` connections going through the gate.
865
866 @throws asyncio.TimeoutError exception if failed to get the
867 required amount of connections in time.
868 """
869 if timeout <= 0.0:
870 while self.connections_count() < count:
871 await self._connected_event.wait()
872 self._connected_event.clear()
873 return
874
875 deadline = time.monotonic() + timeout
876 while self.connections_count() < count:
877 time_left = deadline - time.monotonic()
878 await asyncio.wait_for(
879 self._connected_event.wait(),
880 timeout=time_left,
881 )
882 self._connected_event.clear()
883
884 def _create_accepting_sockets(self) -> typing.List[Socket]:
885 res: typing.List[Socket] = []
886 for addr in socket.getaddrinfo(
887 self._route.host_for_client,
888 self._route.port_for_client,
889 type=socket.SOCK_STREAM,
890 ):
891 sock = Socket(addr[0], addr[1])
892 _make_socket_nonblocking(sock)
893 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
894 sock.bind(addr[4])
895 sock.listen()
896 logger.debug(
897 f'Accepting connections on {sock.getsockname()}, fd={sock.fileno()}',
898 )
899 res.append(sock)
900
901 return res
902
903 async def _connect_to_server(self):
904 addrs = await self._loop.getaddrinfo(
905 self._route.host_to_server,
906 self._route.port_to_server,
907 type=socket.SOCK_STREAM,
908 )
909 for addr in addrs:
910 server = Socket(addr[0], addr[1])
911 _make_socket_nonblocking(server)
912 try:
913 await self._loop.sock_connect(server, addr[4])
914 logging.trace('Connected to %s', addr[4])
915 return server
916 except Exception as exc: # pylint: disable=broad-except
917 server.close()
918 logging.warning('Could not connect to %s: %s', addr[4], exc)
919
920 async def _do_accept(self, accept_sock: Socket) -> None:
921 while True:
922 client, _ = await self._loop.sock_accept(accept_sock)
923 _make_socket_nonblocking(client)
924
925 server = await self._connect_to_server()
926 if server:
927 self._sockets.add(
928 _SocketsPaired(
929 self._route.name,
930 self._loop,
931 client,
932 server,
933 self._to_server_intercept,
934 self._to_client_intercept,
935 ),
936 )
937 self._connected_event.set()
938 else:
939 client.close()
940
941 self._collect_garbage()
942
943
944class UdpGate(BaseGate):
945 """
946 Implements UDP chaos-proxy logic such as demuxing incoming datagrams
947 from different clients.
948 Separate connections to server are made for each new client.
949
950 @ingroup userver_testsuite
951
952 @see @ref scripts/docs/en/userver/chaos_testing.md
953 """
954
955 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None):
956 self._clients: typing.Set[_UdpDemuxSocketMock] = set()
957 super().__init__(route, loop)
958
959 def is_connected(self) -> bool:
960 """
961 Returns True if there is active pair of sockets ready to transfer data
962 at the moment.
963 """
964 return len(self._sockets) > 0
965
966 def _create_accepting_sockets(self) -> typing.List[Socket]:
967 res: typing.List[Socket] = []
968 for addr in socket.getaddrinfo(
969 self._route.host_for_client,
970 self._route.port_for_client,
971 type=socket.SOCK_DGRAM,
972 ):
973 sock = socket.socket(addr[0], addr[1])
974 _make_socket_nonblocking(sock)
975 sock.bind(addr[4])
976 logger.debug(f'Accepting connections on {sock.getsockname()}')
977 res.append(sock)
978
979 return res
980
981 async def _connect_to_server(self):
982 addrs = await self._loop.getaddrinfo(
983 self._route.host_to_server,
984 self._route.port_to_server,
985 type=socket.SOCK_DGRAM,
986 )
987 for addr in addrs:
988 server = Socket(addr[0], addr[1])
989 try:
990 _make_socket_nonblocking(server)
991 await self._loop.sock_connect(server, addr[4])
992 logging.trace('Connected to %s', addr[4])
993 return server
994 except Exception as exc: # pylint: disable=broad-except
995 logging.warning('Could not connect to %s: %s', addr[4], exc)
996
997 def _collect_garbage(self) -> None:
998 super()._collect_garbage()
999 self._clients = {c for c in self._clients if c.is_active()}
1000
1001 async def _do_accept(self, accept_sock: Socket):
1002 sock = asyncio_socket.from_socket(accept_sock)
1003 while True:
1004 data, addr = await sock.recvfrom(RECV_MAX_SIZE, timeout=60.0)
1005
1006 client: typing.Optional[_UdpDemuxSocketMock] = None
1007 for known_clients in self._clients:
1008 if addr == known_clients.peer_address:
1009 client = known_clients
1010 break
1011
1012 if client is None:
1013 server = await self._connect_to_server()
1014 if not server:
1015 accept_sock.close()
1016 break
1017
1018 client = _UdpDemuxSocketMock(accept_sock, addr)
1019 self._clients.add(client)
1020
1021 self._sockets.add(
1022 _SocketsPaired(
1023 self._route.name,
1024 self._loop,
1025 client,
1026 server,
1027 self._to_server_intercept,
1028 self._to_client_intercept,
1029 ),
1030 )
1031
1032 await client.push(self._loop, data)
1033 self._collect_garbage()
1034
1035 async def to_server_concat_packets(self, packet_size: int) -> None:
1036 raise NotImplementedError('Udp packets cannot be concatenated')
1037
1038 async def to_client_concat_packets(self, packet_size: int) -> None:
1039 raise NotImplementedError('Udp packets cannot be concatenated')
1040
1041 async def to_server_smaller_parts(
1042 self,
1043 max_size: int,
1044 *,
1045 sleep_per_packet: float = 0,
1046 ) -> None:
1047 raise NotImplementedError('Udp packets cannot be split')
1048
1049 async def to_client_smaller_parts(
1050 self,
1051 max_size: int,
1052 *,
1053 sleep_per_packet: float = 0,
1054 ) -> None:
1055 raise NotImplementedError('Udp packets cannot be split')
1056
1057
1058def _create_callqueue(obj):
1059 if obj is None:
1060 return None
1061
1062 # workaround testsuite acallqueue that does not work with instances
1063 if isinstance(obj, callinfo.AsyncCallQueue):
1064 return obj
1065 if hasattr(obj, '__name__'):
1066 return callinfo.acallqueue(obj)
1067
1068 @functools.wraps(obj)
1069 async def wrapper(*args, **kwargs):
1070 return await obj(*args, **kwargs)
1071
1072 return callinfo.acallqueue(wrapper)
1073
1074
1075async def _wait_for_data(sock, timeout=60.0):
1076 if isinstance(sock, _UdpDemuxSocketMock):
1077 sock = sock.get_demux_out()
1078 sock = asyncio_socket.from_socket(sock)
1079 await sock.wait_for_data(timeout=timeout)