userver: /home/antonyzhilin/arcadia/taxi/uservices/userver/testsuite/pytest_plugins/pytest_userver/chaos.py Source File
Loading...
Searching...
No Matches
chaos.py
1# pylint: disable=too-many-lines
2"""
3Python module that provides testsuite support for
4chaos tests; see
5@ref scripts/docs/en/userver/chaos_testing.md for an introduction.
6
7@ingroup userver_testsuite
8"""
9
10import asyncio
11import dataclasses
12import functools
13import io
14import logging
15import random
16import re
17import socket
18import time
19import typing
20
21import pytest
22
23from testsuite.utils import callinfo
24
25from pytest_userver import asyncio_socket
26
27
28class BaseError(Exception):
29 pass
30
31
33 pass
34
35
36@dataclasses.dataclass(frozen=True)
38 """
39 Class that describes the route for TcpGate or UdpGate.
40
41 Use `port_for_client == 0` to bind to some unused port. In that case the
42 actual address could be retrieved via BaseGate.get_sockname_for_clients().
43
44 @ingroup userver_testsuite
45 """
46
47 name: str
48 host_to_server: str
49 port_to_server: int
50 host_for_client: str = '127.0.0.1'
51 port_for_client: int = 0
52
53
54# @cond
55
56# https://docs.python.org/3/library/socket.html#socket.socket.recv
57RECV_MAX_SIZE = 4096
58MAX_DELAY = 60.0
59
60
61logger = logging.getLogger(__name__)
62
63
64Address = typing.Tuple[str, int]
65EvLoop = typing.Any
66Socket = socket.socket
67Interceptor = typing.Callable[
68 [EvLoop, Socket, Socket],
69 typing.Coroutine[typing.Any, typing.Any, None],
70]
71
72
73class GateException(Exception):
74 pass
75
76
77class GateInterceptException(Exception):
78 pass
79
80
81async def _intercept_ok(
82 loop: EvLoop,
83 socket_from: Socket,
84 socket_to: Socket,
85) -> None:
86 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
87 if not data:
88 raise ConnectionClosedError()
89 await loop.sock_sendall(socket_to, data)
90
91
92async def _intercept_drop(
93 loop: EvLoop,
94 socket_from: Socket,
95 socket_to: Socket,
96) -> None:
97 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
98 if not data:
99 raise ConnectionClosedError()
100
101
102async def _intercept_delay(
103 delay: float,
104 loop: EvLoop,
105 socket_from: Socket,
106 socket_to: Socket,
107) -> None:
108 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
109 if not data:
110 raise ConnectionClosedError()
111 await asyncio.sleep(delay)
112 await loop.sock_sendall(socket_to, data)
113
114
115async def _intercept_close_on_data(
116 loop: EvLoop,
117 socket_from: Socket,
118 socket_to: Socket,
119) -> None:
120 data = await loop.sock_recv(socket_from, 1)
121 if not data:
122 raise ConnectionClosedError()
123 raise GateInterceptException('Closing socket on data')
124
125
126async def _intercept_corrupt(
127 loop: EvLoop,
128 socket_from: Socket,
129 socket_to: Socket,
130) -> None:
131 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
132 if not data:
133 raise ConnectionClosedError()
134 await loop.sock_sendall(socket_to, bytearray([not x for x in data]))
135
136
137class _InterceptBpsLimit:
138 def __init__(self, bytes_per_second: float):
139 assert bytes_per_second >= 1
140 self._bytes_per_second = bytes_per_second
141 self._time_last_added = 0.0
142 self._bytes_left = self._bytes_per_second
143
144 def _update_limit(self) -> None:
145 current_time = time.monotonic()
146 elapsed = current_time - self._time_last_added
147 bytes_addition = self._bytes_per_second * elapsed
148 if bytes_addition > 0:
149 self._bytes_left += bytes_addition
150 self._time_last_added = current_time
151
152 if self._bytes_left > self._bytes_per_second:
153 self._bytes_left = self._bytes_per_second
154
155 async def __call__(
156 self,
157 loop: EvLoop,
158 socket_from: Socket,
159 socket_to: Socket,
160 ) -> None:
161 self._update_limit()
162
163 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
164 if bytes_to_recv > 0:
165 data = await loop.sock_recv(socket_from, bytes_to_recv)
166 if not data:
167 raise ConnectionClosedError()
168 self._bytes_left -= len(data)
169
170 await loop.sock_sendall(socket_to, data)
171 else:
172 logger.info('Socket hits the bytes per second limit')
173 await asyncio.sleep(1.0 / self._bytes_per_second)
174
175
176class _InterceptTimeLimit:
177 def __init__(self, timeout: float, jitter: float):
178 self._sockets: typing.Dict[Socket, float] = {}
179 assert timeout >= 0.0
180 self._timeout = timeout
181 assert jitter >= 0.0
182 self._jitter = jitter
183
184 def raise_if_timed_out(self, socket_from: Socket) -> None:
185 if socket_from not in self._sockets:
186 jitter = self._jitter * random.random()
187 expire_at = time.monotonic() + self._timeout + jitter
188 self._sockets[socket_from] = expire_at
189
190 if self._sockets[socket_from] <= time.monotonic():
191 del self._sockets[socket_from]
192 raise GateInterceptException('Socket hits the time limit')
193
194 async def __call__(
195 self,
196 loop: EvLoop,
197 socket_from: Socket,
198 socket_to: Socket,
199 ) -> None:
200 self.raise_if_timed_out(socket_from)
201 await _intercept_ok(loop, socket_from, socket_to)
202
203
204class _InterceptSmallerParts:
205 def __init__(self, max_size: int, sleep_per_packet: float):
206 assert max_size > 0
207 self._max_size = max_size
208 self._sleep_per_packet = sleep_per_packet
209
210 async def __call__(
211 self,
212 loop: EvLoop,
213 socket_from: Socket,
214 socket_to: Socket,
215 ) -> None:
216 data = await loop.sock_recv(socket_from, self._max_size)
217 if not data:
218 raise ConnectionClosedError()
219 await asyncio.sleep(self._sleep_per_packet)
220 await loop.sock_sendall(socket_to, data)
221
222
223class _InterceptConcatPackets:
224 def __init__(self, packet_size: int):
225 assert packet_size >= 0
226 self._packet_size = packet_size
227 self._expire_at: typing.Optional[float] = None
228 self._buf = io.BytesIO()
229
230 async def __call__(
231 self,
232 loop: EvLoop,
233 socket_from: Socket,
234 socket_to: Socket,
235 ) -> None:
236 if self._expire_at is None:
237 self._expire_at = time.monotonic() + MAX_DELAY
238
239 if self._expire_at <= time.monotonic():
240 pytest.fail(
241 f'Failed to make a packet of sufficient size in {MAX_DELAY} '
242 'seconds. Check the test logic, it should end with checking '
243 'that the data was sent and by calling TcpGate function '
244 'to_client_pass() to pass the remaining packets.',
245 )
246
247 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
248 if not data:
249 raise ConnectionClosedError()
250 self._buf.write(data)
251 if self._buf.tell() >= self._packet_size:
252 await loop.sock_sendall(socket_to, self._buf.getvalue())
253 self._buf = io.BytesIO()
254 self._expire_at = None
255
256
257class _InterceptBytesLimit:
258 def __init__(self, bytes_limit: int, gate: 'BaseGate'):
259 assert bytes_limit >= 0
260 self._bytes_limit = bytes_limit
261 self._bytes_remain = self._bytes_limit
262 self._gate = gate
263
264 async def __call__(
265 self,
266 loop: EvLoop,
267 socket_from: Socket,
268 socket_to: Socket,
269 ) -> None:
270 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
271 if not data:
272 raise ConnectionClosedError()
273 if self._bytes_remain <= len(data):
274 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
275 await self._gate.sockets_close()
276 self._bytes_remain = self._bytes_limit
277 raise GateInterceptException('Data transmission limit reached')
278 self._bytes_remain -= len(data)
279 await loop.sock_sendall(socket_to, data)
280
281
282class _InterceptSubstitute:
283 def __init__(self, pattern: str, repl: str, encoding='utf-8'):
284 self._pattern = re.compile(pattern)
285 self._repl = repl
286 self._encoding = encoding
287
288 async def __call__(
289 self,
290 loop: EvLoop,
291 socket_from: Socket,
292 socket_to: Socket,
293 ) -> None:
294 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
295 if not data:
296 raise ConnectionClosedError()
297 try:
298 res = self._pattern.sub(self._repl, data.decode(self._encoding))
299 data = res.encode(self._encoding)
300 except UnicodeError:
301 pass
302 await loop.sock_sendall(socket_to, data)
303
304
305async def _cancel_and_join(task: typing.Optional[asyncio.Task]) -> None:
306 if not task or task.cancelled():
307 return
308
309 try:
310 task.cancel()
311 await task
312 except asyncio.CancelledError:
313 return
314 except Exception: # pylint: disable=broad-except
315 logger.exception('Exception in _cancel_and_join')
316
317
318def _make_socket_nonblocking(sock: Socket) -> None:
319 sock.setblocking(False)
320 if sock.type == socket.SOCK_STREAM:
321 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
322
323
324class _UdpDemuxSocketMock:
325 """
326 Emulates a point-to-point connection over UDP socket
327 with a non-blocking socket interface
328 """
329
330 def gettimeout(self):
331 return self._sock.gettimeout()
332
333 def __init__(self, sock: Socket, peer_address: Address):
334 self._sock: Socket = sock
335 self._peeraddr: Address = peer_address
336
337 sockpair = socket.socketpair(type=socket.SOCK_DGRAM)
338 self._demux_in: Socket = sockpair[0]
339 self._demux_out: Socket = sockpair[1]
340 _make_socket_nonblocking(self._demux_in)
341 _make_socket_nonblocking(self._demux_out)
342 self._is_active: bool = True
343
344 @property
345 def peer_address(self):
346 return self._peeraddr
347
348 async def push(self, loop: EvLoop, data: bytes):
349 return await loop.sock_sendall(self._demux_in, data)
350
351 def is_active(self):
352 return self._is_active
353
354 def close(self):
355 self._is_active = False
356 self._demux_out.close()
357 self._demux_in.close()
358
359 def recvfrom(self, bufsize: int, flags: int = 0):
360 return self._demux_out.recvfrom(bufsize, flags)
361
362 def recv(self, bufsize: int, flags: int = 0):
363 return self._demux_out.recv(bufsize, flags)
364
365 def get_demux_out(self):
366 return self._demux_out
367
368 def fileno(self):
369 return self._demux_out.fileno()
370
371 def send(self, data: bytes):
372 return self._sock.sendto(data, self._peeraddr)
373
374
375class InterceptTask:
376 def __init__(self, socket_from, socket_to, interceptor):
377 self._socket_from = socket_from
378 self._socket_to = socket_to
379 self._condition = asyncio.Condition()
380 self._interceptor = interceptor
381
382 def get_interceptor(self):
383 return self._interceptor
384
385 async def set_interceptor(self, interceptor):
386 async with self._condition:
387 self._interceptor = interceptor
388 self._condition.notify()
389
390 async def run(self):
391 loop = asyncio.get_running_loop()
392 while True:
393 # Applies new interceptors faster.
394 #
395 # To avoid long awaiting on sock_recv in an outdated
396 # interceptor we wait for data before grabbing and applying
397 # the interceptor.
398 await _wait_for_data(self._socket_from)
399
400 # Wait for interceptor attched
401 async with self._condition:
402 interceptor = await self._condition.wait_for(self.get_interceptor)
403
404 logging.trace('running interceptor: %s', interceptor)
405 await interceptor(loop, self._socket_from, self._socket_to)
406
407
408class _SocketsPaired:
409 def __init__(
410 self,
411 proxy_name: str,
412 loop: EvLoop,
413 client: typing.Union[socket.socket, _UdpDemuxSocketMock],
414 server: socket.socket,
415 to_server_intercept: Interceptor,
416 to_client_intercept: Interceptor,
417 ) -> None:
418 self._proxy_name = proxy_name
419
420 self._client = client
421 self._server = server
422
423 self._task_to_server = InterceptTask(client, server, to_server_intercept)
424 self._task_to_client = InterceptTask(server, client, to_client_intercept)
425
426 self._task = asyncio.create_task(self._run())
427 self._interceptor_tasks = []
428
429 async def set_to_server_interceptor(self, interceptor):
430 await self._task_to_server.set_interceptor(interceptor)
431
432 async def set_to_client_interceptor(self, interceptor: Interceptor):
433 await self._task_to_client.set_interceptor(interceptor)
434
435 async def shutdown(self) -> None:
436 for task in self._interceptor_tasks:
437 await _cancel_and_join(task)
438 await _cancel_and_join(self._task)
439
440 def is_active(self) -> bool:
441 return not self._task.done()
442
443 def info(self) -> str:
444 if not self.is_active():
445 return '<inactive>'
446
447 return f'client fd={self._client.fileno()} <=> server fd={self._server.fileno()}'
448
449 async def _run(self):
450 self._interceptor_tasks = [
451 asyncio.create_task(obj.run()) for obj in (self._task_to_server, self._task_to_client)
452 ]
453 try:
454 done, _ = await asyncio.wait(self._interceptor_tasks, return_when=asyncio.FIRST_EXCEPTION)
455 for task in done:
456 task.result()
457 except GateInterceptException as exc:
458 logger.info('In "%s": %s', self._proxy_name, exc)
459 except socket.error as exc:
460 logger.error('Exception in "%s": %s', self._proxy_name, exc)
461 except Exception:
462 logger.exception('interceptor failed')
463 finally:
464 for task in self._interceptor_tasks:
465 task.cancel()
466
467 # Closing the sockets here so that the self.shutdown()
468 # returns only when the sockets are actually closed
469 for sock in self._server, self._client:
470 try:
471 sock.close()
472 except socket.error:
473 logger.exception('Exception in "%s" on closing %s:', self._proxy_name, sock)
474
475
476# @endcond
477
478
480 """
481 This base class maintain endpoints of two types:
482
483 Server-side endpoints to receive messages from clients. Address of this
484 endpoint is described by (host_for_client, port_for_client).
485
486 Client-side endpoints to forward messages to server. Server must listen on
487 (host_to_server, port_to_server).
488
489 Asynchronously concurrently passes data from client to server and from
490 server to client, allowing intercepting the data, injecting delays and
491 dropping connections.
492
493 @warning Do not use this class itself. Use one of the specifications
494 TcpGate or UdpGate
495
496 @ingroup userver_testsuite
497
498 @see @ref scripts/docs/en/userver/chaos_testing.md
499 """
500
501 _NOT_IMPLEMENTED_MESSAGE = 'Do not use BaseGate itself, use one of specializations TcpGate or UdpGate'
502
503 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None) -> None:
504 self._route = route
505 if loop is None:
506 loop = asyncio.get_running_loop()
507 self._loop = loop
508
509 self._to_server_intercept: Interceptor = _intercept_ok
510 self._to_client_intercept: Interceptor = _intercept_ok
511
512 self._accept_sockets: typing.List[socket.socket] = []
513 self._accept_tasks: typing.List[asyncio.Task[None]] = []
514
515 self._sockets: typing.Set[_SocketsPaired] = set()
516
517 async def __aenter__(self) -> 'BaseGate':
518 self.start()
519 return self
520
521 async def __aexit__(self, exc_type, exc_value, traceback) -> None:
522 await self.stop()
523
524 def _create_accepting_sockets(self) -> typing.List[Socket]:
525 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
526
527 def start(self):
528 """Open the socket and start accepting tasks"""
529 if self._accept_sockets:
530 return
531
533
534 if not self._accept_sockets:
535 raise GateException(
536 f'Could not resolve hostname {self._route.host_for_client}',
537 )
538
539 if self._route.port_for_client == 0:
540 # In case of stop()+start() bind to the same port
541 self._route = GateRoute(
542 name=self._route.name,
543 host_to_server=self._route.host_to_server,
544 port_to_server=self._route.port_to_server,
545 host_for_client=self._accept_sockets[0].getsockname()[0],
546 port_for_client=self._accept_sockets[0].getsockname()[1],
547 )
548
549 self.start_accepting()
550
551 def start_accepting(self) -> None:
552 """Start accepting tasks"""
553 assert self._accept_sockets
554 if not all(tsk.done() for tsk in self._accept_tasks):
555 return
556
557 self._accept_tasks.clear()
558 for sock in self._accept_sockets:
559 self._accept_tasks.append(
560 asyncio.create_task(self._do_accept(sock)),
561 )
562
563 async def stop_accepting(self) -> None:
564 """
565 Stop accepting tasks without closing the accepting socket.
566 """
567 for tsk in self._accept_tasks:
568 await _cancel_and_join(tsk)
569 self._accept_tasks.clear()
570
571 async def stop(self) -> None:
572 """
573 Stop accepting tasks, close all the sockets
574 """
575 if not self._accept_sockets and not self._sockets:
576 return
577
578 await self.to_server_pass()
579 await self.to_client_pass()
580
581 await self.stop_accepting()
582 logger.info('Before close() %s', self.info())
583 await self.sockets_close()
584 assert not self._sockets
585
586 for sock in self._accept_sockets:
587 sock.close()
588 self._accept_sockets.clear()
589 logger.info('Stopped. %s', self.info())
590
591 async def sockets_close(
592 self,
593 *,
594 count: typing.Optional[int] = None,
595 ) -> None:
596 """Close all the connection going through the gate"""
597 for x in list(self._sockets)[0:count]:
598 await x.shutdown()
599 self._collect_garbage()
600
601 def get_sockname_for_clients(self) -> Address:
602 """
603 Returns the client socket address that the gate listens on.
604
605 This function allows to use 0 in GateRoute.port_for_client and retrieve
606 the actual port and host.
607 """
608 assert self._route.port_for_client != 0, ('Gate was not started and the port_for_client is still 0',)
609 return (self._route.host_for_client, self._route.port_for_client)
610
611 def info(self) -> str:
612 """Print info on open sockets"""
613 if not self._sockets:
614 return f'"{self._route.name}" no active sockets'
615
616 return f'"{self._route.name}" active sockets:\n\t' + '\n\t'.join(x.info() for x in self._sockets)
617
618 def _collect_garbage(self) -> None:
619 self._sockets = {x for x in self._sockets if x.is_active()}
620
621 async def _do_accept(self, accept_sock: Socket) -> None:
622 """
623 This task should wait for connections and create SocketPair
624 """
625 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
626
627 async def set_to_server_interceptor(self, interceptor: Interceptor) -> callinfo.AsyncCallQueue:
628 """
629 Replace existing interceptor of client to server data with a custom
630 """
631 self._to_server_intercept = _create_callqueue(interceptor)
632 for x in self._sockets:
633 await x.set_to_server_interceptor(self._to_server_intercept)
634 return self._to_server_intercept
635
636 async def set_to_client_interceptor(self, interceptor: Interceptor) -> callinfo.AsyncCallQueue:
637 """
638 Replace existing interceptor of server to client data with a custom
639
640 """
641 if interceptor is not None:
642 self._to_client_intercept = _create_callqueue(interceptor)
643 else:
644 self._to_client_intercept = None
645 for x in self._sockets:
646 await x.set_to_client_interceptor(self._to_client_intercept)
647 return self._to_client_intercept
648
649 async def to_server_pass(self) -> callinfo.AsyncCallQueue:
650 """Pass data as is"""
651 logging.trace('to_server_pass')
652 return await self.set_to_server_interceptor(_intercept_ok)
653
654 async def to_client_pass(self) -> callinfo.AsyncCallQueue:
655 """Pass data as is"""
656 logging.trace('to_client_pass')
657 return await self.set_to_client_interceptor(_intercept_ok)
658
659 async def to_server_noop(self) -> callinfo.AsyncCallQueue:
660 """Do not read data, causing client to keep multiple data"""
661 logging.trace('to_server_noop')
662 return await self.set_to_server_interceptor(None)
663
664 async def to_client_noop(self) -> callinfo.AsyncCallQueue:
665 """Do not read data, causing server to keep multiple data"""
666 logging.trace('to_client_noop')
667 return await self.set_to_client_interceptor(None)
668
669 async def to_server_drop(self) -> callinfo.AsyncCallQueue:
670 """Read and discard data"""
671 logging.trace('to_server_drop')
672 return await self.set_to_server_interceptor(_intercept_drop)
673
674 async def to_client_drop(self) -> callinfo.AsyncCallQueue:
675 """Read and discard data"""
676 logging.trace('to_client_drop')
677 return await self.set_to_client_interceptor(_intercept_drop)
678
679 async def to_server_delay(self, delay: float) -> callinfo.AsyncCallQueue:
680 """Delay data transmission"""
681 logging.trace('to_server_delay, delay: %s', delay)
682
683 async def _intercept_delay_bound(
684 loop: EvLoop,
685 socket_from: Socket,
686 socket_to: Socket,
687 ) -> None:
688 await _intercept_delay(delay, loop, socket_from, socket_to)
689
690 return await self.set_to_server_interceptor(_intercept_delay_bound)
691
692 async def to_client_delay(self, delay: float) -> callinfo.AsyncCallQueue:
693 """Delay data transmission"""
694 logging.trace('to_client_delay, delay: %s', delay)
695
696 async def _intercept_delay_bound(
697 loop: EvLoop,
698 socket_from: Socket,
699 socket_to: Socket,
700 ) -> None:
701 await _intercept_delay(delay, loop, socket_from, socket_to)
702
703 return await self.set_to_client_interceptor(_intercept_delay_bound)
704
705 async def to_server_close_on_data(self) -> callinfo.AsyncCallQueue:
706 """Close on first bytes of data from client"""
707 logging.trace('to_server_close_on_data')
708 return await self.set_to_server_interceptor(_intercept_close_on_data)
709
710 async def to_client_close_on_data(self) -> callinfo.AsyncCallQueue:
711 """Close on first bytes of data from server"""
712 logging.trace('to_client_close_on_data')
713 return await self.set_to_client_interceptor(_intercept_close_on_data)
714
715 async def to_server_corrupt_data(self) -> callinfo.AsyncCallQueue:
716 """Corrupt data received from client"""
717 logging.trace('to_server_corrupt_data')
718 return await self.set_to_server_interceptor(_intercept_corrupt)
719
720 async def to_client_corrupt_data(self) -> callinfo.AsyncCallQueue:
721 """Corrupt data received from server"""
722 logging.trace('to_client_corrupt_data')
723 return await self.set_to_client_interceptor(_intercept_corrupt)
724
725 async def to_server_limit_bps(self, bytes_per_second: float) -> callinfo.AsyncCallQueue:
726 """Limit bytes per second transmission by network from client"""
727 logging.trace(
728 'to_server_limit_bps, bytes_per_second: %s',
729 bytes_per_second,
730 )
731 return await self.set_to_server_interceptor(_InterceptBpsLimit(bytes_per_second))
732
733 async def to_client_limit_bps(self, bytes_per_second: float) -> callinfo.AsyncCallQueue:
734 """Limit bytes per second transmission by network from server"""
735 logging.trace(
736 'to_client_limit_bps, bytes_per_second: %s',
737 bytes_per_second,
738 )
739 return await self.set_to_client_interceptor(_InterceptBpsLimit(bytes_per_second))
740
741 async def to_server_limit_time(self, timeout: float, jitter: float) -> callinfo.AsyncCallQueue:
742 """Limit connection lifetime on receive of first bytes from client"""
743 logging.trace(
744 'to_server_limit_time, timeout: %s, jitter: %s',
745 timeout,
746 jitter,
747 )
748 return await self.set_to_server_interceptor(_InterceptTimeLimit(timeout, jitter))
749
750 async def to_client_limit_time(self, timeout: float, jitter: float) -> callinfo.AsyncCallQueue:
751 """Limit connection lifetime on receive of first bytes from server"""
752 logging.trace(
753 'to_client_limit_time, timeout: %s, jitter: %s',
754 timeout,
755 jitter,
756 )
757 return await self.set_to_client_interceptor(_InterceptTimeLimit(timeout, jitter))
758
759 async def to_server_smaller_parts(
760 self,
761 max_size: int,
762 *,
763 sleep_per_packet: float = 0,
764 ) -> callinfo.AsyncCallQueue:
765 """
766 Pass data to server in smaller parts
767
768 @param max_size Max packet size to send to server
769 @param sleep_per_packet Optional sleep interval per packet, seconds
770 """
771 logging.trace('to_server_smaller_parts, max_size: %s', max_size)
772 return await self.set_to_server_interceptor(
773 _InterceptSmallerParts(max_size, sleep_per_packet),
774 )
775
776 async def to_client_smaller_parts(
777 self,
778 max_size: int,
779 *,
780 sleep_per_packet: float = 0,
781 ) -> callinfo.AsyncCallQueue:
782 """
783 Pass data to client in smaller parts
784
785 @param max_size Max packet size to send to client
786 @param sleep_per_packet Optional sleep interval per packet, seconds
787 """
788 logging.trace('to_client_smaller_parts, max_size: %s', max_size)
789 return await self.set_to_client_interceptor(
790 _InterceptSmallerParts(max_size, sleep_per_packet),
791 )
792
793 async def to_server_concat_packets(self, packet_size: int) -> callinfo.AsyncCallQueue:
794 """
795 Pass data in bigger parts
796 @param packet_size minimal size of the resulting packet
797 """
798 logging.trace('to_server_concat_packets, packet_size: %s', packet_size)
799 return await self.set_to_server_interceptor(_InterceptConcatPackets(packet_size))
800
801 async def to_client_concat_packets(self, packet_size: int) -> callinfo.AsyncCallQueue:
802 """
803 Pass data in bigger parts
804 @param packet_size minimal size of the resulting packet
805 """
806 logging.trace('to_client_concat_packets, packet_size: %s', packet_size)
807 return await self.set_to_client_interceptor(_InterceptConcatPackets(packet_size))
808
809 async def to_server_limit_bytes(self, bytes_limit: int) -> callinfo.AsyncCallQueue:
810 """Drop all connections each `bytes_limit` of data sent by network"""
811 logging.trace('to_server_limit_bytes, bytes_limit: %s', bytes_limit)
812 return await self.set_to_server_interceptor(_InterceptBytesLimit(bytes_limit, self))
813
814 async def to_client_limit_bytes(self, bytes_limit: int) -> callinfo.AsyncCallQueue:
815 """Drop all connections each `bytes_limit` of data sent by network"""
816 logging.trace('to_client_limit_bytes, bytes_limit: %s', bytes_limit)
817 return await self.set_to_client_interceptor(_InterceptBytesLimit(bytes_limit, self))
818
819 async def to_server_substitute(self, pattern: str, repl: str) -> callinfo.AsyncCallQueue:
820 """Apply regex substitution to data from client"""
821 logging.trace(
822 'to_server_substitute, pattern: %s, repl: %s',
823 pattern,
824 repl,
825 )
826 return await self.set_to_server_interceptor(_InterceptSubstitute(pattern, repl))
827
828 async def to_client_substitute(self, pattern: str, repl: str) -> callinfo.AsyncCallQueue:
829 """Apply regex substitution to data from server"""
830 logging.trace(
831 'to_client_substitute, pattern: %s, repl: %s',
832 pattern,
833 repl,
834 )
835 return await self.set_to_client_interceptor(_InterceptSubstitute(pattern, repl))
836
837
838class TcpGate(BaseGate):
839 """
840 Implements TCP chaos-proxy logic such as accepting incoming tcp client
841 connections. On each new connection new tcp client connects to server
842 (host_to_server, port_to_server).
843
844 @ingroup userver_testsuite
845
846 @see @ref scripts/docs/en/userver/chaos_testing.md
847 """
848
849 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None) -> None:
850 self._connected_event = asyncio.Event()
851 super().__init__(route, loop)
852
853 def connections_count(self) -> int:
854 """
855 Returns maximal amount of connections going through the gate at
856 the moment.
857
858 @warning Some of the connections could be closing, or could be opened
859 right before the function starts. Use with caution!
860 """
861 return len(self._sockets)
862
863 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
864 """
865 Wait for at least `count` connections going through the gate.
866
867 @throws asyncio.TimeoutError exception if failed to get the
868 required amount of connections in time.
869 """
870 if timeout <= 0.0:
871 while self.connections_count() < count:
872 await self._connected_event.wait()
873 self._connected_event.clear()
874 return
875
876 deadline = time.monotonic() + timeout
877 while self.connections_count() < count:
878 time_left = deadline - time.monotonic()
879 await asyncio.wait_for(
880 self._connected_event.wait(),
881 timeout=time_left,
882 )
883 self._connected_event.clear()
884
885 def _create_accepting_sockets(self) -> typing.List[Socket]:
886 res: typing.List[Socket] = []
887 for addr in socket.getaddrinfo(
888 self._route.host_for_client,
889 self._route.port_for_client,
890 type=socket.SOCK_STREAM,
891 ):
892 sock = Socket(addr[0], addr[1])
893 _make_socket_nonblocking(sock)
894 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
895 sock.bind(addr[4])
896 sock.listen()
897 logger.debug(
898 f'Accepting connections on {sock.getsockname()}, fd={sock.fileno()}',
899 )
900 res.append(sock)
901
902 return res
903
904 async def _connect_to_server(self):
905 addrs = await self._loop.getaddrinfo(
906 self._route.host_to_server,
907 self._route.port_to_server,
908 type=socket.SOCK_STREAM,
909 )
910 for addr in addrs:
911 server = Socket(addr[0], addr[1])
912 _make_socket_nonblocking(server)
913 try:
914 await self._loop.sock_connect(server, addr[4])
915 logging.trace('Connected to %s', addr[4])
916 return server
917 except Exception as exc: # pylint: disable=broad-except
918 server.close()
919 logging.warning('Could not connect to %s: %s', addr[4], exc)
920
921 async def _do_accept(self, accept_sock: Socket) -> None:
922 while True:
923 client, _ = await self._loop.sock_accept(accept_sock)
924 _make_socket_nonblocking(client)
925
926 server = await self._connect_to_server()
927 if server:
928 self._sockets.add(
929 _SocketsPaired(
930 self._route.name,
931 self._loop,
932 client,
933 server,
934 self._to_server_intercept,
935 self._to_client_intercept,
936 ),
937 )
938 self._connected_event.set()
939 else:
940 client.close()
941
942 self._collect_garbage()
943
944
945class UdpGate(BaseGate):
946 """
947 Implements UDP chaos-proxy logic such as demuxing incoming datagrams
948 from different clients.
949 Separate connections to server are made for each new client.
950
951 @ingroup userver_testsuite
952
953 @see @ref scripts/docs/en/userver/chaos_testing.md
954 """
955
956 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None):
957 self._clients: typing.Set[_UdpDemuxSocketMock] = set()
958 super().__init__(route, loop)
959
960 def is_connected(self) -> bool:
961 """
962 Returns True if there is active pair of sockets ready to transfer data
963 at the moment.
964 """
965 return len(self._sockets) > 0
966
967 def _create_accepting_sockets(self) -> typing.List[Socket]:
968 res: typing.List[Socket] = []
969 for addr in socket.getaddrinfo(
970 self._route.host_for_client,
971 self._route.port_for_client,
972 type=socket.SOCK_DGRAM,
973 ):
974 sock = socket.socket(addr[0], addr[1])
975 _make_socket_nonblocking(sock)
976 sock.bind(addr[4])
977 logger.debug(f'Accepting connections on {sock.getsockname()}')
978 res.append(sock)
979
980 return res
981
982 async def _connect_to_server(self):
983 addrs = await self._loop.getaddrinfo(
984 self._route.host_to_server,
985 self._route.port_to_server,
986 type=socket.SOCK_DGRAM,
987 )
988 for addr in addrs:
989 server = Socket(addr[0], addr[1])
990 try:
991 _make_socket_nonblocking(server)
992 await self._loop.sock_connect(server, addr[4])
993 logging.trace('Connected to %s', addr[4])
994 return server
995 except Exception as exc: # pylint: disable=broad-except
996 logging.warning('Could not connect to %s: %s', addr[4], exc)
997
998 def _collect_garbage(self) -> None:
999 super()._collect_garbage()
1000 self._clients = {c for c in self._clients if c.is_active()}
1001
1002 async def _do_accept(self, accept_sock: Socket):
1003 sock = asyncio_socket.from_socket(accept_sock)
1004 while True:
1005 data, addr = await sock.recvfrom(RECV_MAX_SIZE, timeout=60.0)
1006
1007 client: typing.Optional[_UdpDemuxSocketMock] = None
1008 for known_clients in self._clients:
1009 if addr == known_clients.peer_address:
1010 client = known_clients
1011 break
1012
1013 if client is None:
1014 server = await self._connect_to_server()
1015 if not server:
1016 accept_sock.close()
1017 break
1018
1019 client = _UdpDemuxSocketMock(accept_sock, addr)
1020 self._clients.add(client)
1021
1022 self._sockets.add(
1023 _SocketsPaired(
1024 self._route.name,
1025 self._loop,
1026 client,
1027 server,
1028 self._to_server_intercept,
1029 self._to_client_intercept,
1030 ),
1031 )
1032
1033 await client.push(self._loop, data)
1034 self._collect_garbage()
1035
1036 async def to_server_concat_packets(self, packet_size: int) -> None:
1037 raise NotImplementedError('Udp packets cannot be concatenated')
1038
1039 async def to_client_concat_packets(self, packet_size: int) -> None:
1040 raise NotImplementedError('Udp packets cannot be concatenated')
1041
1042 async def to_server_smaller_parts(
1043 self,
1044 max_size: int,
1045 *,
1046 sleep_per_packet: float = 0,
1047 ) -> None:
1048 raise NotImplementedError('Udp packets cannot be split')
1049
1050 async def to_client_smaller_parts(
1051 self,
1052 max_size: int,
1053 *,
1054 sleep_per_packet: float = 0,
1055 ) -> None:
1056 raise NotImplementedError('Udp packets cannot be split')
1057
1058
1059def _create_callqueue(obj):
1060 if obj is None:
1061 return None
1062
1063 # workaround testsuite acallqueue that does not work with instances
1064 if isinstance(obj, callinfo.AsyncCallQueue):
1065 return obj
1066 if hasattr(obj, '__name__'):
1067 return callinfo.acallqueue(obj)
1068
1069 @functools.wraps(obj)
1070 async def wrapper(*args, **kwargs):
1071 return await obj(*args, **kwargs)
1072
1073 return callinfo.acallqueue(wrapper)
1074
1075
1076async def _wait_for_data(sock, timeout=60.0):
1077 if isinstance(sock, _UdpDemuxSocketMock):
1078 sock = sock.get_demux_out()
1079 sock = asyncio_socket.from_socket(sock)
1080 await sock.wait_for_data(timeout=timeout)