userver: testsuite/pytest_plugins/pytest_userver/chaos.py
Loading...
Searching...
No Matches
testsuite/pytest_plugins/pytest_userver/chaos.py
1"""
2Python module that provides testsuite support for
3chaos tests; see
4@ref scripts/docs/en/userver/chaos_testing.md for an introduction.
5
6@ingroup userver_testsuite
7"""
8
9import asyncio
10import dataclasses
11import errno
12import fcntl
13import logging
14import os
15import random
16import re
17import socket
18import sys
19import time
20import typing
21
22
23@dataclasses.dataclass(frozen=True)
24class GateRoute:
25 """
26 Class that describes the route for TcpGate or UdpGate.
27
28 Use `port_for_client == 0` to bind to some unused port. In that case the
29 actual address could be retrieved via BaseGate.get_sockname_for_clients().
30
31 @ingroup userver_testsuite
32 """
33
34 name: str
35 host_to_server: str
36 port_to_server: int
37 host_for_client: str = '127.0.0.1'
38 port_for_client: int = 0
39
40
41# @cond
42
43# https://docs.python.org/3/library/socket.html#socket.socket.recv
44RECV_MAX_SIZE = 4096
45MAX_DELAY = 60.0
46
47
48logger = logging.getLogger(__name__)
49
50
51Address = typing.Tuple[str, int]
52EvLoop = typing.Any
53Socket = socket.socket
54Interceptor = typing.Callable[
55 [EvLoop, Socket, Socket], typing.Coroutine[typing.Any, typing.Any, None],
56]
57
58
59class GateException(Exception):
60 pass
61
62
63class GateInterceptException(Exception):
64 pass
65
66
67async def _yield() -> None:
68 # Minamal delay can be 0. This will be fast path for coroutine switching
69 # https://docs.python.org/3/library/asyncio-task.html#sleeping
70
71 min_delay = 0
72 await asyncio.sleep(min_delay)
73
74
75def _try_get_message(
76 recv_socket: Socket,
77) -> typing.Tuple[typing.Optional[bytes], typing.Optional[Address]]:
78
79 try:
80 return recv_socket.recvfrom(RECV_MAX_SIZE, socket.MSG_PEEK)
81 except socket.error as exc:
82 err = exc.args[0]
83 if err in {errno.EAGAIN, errno.EWOULDBLOCK}:
84 return None, None
85 raise exc
86
87
88async def _wait_for_message_task(
89 recv_socket: Socket,
90) -> typing.Tuple[bytes, Address]:
91 while True:
92 msg, addr = _try_get_message(recv_socket)
93 if msg:
94 assert addr
95 return msg, addr
96
97 await _yield()
98
99
100def _incoming_data_size(recv_socket: Socket) -> int:
101 msg, _ = _try_get_message(recv_socket)
102 return len(msg) if msg else 0
103
104
105async def _intercept_ok(
106 loop: EvLoop, socket_from: Socket, socket_to: Socket,
107) -> None:
108 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
109 await loop.sock_sendall(socket_to, data)
110
111
112async def _intercept_noop(
113 loop: EvLoop, socket_from: Socket, socket_to: Socket,
114) -> None:
115 pass
116
117
118async def _intercept_drop(
119 loop: EvLoop, socket_from: Socket, socket_to: Socket,
120) -> None:
121 await loop.sock_recv(socket_from, RECV_MAX_SIZE)
122
123
124async def _intercept_delay(
125 delay: float, loop: EvLoop, socket_from: Socket, socket_to: Socket,
126) -> None:
127 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
128 await asyncio.sleep(delay)
129 await loop.sock_sendall(socket_to, data)
130
131
132async def _intercept_close_on_data(
133 loop: EvLoop, socket_from: Socket, socket_to: Socket,
134) -> None:
135 await loop.sock_recv(socket_from, 1)
136 raise GateInterceptException('Closing socket on data')
137
138
139async def _intercept_corrupt(
140 loop: EvLoop, socket_from: Socket, socket_to: Socket,
141) -> None:
142 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
143 await loop.sock_sendall(socket_to, bytearray([not x for x in data]))
144
145
146class _InterceptBpsLimit:
147 def __init__(self, bytes_per_second: float):
148 assert bytes_per_second >= 1
149 self._bytes_per_second = bytes_per_second
150 self._time_last_added = 0.0
151 self._bytes_left = self._bytes_per_second
152
153 def _update_limit(self) -> None:
154 current_time = time.monotonic()
155 elapsed = current_time - self._time_last_added
156 bytes_addition = self._bytes_per_second * elapsed
157 if bytes_addition > 0:
158 self._bytes_left += bytes_addition
159 self._time_last_added = current_time
160
161 if self._bytes_left > self._bytes_per_second:
162 self._bytes_left = self._bytes_per_second
163
164 async def __call__(
165 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
166 ) -> None:
167 self._update_limit()
168
169 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
170 if bytes_to_recv > 0:
171 data = await loop.sock_recv(socket_from, bytes_to_recv)
172 self._bytes_left -= len(data)
173
174 await loop.sock_sendall(socket_to, data)
175 else:
176 logger.info('Socket hits the bytes per second limit')
177 await asyncio.sleep(1.0 / self._bytes_per_second)
178
179
180class _InterceptTimeLimit:
181 def __init__(self, timeout: float, jitter: float):
182 self._sockets: typing.Dict[Socket, float] = {}
183 assert timeout >= 0.0
184 self._timeout = timeout
185 assert jitter >= 0.0
186 self._jitter = jitter
187
188 def raise_if_timed_out(self, socket_from: Socket) -> None:
189 if socket_from not in self._sockets:
190 jitter = self._jitter * random.random()
191 expire_at = time.monotonic() + self._timeout + jitter
192 self._sockets[socket_from] = expire_at
193
194 if self._sockets[socket_from] <= time.monotonic():
195 del self._sockets[socket_from]
196 raise GateInterceptException('Socket hits the time limit')
197
198 async def __call__(
199 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
200 ) -> None:
201 self.raise_if_timed_out(socket_from)
202 await _intercept_ok(loop, socket_from, socket_to)
203
204
205class _InterceptSmallerParts:
206 def __init__(self, max_size: int, sleep_per_packet: float):
207 assert max_size > 0
208 self._max_size = max_size
209 self._sleep_per_packet = sleep_per_packet
210
211 async def __call__(
212 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
213 ) -> None:
214 incoming_size = _incoming_data_size(socket_from)
215 chunk_size = min(incoming_size, self._max_size)
216 data = await loop.sock_recv(socket_from, chunk_size)
217 await asyncio.sleep(self._sleep_per_packet)
218 await loop.sock_sendall(socket_to, data)
219
220
221class _InterceptConcatPackets:
222 def __init__(self, packet_size: int):
223 assert packet_size >= 0
224 self._packet_size = packet_size
225 self._expire_at: typing.Optional[float] = None
226
227 async def __call__(
228 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
229 ) -> None:
230 if self._expire_at is None:
231 self._expire_at = time.monotonic() + MAX_DELAY
232
233 if self._expire_at <= time.monotonic():
234 logger.error(
235 f'Failed to make a packet of sufficient size in {MAX_DELAY} '
236 'seconds. Check the test logic, it should end with checking '
237 'that the data was sent and by calling TcpGate function '
238 'to_client_pass() to pass the remaining packets.',
239 )
240 sys.exit(2)
241
242 incoming_size = _incoming_data_size(socket_from)
243 if incoming_size >= self._packet_size:
244 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
245 await loop.sock_sendall(socket_to, data)
246 self._expire_at = None
247
248
249class _InterceptBytesLimit:
250 def __init__(self, bytes_limit: int, gate: 'BaseGate'):
251 assert bytes_limit >= 0
252 self._bytes_limit = bytes_limit
253 self._bytes_remain = self._bytes_limit
254 self._gate = gate
255
256 async def __call__(
257 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
258 ) -> None:
259 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
260 if self._bytes_remain <= len(data):
261 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
262 await self._gate.sockets_close()
263 self._bytes_remain = self._bytes_limit
264 raise GateInterceptException('Data transmission limit reached')
265
266 self._bytes_remain -= len(data)
267 await loop.sock_sendall(socket_to, data)
268
269
270class _InterceptSubstitute:
271 def __init__(self, pattern: str, repl: str, encoding='utf-8'):
272 self._pattern = re.compile(pattern)
273 self._repl = repl
274 self._encoding = encoding
275
276 async def __call__(
277 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
278 ) -> None:
279 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
280 try:
281 res = self._pattern.sub(self._repl, data.decode(self._encoding))
282 data = res.encode(self._encoding)
283 except UnicodeError:
284 pass
285 await loop.sock_sendall(socket_to, data)
286
287
288async def _cancel_and_join(task: typing.Optional[asyncio.Task]) -> None:
289 if not task or task.cancelled():
290 return
291
292 try:
293 task.cancel()
294 await task
295 except asyncio.CancelledError:
296 return
297 except Exception as exc: # pylint: disable=broad-except
298 logger.error('Exception in _cancel_and_join: %s', exc)
299
300
301class _SocketsPaired:
302 def __init__(
303 self,
304 proxy_name: str,
305 loop: EvLoop,
306 client: socket.socket,
307 server: socket.socket,
308 to_server_intercept: Interceptor,
309 to_client_intercept: Interceptor,
310 ) -> None:
311 self._proxy_name = proxy_name
312 self._loop = loop
313
314 self._client = client
315 self._server = server
316
317 self._to_server_intercept: Interceptor = to_server_intercept
318 self._to_client_intercept: Interceptor = to_client_intercept
319
320 self._task_to_server = asyncio.create_task(
321 self._do_pipe_channels(to_server=True),
322 )
323 self._task_to_client = asyncio.create_task(
324 self._do_pipe_channels(to_server=False),
325 )
326
327 self._finished_channels = 0
328
329 async def _do_pipe_channels(self, *, to_server: bool) -> None:
330 if to_server:
331 socket_from = self._client
332 socket_to = self._server
333 else:
334 socket_from = self._server
335 socket_to = self._client
336
337 try:
338 while True:
339 # Applies new interceptors faster.
340 #
341 # To avoid long awaiting on sock_recv in an outdated
342 # interceptor we wait for data before grabbing and applying
343 # the interceptor.
344 if not _incoming_data_size(socket_from):
345 await _yield()
346 continue
347
348 if to_server:
349 interceptor = self._to_server_intercept
350 else:
351 interceptor = self._to_client_intercept
352
353 await interceptor(self._loop, socket_from, socket_to)
354 await _yield()
355 except GateInterceptException as exc:
356 logger.info('In "%s": %s', self._proxy_name, exc)
357 except socket.error as exc:
358 logger.error('Exception in "%s": %s', self._proxy_name, exc)
359 finally:
360 self._finished_channels += 1
361 if self._finished_channels == 2:
362 # Closing the sockets here so that the self.shutdown()
363 # returns only when the sockets are actually closed
364 logger.info('"%s" closes %s', self._proxy_name, self.info())
365 self._close_socket(self._client)
366 self._close_socket(self._server)
367 else:
368 assert self._finished_channels == 1
369 if to_server:
370 self._task_to_client.cancel()
371 else:
372 self._task_to_server.cancel()
373
374 def set_to_server_interceptor(self, interceptor: Interceptor) -> None:
375 self._to_server_intercept = interceptor
376
377 def set_to_client_interceptor(self, interceptor: Interceptor) -> None:
378 self._to_client_intercept = interceptor
379
380 def _close_socket(self, self_socket: Socket) -> None:
381 assert self_socket in {self._client, self._server}
382 try:
383 self_socket.close()
384 except socket.error as exc:
385 logger.error(
386 'Exception in "%s" on closing %s: %s',
387 self._proxy_name,
388 'client' if self_socket == self._client else 'server',
389 exc,
390 )
391
392 async def shutdown(self) -> None:
393 for task in {self._task_to_client, self._task_to_server}:
394 await _cancel_and_join(task)
395
396 def is_active(self) -> bool:
397 return (
398 not self._task_to_client.done() or not self._task_to_server.done()
399 )
400
401 def info(self) -> str:
402 if not self.is_active():
403 return '<inactive>'
404
405 return (
406 f'client fd={self._client.fileno()} <=> '
407 f'server fd={self._server.fileno()}'
408 )
409
410
411# @endcond
412
413
414class BaseGate:
415 """
416 This base class maintain endpoints of two types:
417
418 Server-side endpoints to receive messages from clients. Address of this
419 endpoint is described by (host_for_client, port_for_client).
420
421 Client-side endpoints to forward messages to server. Server must listen on
422 (host_to_server, port_to_server).
423
424 Asynchronously concurrently passes data from client to server and from
425 server to client, allowing intercepting the data, injecting delays and
426 dropping connections.
427
428 @warning Do not use this class itself. Use one of the specifications
429 TcpGate or UdpGate
430
431 @ingroup userver_testsuite
432
433 @see @ref scripts/docs/en/userver/chaos_testing.md
434 """
435
436 _NOT_IMPLEMENTED_MESSAGE = (
437 'Do not use BaseGate itself, use one of '
438 'specializations TcpGate or UdpGate'
439 )
440
441 def __init__(self, route: GateRoute, loop: EvLoop) -> None:
442 self._route = route
443 self._loop = loop
444
445 self._to_server_intercept: Interceptor = _intercept_ok
446 self._to_client_intercept: Interceptor = _intercept_ok
447
448 self._accept_sockets: typing.List[socket.socket] = []
449 self._accept_tasks: typing.List[asyncio.Task[None]] = []
450
451 self._connected_event = asyncio.Event()
452
453 self._sockets: typing.Set[_SocketsPaired] = set()
454
455 async def __aenter__(self) -> 'BaseGate':
456 self.start()
457 return self
458
459 async def __aexit__(self, exc_type, exc_value, traceback) -> None:
460 await self.stop()
461
462 def _create_accepting_sockets(self) -> typing.List[Socket]:
463 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
464
465 def start(self):
466 """ Open the socket and start accepting tasks """
467 if self._accept_sockets:
468 return
469
470 self._accept_sockets.extend(self._create_accepting_sockets())
471
472 if not self._accept_sockets:
473 raise GateException(
474 f'Could not resolve hostname {self._route.host_for_client}',
475 )
476
477 if self._route.port_for_client == 0:
478 # In case of stop()+start() bind to the same port
479 self._route = GateRoute(
480 name=self._route.name,
481 host_to_server=self._route.host_to_server,
482 port_to_server=self._route.port_to_server,
483 host_for_client=self._accept_sockets[0].getsockname()[0],
484 port_for_client=self._accept_sockets[0].getsockname()[1],
485 )
486
487 BaseGate.start_accepting(self)
488
489 def start_accepting(self) -> None:
490 """ Start accepting tasks """
491 assert self._accept_sockets
492 if not all(tsk.done() for tsk in self._accept_tasks):
493 return
494
495 self._accept_tasks.clear()
496 for sock in self._accept_sockets:
497 self._accept_tasks.append(
498 asyncio.create_task(self._do_accept(sock)),
499 )
500
501 async def stop_accepting(self) -> None:
502 """
503 Stop accepting tasks without closing the accepting socket.
504 """
505 for tsk in self._accept_tasks:
506 await _cancel_and_join(tsk)
507 self._accept_tasks.clear()
508
509 async def stop(self) -> None:
510 """
511 Stop accepting tasks, close all the sockets
512 """
513 if not self._accept_sockets and not self._sockets:
514 return
515
516 self.to_server_pass()
517 self.to_client_pass()
518
519 for sock in self._accept_sockets:
520 sock.close()
521
522 await BaseGate.stop_accepting(self)
523 logger.info('Before close() %s', self.info())
524 await self.sockets_close()
525 assert not self._sockets
526
527 self._accept_sockets.clear()
528 logger.info('Stopped. %s', self.info())
529
530 async def sockets_close(
531 self, *, count: typing.Optional[int] = None,
532 ) -> None:
533 """ Close all the connection going through the gate """
534 for x in list(self._sockets)[0:count]:
535 await x.shutdown()
536 self._collect_garbage()
537
538 def get_sockname_for_clients(self) -> Address:
539 """
540 Returns the client socket address that the gate listens on.
541
542 This function allows to use 0 in GateRoute.port_for_client and retrieve
543 the actual port and host.
544 """
545 assert self._route.port_for_client != 0, (
546 'Gate was not started and the port_for_client is still 0',
547 )
548 return (self._route.host_for_client, self._route.port_for_client)
549
550 def info(self) -> str:
551 """ Print info on open sockets """
552 if not self._sockets:
553 return f'"{self._route.name}" no active sockets'
554
555 return f'"{self._route.name}" active sockets:\n\t' + '\n\t'.join(
556 x.info() for x in self._sockets
557 )
558
559 def _collect_garbage(self) -> None:
560 self._sockets = {x for x in self._sockets if x.is_active()}
561
562 async def _do_accept(self, accept_sock: Socket) -> None:
563 """
564 This task should wait for connections and create SocketPair
565 """
566 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
567
568 def set_to_server_interceptor(self, interceptor: Interceptor) -> None:
569 """
570 Replace existing interceptor of client to server data with a custom
571 """
572 self._to_server_intercept = interceptor
573 for x in self._sockets:
574 x.set_to_server_interceptor(self._to_server_intercept)
575
576 def set_to_client_interceptor(self, interceptor: Interceptor) -> None:
577 """
578 Replace existing interceptor of server to client data with a custom
579 """
580 self._to_client_intercept = interceptor
581 for x in self._sockets:
582 x.set_to_client_interceptor(self._to_client_intercept)
583
584 def to_server_pass(self) -> None:
585 """ Pass data as is """
586 logging.trace('to_server_pass')
587 self.set_to_server_interceptor(_intercept_ok)
588
589 def to_client_pass(self) -> None:
590 """ Pass data as is """
591 logging.trace('to_client_pass')
592 self.set_to_client_interceptor(_intercept_ok)
593
594 def to_server_noop(self) -> None:
595 """ Do not read data, causing client to keep multiple data """
596 logging.trace('to_server_noop')
597 self.set_to_server_interceptor(_intercept_noop)
598
599 def to_client_noop(self) -> None:
600 """ Do not read data, causing server to keep multiple data """
601 logging.trace('to_client_noop')
602 self.set_to_client_interceptor(_intercept_noop)
603
604 def to_server_drop(self) -> None:
605 """ Read and discard data """
606 logging.trace('to_server_drop')
607 self.set_to_server_interceptor(_intercept_drop)
608
609 def to_client_drop(self) -> None:
610 """ Read and discard data """
611 logging.trace('to_client_drop')
612 self.set_to_client_interceptor(_intercept_drop)
613
614 def to_server_delay(self, delay: float) -> None:
615 """ Delay data transmission """
616 logging.trace('to_server_delay, delay: %s', delay)
617
618 async def _intercept_delay_bound(
619 loop: EvLoop, socket_from: Socket, socket_to: Socket,
620 ) -> None:
621 await _intercept_delay(delay, loop, socket_from, socket_to)
622
623 self.set_to_server_interceptor(_intercept_delay_bound)
624
625 def to_client_delay(self, delay: float) -> None:
626 """ Delay data transmission """
627 logging.trace('to_client_delay, delay: %s', delay)
628
629 async def _intercept_delay_bound(
630 loop: EvLoop, socket_from: Socket, socket_to: Socket,
631 ) -> None:
632 await _intercept_delay(delay, loop, socket_from, socket_to)
633
634 self.set_to_client_interceptor(_intercept_delay_bound)
635
636 def to_server_close_on_data(self) -> None:
637 """ Close on first bytes of data from client """
638 logging.trace('to_server_close_on_data')
639 self.set_to_server_interceptor(_intercept_close_on_data)
640
641 def to_client_close_on_data(self) -> None:
642 """ Close on first bytes of data from server """
643 logging.trace('to_client_close_on_data')
644 self.set_to_client_interceptor(_intercept_close_on_data)
645
646 def to_server_corrupt_data(self) -> None:
647 """ Corrupt data received from client """
648 logging.trace('to_server_corrupt_data')
649 self.set_to_server_interceptor(_intercept_corrupt)
650
651 def to_client_corrupt_data(self) -> None:
652 """ Corrupt data received from server """
653 logging.trace('to_client_corrupt_data')
654 self.set_to_client_interceptor(_intercept_corrupt)
655
656 def to_server_limit_bps(self, bytes_per_second: float) -> None:
657 """ Limit bytes per second transmission by network from client """
658 logging.trace(
659 'to_server_limit_bps, bytes_per_second: %s', bytes_per_second,
660 )
661 self.set_to_server_interceptor(_InterceptBpsLimit(bytes_per_second))
662
663 def to_client_limit_bps(self, bytes_per_second: float) -> None:
664 """ Limit bytes per second transmission by network from server """
665 logging.trace(
666 'to_client_limit_bps, bytes_per_second: %s', bytes_per_second,
667 )
668 self.set_to_client_interceptor(_InterceptBpsLimit(bytes_per_second))
669
670 def to_server_limit_time(self, timeout: float, jitter: float) -> None:
671 """ Limit connection lifetime on receive of first bytes from client """
672 logging.trace(
673 'to_server_limit_time, timeout: %s, jitter: %s', timeout, jitter,
674 )
675 self.set_to_server_interceptor(_InterceptTimeLimit(timeout, jitter))
676
677 def to_client_limit_time(self, timeout: float, jitter: float) -> None:
678 """ Limit connection lifetime on receive of first bytes from server """
679 logging.trace(
680 'to_client_limit_time, timeout: %s, jitter: %s', timeout, jitter,
681 )
682 self.set_to_client_interceptor(_InterceptTimeLimit(timeout, jitter))
683
684 def to_server_smaller_parts(
685 self, max_size: int, *, sleep_per_packet: float = 0,
686 ) -> None:
687 """
688 Pass data to server in smaller parts
689
690 @param max_size Max packet size to send to server
691 @param sleep_per_packet Optional sleep interval per packet, seconds
692 """
693 logging.trace('to_server_smaller_parts, max_size: %s', max_size)
694 self.set_to_server_interceptor(
695 _InterceptSmallerParts(max_size, sleep_per_packet),
696 )
697
698 def to_client_smaller_parts(
699 self, max_size: int, *, sleep_per_packet: float = 0,
700 ) -> None:
701 """
702 Pass data to client in smaller parts
703
704 @param max_size Max packet size to send to client
705 @param sleep_per_packet Optional sleep interval per packet, seconds
706 """
707 logging.trace('to_client_smaller_parts, max_size: %s', max_size)
708 self.set_to_client_interceptor(
709 _InterceptSmallerParts(max_size, sleep_per_packet),
710 )
711
712 def to_server_concat_packets(self, packet_size: int) -> None:
713 """
714 Pass data in bigger parts
715 @param packet_size minimal size of the resulting packet
716 """
717 logging.trace('to_server_concat_packets, packet_size: %s', packet_size)
718 self.set_to_server_interceptor(_InterceptConcatPackets(packet_size))
719
720 def to_client_concat_packets(self, packet_size: int) -> None:
721 """
722 Pass data in bigger parts
723 @param packet_size minimal size of the resulting packet
724 """
725 logging.trace('to_client_concat_packets, packet_size: %s', packet_size)
726 self.set_to_client_interceptor(_InterceptConcatPackets(packet_size))
727
728 def to_server_limit_bytes(self, bytes_limit: int) -> None:
729 """ Drop all connections each `bytes_limit` of data sent by network """
730 logging.trace('to_server_limit_bytes, bytes_limit: %s', bytes_limit)
731 self.set_to_server_interceptor(_InterceptBytesLimit(bytes_limit, self))
732
733 def to_client_limit_bytes(self, bytes_limit: int) -> None:
734 """ Drop all connections each `bytes_limit` of data sent by network """
735 logging.trace('to_client_limit_bytes, bytes_limit: %s', bytes_limit)
736 self.set_to_client_interceptor(_InterceptBytesLimit(bytes_limit, self))
737
738 def to_server_substitute(self, pattern: str, repl: str) -> None:
739 """ Apply regex substitution to data from client """
740 logging.trace(
741 'to_server_substitute, pattern: %s, repl: %s', pattern, repl,
742 )
743 self.set_to_server_interceptor(_InterceptSubstitute(pattern, repl))
744
745 def to_client_substitute(self, pattern: str, repl: str) -> None:
746 """ Apply regex substitution to data from server """
747 logging.trace(
748 'to_client_substitute, pattern: %s, repl: %s', pattern, repl,
749 )
750 self.set_to_client_interceptor(_InterceptSubstitute(pattern, repl))
751
752
753class TcpGate(BaseGate):
754 """
755 Implements TCP chaos-proxy logic such as accepting incoming tcp client
756 connections. On each new connection new tcp client connects to server
757 (host_to_server, port_to_server).
758
759 @ingroup userver_testsuite
760
761 @see @ref scripts/docs/en/userver/chaos_testing.md
762 """
763
764 def __init__(self, route: GateRoute, loop: EvLoop) -> None:
765 BaseGate.__init__(self, route, loop)
766
767 def connections_count(self) -> int:
768 """
769 Returns maximal amount of connections going through the gate at
770 the moment.
771
772 @warning Some of the connections could be closing, or could be opened
773 right before the function starts. Use with caution!
774 """
775 return len(self._sockets)
776
777 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
778 """
779 Wait for at least `count` connections going through the gate.
780
781 @throws asyncio.TimeoutError exception if failed to get the
782 required amount of connections in time.
783 """
784 if timeout <= 0.0:
785 while self.connections_count() < count:
786 await self._connected_event.wait()
787 self._connected_event.clear()
788 return
789
790 deadline = time.monotonic() + timeout
791 while self.connections_count() < count:
792 time_left = deadline - time.monotonic()
793 await asyncio.wait_for(
794 self._connected_event.wait(), timeout=time_left,
795 )
796 self._connected_event.clear()
797
798 def _make_socket_nonblocking(self, sock: Socket) -> None:
799 sock.setblocking(False)
800 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
801 fcntl.fcntl(sock, fcntl.F_SETFL, os.O_NONBLOCK)
802
803 def _create_accepting_sockets(self) -> typing.List[Socket]:
804 res: typing.List[Socket] = []
805 for addr in socket.getaddrinfo(
806 self._route.host_for_client,
807 self._route.port_for_client,
808 type=socket.SOCK_STREAM,
809 ):
810 sock = Socket(addr[0], addr[1])
811 self._make_socket_nonblocking(sock)
812 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
813 sock.bind(addr[4])
814 sock.listen()
815 logger.debug(
816 f'Accepting connections on {sock.getsockname()}, '
817 f'fd={sock.fileno()}',
818 )
819 res.append(sock)
820
821 return res
822
823 async def _connect_to_server(self):
824 addrs = await self._loop.getaddrinfo(
825 self._route.host_to_server,
826 self._route.port_to_server,
827 type=socket.SOCK_STREAM,
828 )
829 for addr in addrs:
830 server = Socket(addr[0], addr[1])
831 try:
832 self._make_socket_nonblocking(server)
833 await self._loop.sock_connect(server, addr[4])
834 logging.trace('Connected to %s', addr[4])
835 return server
836 except Exception as exc: # pylint: disable=broad-except
837 server.close()
838 logging.warning('Could not connect to %s: %s', addr[4], exc)
839
840 async def _do_accept(self, accept_sock: Socket) -> None:
841 while accept_sock:
842 client, _ = await self._loop.sock_accept(accept_sock)
843 self._make_socket_nonblocking(client)
844
845 server = await self._connect_to_server()
846 if server:
847 self._sockets.add(
848 _SocketsPaired(
849 self._route.name,
850 self._loop,
851 client,
852 server,
853 self._to_server_intercept,
854 self._to_client_intercept,
855 ),
856 )
857 self._connected_event.set()
858 else:
859 client.close()
860
861 self._collect_garbage()
862
863
864class UdpGate(BaseGate):
865 """
866 Implements UDP chaos-proxy logic such as waiting for first
867 message and setting up sockets for forwarding messages between
868 udp-client and udp-server
869
870 @ingroup userver_testsuite
871
872 @see @ref scripts/docs/en/userver/chaos_testing.md
873 """
874
875 _NOT_IMPLEMENTED_IN_UDP = 'This method is not allowed in UDP gate'
876
877 def __init__(self, route: GateRoute, loop: EvLoop):
878 self._client_addr: typing.Optional[Address] = None
879 BaseGate.__init__(self, route, loop)
880
881 def is_connected(self) -> bool:
882 """
883 Returns True if there is active pair of sockets ready to transfer data
884 at the moment.
885 """
886 return len(self._sockets) > 0
887
888 def _make_socket_nonblocking(self, sock: Socket) -> None:
889 sock.setblocking(False)
890 fcntl.fcntl(sock, fcntl.F_SETFL, os.O_NONBLOCK)
891
892 def _create_accepting_sockets(self) -> typing.List[Socket]:
893 res: typing.List[Socket] = []
894 for addr in socket.getaddrinfo(
895 self._route.host_for_client,
896 self._route.port_for_client,
897 type=socket.SOCK_DGRAM,
898 ):
899 sock = socket.socket(addr[0], addr[1])
900 self._make_socket_nonblocking(sock)
901 sock.bind(addr[4])
902 logger.debug(f'Accepting connections on {sock.getsockname()}')
903 res.append(sock)
904
905 return res
906
907 async def _connect_to_server(self):
908 addrs = await self._loop.getaddrinfo(
909 self._route.host_to_server,
910 self._route.port_to_server,
911 type=socket.SOCK_DGRAM,
912 )
913 for addr in addrs:
914 server = Socket(addr[0], addr[1])
915 try:
916 self._make_socket_nonblocking(server)
917 await self._loop.sock_connect(server, addr[4])
918 logging.trace('Connected to %s', addr[4])
919 return server
920 except Exception as exc: # pylint: disable=broad-except
921 logging.warning('Could not connect to %s: %s', addr[4], exc)
922
923 async def _do_accept(self, accept_sock: Socket):
924 if not accept_sock:
925 return
926
927 _, addr = await _wait_for_message_task(accept_sock)
928
929 self._client_addr = addr
930 try:
931 await self._loop.sock_connect(accept_sock, addr)
932 except Exception as exc: # pylint: disable=broad-except
933 logging.warning('Could not connect to %s: %s', addr, exc)
934
935 server = await self._connect_to_server()
936 if server:
937 self._sockets.add(
938 _SocketsPaired(
939 self._route.name,
940 self._loop,
941 accept_sock,
942 server,
943 self._to_server_intercept,
944 self._to_client_intercept,
945 ),
946 )
947 self._connected_event.set()
948 else:
949 accept_sock.close()
950
951 self._collect_garbage()
952
953 def start_accepting(self) -> None:
954 raise NotImplementedError(
955 'Since UdpGate can only have one connection, you cannot start or '
956 'stop accepting tasks manually. Use start() and stop() methods to '
957 'stop data transferring',
958 )
959
960 async def stop_accepting(self) -> None:
961 raise NotImplementedError(
962 'Since UdpGate can only have one connection, you cannot start or '
963 'stop accepting tasks manually. Use start() and stop() methods to '
964 'stop data transferring',
965 )
966
967 def to_server_concat_packets(self, packet_size: int) -> None:
968 raise NotImplementedError('Udp packets cannot be concatenated')
969
970 def to_client_concat_packets(self, packet_size: int) -> None:
971 raise NotImplementedError('Udp packets cannot be concatenated')
972
973 def to_server_smaller_parts(
974 self, max_size: int, *, sleep_per_packet: float = 0,
975 ) -> None:
976 raise NotImplementedError('Udp packets cannot be split')
977
978 def to_client_smaller_parts(
979 self, max_size: int, *, sleep_per_packet: float = 0,
980 ) -> None:
981 raise NotImplementedError('Udp packets cannot be split')