26    Class that describes the route for TcpGate or UdpGate. 
   28    Use `port_for_client == 0` to bind to some unused port. In that case the 
   29    actual address could be retrieved via BaseGate.get_sockname_for_clients(). 
   31    @ingroup userver_testsuite 
   37    host_for_client: str = 
'127.0.0.1' 
   38    port_for_client: int = 0
 
   48logger = logging.getLogger(__name__)
 
   51Address = typing.Tuple[str, int]
 
   54Interceptor = typing.Callable[
 
   55    [EvLoop, Socket, Socket], typing.Coroutine[typing.Any, typing.Any, 
None],
 
   59class GateException(Exception):
 
   63class GateInterceptException(Exception):
 
   67async def _yield() -> None:
 
   72    await asyncio.sleep(min_delay)
 
   76        recv_socket: Socket, flags: int,
 
   77) -> typing.Tuple[typing.Optional[bytes], typing.Optional[Address]]:
 
   79        return recv_socket.recvfrom(RECV_MAX_SIZE, flags)
 
   80    except (BlockingIOError, InterruptedError):
 
   84async def _get_message_task(
 
   86) -> typing.Tuple[bytes, Address]:
 
   88        msg, addr = _try_get_message(recv_socket, 0)
 
   96def _incoming_data_size(recv_socket: Socket) -> int:
 
   97    msg, _ = _try_get_message(recv_socket, socket.MSG_PEEK)
 
   98    return len(msg) 
if msg 
else 0
 
  101async def _intercept_ok(
 
  102        loop: EvLoop, socket_from: Socket, socket_to: Socket,
 
  104    data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
 
  105    await loop.sock_sendall(socket_to, data)
 
  108async def _intercept_noop(
 
  109        loop: EvLoop, socket_from: Socket, socket_to: Socket,
 
  114async def _intercept_drop(
 
  115        loop: EvLoop, socket_from: Socket, socket_to: Socket,
 
  117    await loop.sock_recv(socket_from, RECV_MAX_SIZE)
 
  120async def _intercept_delay(
 
  121        delay: float, loop: EvLoop, socket_from: Socket, socket_to: Socket,
 
  123    data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
 
  124    await asyncio.sleep(delay)
 
  125    await loop.sock_sendall(socket_to, data)
 
  128async def _intercept_close_on_data(
 
  129        loop: EvLoop, socket_from: Socket, socket_to: Socket,
 
  131    await loop.sock_recv(socket_from, 1)
 
  132    raise GateInterceptException(
'Closing socket on data')
 
  135async def _intercept_corrupt(
 
  136        loop: EvLoop, socket_from: Socket, socket_to: Socket,
 
  138    data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
 
  139    await loop.sock_sendall(socket_to, bytearray([
not x 
for x 
in data]))
 
  142class _InterceptBpsLimit:
 
  143    def __init__(self, bytes_per_second: float):
 
  144        assert bytes_per_second >= 1
 
  145        self._bytes_per_second = bytes_per_second
 
  146        self._time_last_added = 0.0
 
  147        self._bytes_left = self._bytes_per_second
 
  149    def _update_limit(self) -> None:
 
  150        current_time = time.monotonic()
 
  151        elapsed = current_time - self._time_last_added
 
  152        bytes_addition = self._bytes_per_second * elapsed
 
  153        if bytes_addition > 0:
 
  154            self._bytes_left += bytes_addition
 
  155            self._time_last_added = current_time
 
  157            if self._bytes_left > self._bytes_per_second:
 
  158                self._bytes_left = self._bytes_per_second
 
  161            self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
 
  165        bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
 
  166        if bytes_to_recv > 0:
 
  167            data = await loop.sock_recv(socket_from, bytes_to_recv)
 
  168            self._bytes_left -= len(data)
 
  170            await loop.sock_sendall(socket_to, data)
 
  172            logger.info(
'Socket hits the bytes per second limit')
 
  173            await asyncio.sleep(1.0 / self._bytes_per_second)
 
  176class _InterceptTimeLimit:
 
  177    def __init__(self, timeout: float, jitter: float):
 
  178        self._sockets: typing.Dict[Socket, float] = {}
 
  179        assert timeout >= 0.0
 
  180        self._timeout = timeout
 
  182        self._jitter = jitter
 
  184    def raise_if_timed_out(self, socket_from: Socket) -> 
None:
 
  185        if socket_from 
not in self._sockets:
 
  186            jitter = self._jitter * random.random()
 
  187            expire_at = time.monotonic() + self._timeout + jitter
 
  188            self._sockets[socket_from] = expire_at
 
  190        if self._sockets[socket_from] <= time.monotonic():
 
  191            del self._sockets[socket_from]
 
  192            raise GateInterceptException(
'Socket hits the time limit')
 
  195            self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
 
  197        self.raise_if_timed_out(socket_from)
 
  198        await _intercept_ok(loop, socket_from, socket_to)
 
  201class _InterceptSmallerParts:
 
  202    def __init__(self, max_size: int, sleep_per_packet: float):
 
  204        self._max_size = max_size
 
  205        self._sleep_per_packet = sleep_per_packet
 
  208            self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
 
  210        incoming_size = _incoming_data_size(socket_from)
 
  211        chunk_size = min(incoming_size, self._max_size)
 
  212        data = await loop.sock_recv(socket_from, chunk_size)
 
  213        await asyncio.sleep(self._sleep_per_packet)
 
  214        await loop.sock_sendall(socket_to, data)
 
  217class _InterceptConcatPackets:
 
  218    def __init__(self, packet_size: int):
 
  219        assert packet_size >= 0
 
  220        self._packet_size = packet_size
 
  221        self._expire_at: typing.Optional[float] = 
None 
  224            self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
 
  226        if self._expire_at 
is None:
 
  227            self._expire_at = time.monotonic() + MAX_DELAY
 
  229        if self._expire_at <= time.monotonic():
 
  231                f
'Failed to make a packet of sufficient size in {MAX_DELAY} ' 
  232                'seconds. Check the test logic, it should end with checking ' 
  233                'that the data was sent and by calling TcpGate function ' 
  234                'to_client_pass() to pass the remaining packets.',
 
  238        incoming_size = _incoming_data_size(socket_from)
 
  239        if incoming_size >= self._packet_size:
 
  240            data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
 
  241            await loop.sock_sendall(socket_to, data)
 
  242            self._expire_at = 
None 
  245class _InterceptBytesLimit:
 
  246    def __init__(self, bytes_limit: int, gate: 
'BaseGate'):
 
  247        assert bytes_limit >= 0
 
  248        self._bytes_limit = bytes_limit
 
  249        self._bytes_remain = self._bytes_limit
 
  253            self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
 
  255        data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
 
  256        if self._bytes_remain <= len(data):
 
  257            await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
 
  258            await self._gate.sockets_close()
 
  259            self._bytes_remain = self._bytes_limit
 
  260            raise GateInterceptException(
'Data transmission limit reached')
 
  262        self._bytes_remain -= len(data)
 
  263        await loop.sock_sendall(socket_to, data)
 
  266class _InterceptSubstitute:
 
  267    def __init__(self, pattern: str, repl: str, encoding=
'utf-8'):
 
  268        self._pattern = re.compile(pattern)
 
  270        self._encoding = encoding
 
  273            self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
 
  275        data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
 
  277            res = self._pattern.sub(self._repl, data.decode(self._encoding))
 
  278            data = res.encode(self._encoding)
 
  281        await loop.sock_sendall(socket_to, data)
 
  284async def _cancel_and_join(task: typing.Optional[asyncio.Task]) -> 
None:
 
  285    if not task 
or task.cancelled():
 
  291    except asyncio.CancelledError:
 
  293    except Exception 
as exc:  
 
  294        logger.error(
'Exception in _cancel_and_join: %s', exc)
 
  297def _make_socket_nonblocking(sock: Socket) -> 
None:
 
  298    sock.setblocking(
False)
 
  299    if sock.type == socket.SOCK_STREAM:
 
  300        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
  301    fcntl.fcntl(sock, fcntl.F_SETFL, os.O_NONBLOCK)
 
  304class _UdpDemuxSocketMock:
 
  306    Emulates a point-to-point connection over UDP socket 
  307    with a non-blocking socket interface 
  310    def __init__(self, sock: Socket, peer_address: Address):
 
  311        self._sock: Socket = sock
 
  312        self._peeraddr: Address = peer_address
 
  314        sockpair = socket.socketpair(type=socket.SOCK_DGRAM)
 
  315        self._demux_in: Socket = sockpair[0]
 
  316        self._demux_out: Socket = sockpair[1]
 
  317        _make_socket_nonblocking(self._demux_in)
 
  318        _make_socket_nonblocking(self._demux_out)
 
  319        self._is_active: bool = 
True 
  322    def peer_address(self):
 
  323        return self._peeraddr
 
  325    async def push(self, loop: EvLoop, data: bytes):
 
  326        return await loop.sock_sendall(self._demux_in, data)
 
  329        return self._is_active
 
  332        self._is_active = 
False 
  333        self._demux_out.close()
 
  334        self._demux_in.close()
 
  336    def recvfrom(self, bufsize: int, flags: int = 0):
 
  337        return self._demux_out.recvfrom(bufsize, flags)
 
  339    def recv(self, bufsize: int, flags: int = 0):
 
  340        return self._demux_out.recv(bufsize, flags)
 
  343        return self._demux_out.fileno()
 
  345    def send(self, data: bytes):
 
  346        return self._sock.sendto(data, self._peeraddr)
 
  354            client: typing.Union[socket.socket, _UdpDemuxSocketMock],
 
  355            server: socket.socket,
 
  356            to_server_intercept: Interceptor,
 
  357            to_client_intercept: Interceptor,
 
  359        self._proxy_name = proxy_name
 
  362        self._client = client
 
  363        self._server = server
 
  365        self._to_server_intercept: Interceptor = to_server_intercept
 
  366        self._to_client_intercept: Interceptor = to_client_intercept
 
  368        self._task_to_server = asyncio.create_task(
 
  369            self._do_pipe_channels(to_server=
True),
 
  371        self._task_to_client = asyncio.create_task(
 
  372            self._do_pipe_channels(to_server=
False),
 
  375        self._finished_channels = 0
 
  377    async def _do_pipe_channels(self, *, to_server: bool) -> 
None:
 
  379            socket_from = self._client
 
  380            socket_to = self._server
 
  382            socket_from = self._server
 
  383            socket_to = self._client
 
  392                if not _incoming_data_size(socket_from):
 
  397                    interceptor = self._to_server_intercept
 
  399                    interceptor = self._to_client_intercept
 
  401                await interceptor(self._loop, socket_from, socket_to)
 
  403        except GateInterceptException 
as exc:
 
  404            logger.info(
'In "%s": %s', self._proxy_name, exc)
 
  405        except socket.error 
as exc:
 
  406            logger.error(
'Exception in "%s": %s', self._proxy_name, exc)
 
  408            self._finished_channels += 1
 
  409            if self._finished_channels == 2:
 
  412                logger.info(
'"%s" closes  %s', self._proxy_name, self.info())
 
  413                self._close_socket(self._client)
 
  414                self._close_socket(self._server)
 
  416                assert self._finished_channels == 1
 
  418                    self._task_to_client.cancel()
 
  420                    self._task_to_server.cancel()
 
  422    def set_to_server_interceptor(self, interceptor: Interceptor) -> 
None:
 
  423        self._to_server_intercept = interceptor
 
  425    def set_to_client_interceptor(self, interceptor: Interceptor) -> 
None:
 
  426        self._to_client_intercept = interceptor
 
  428    def _close_socket(self, self_socket: Socket) -> 
None:
 
  429        assert self_socket 
in {self._client, self._server}
 
  432        except socket.error 
as exc:
 
  434                'Exception in "%s" on closing %s: %s',
 
  436                'client' if self_socket == self._client 
else 'server',
 
  440    async def shutdown(self) -> None:
 
  441        for task 
in {self._task_to_client, self._task_to_server}:
 
  442            await _cancel_and_join(task)
 
  444    def is_active(self) -> bool:
 
  446            not self._task_to_client.done() 
or not self._task_to_server.done()
 
  449    def info(self) -> str:
 
  450        if not self.is_active():
 
  454            f
'client fd={self._client.fileno()} <=> ' 
  455            f
'server fd={self._server.fileno()}'