26 Class that describes the route for TcpGate or UdpGate.
28 Use `port_for_client == 0` to bind to some unused port. In that case the
29 actual address could be retrieved via BaseGate.get_sockname_for_clients().
31 @ingroup userver_testsuite
37 host_for_client: str =
'127.0.0.1'
38 port_for_client: int = 0
48logger = logging.getLogger(__name__)
51Address = typing.Tuple[str, int]
54Interceptor = typing.Callable[
55 [EvLoop, Socket, Socket], typing.Coroutine[typing.Any, typing.Any,
None],
59class GateException(Exception):
63class GateInterceptException(Exception):
67async def _yield() -> None:
72 await asyncio.sleep(min_delay)
76 recv_socket: Socket, flags: int,
77) -> typing.Tuple[typing.Optional[bytes], typing.Optional[Address]]:
79 return recv_socket.recvfrom(RECV_MAX_SIZE, flags)
80 except (BlockingIOError, InterruptedError):
84async def _get_message_task(
86) -> typing.Tuple[bytes, Address]:
88 msg, addr = _try_get_message(recv_socket, 0)
96def _incoming_data_size(recv_socket: Socket) -> int:
97 msg, _ = _try_get_message(recv_socket, socket.MSG_PEEK)
98 return len(msg)
if msg
else 0
101async def _intercept_ok(
102 loop: EvLoop, socket_from: Socket, socket_to: Socket,
104 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
105 await loop.sock_sendall(socket_to, data)
108async def _intercept_noop(
109 loop: EvLoop, socket_from: Socket, socket_to: Socket,
114async def _intercept_drop(
115 loop: EvLoop, socket_from: Socket, socket_to: Socket,
117 await loop.sock_recv(socket_from, RECV_MAX_SIZE)
120async def _intercept_delay(
121 delay: float, loop: EvLoop, socket_from: Socket, socket_to: Socket,
123 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
124 await asyncio.sleep(delay)
125 await loop.sock_sendall(socket_to, data)
128async def _intercept_close_on_data(
129 loop: EvLoop, socket_from: Socket, socket_to: Socket,
131 await loop.sock_recv(socket_from, 1)
132 raise GateInterceptException(
'Closing socket on data')
135async def _intercept_corrupt(
136 loop: EvLoop, socket_from: Socket, socket_to: Socket,
138 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
139 await loop.sock_sendall(socket_to, bytearray([
not x
for x
in data]))
142class _InterceptBpsLimit:
143 def __init__(self, bytes_per_second: float):
144 assert bytes_per_second >= 1
145 self._bytes_per_second = bytes_per_second
146 self._time_last_added = 0.0
147 self._bytes_left = self._bytes_per_second
149 def _update_limit(self) -> None:
150 current_time = time.monotonic()
151 elapsed = current_time - self._time_last_added
152 bytes_addition = self._bytes_per_second * elapsed
153 if bytes_addition > 0:
154 self._bytes_left += bytes_addition
155 self._time_last_added = current_time
157 if self._bytes_left > self._bytes_per_second:
158 self._bytes_left = self._bytes_per_second
161 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
165 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
166 if bytes_to_recv > 0:
167 data = await loop.sock_recv(socket_from, bytes_to_recv)
168 self._bytes_left -= len(data)
170 await loop.sock_sendall(socket_to, data)
172 logger.info(
'Socket hits the bytes per second limit')
173 await asyncio.sleep(1.0 / self._bytes_per_second)
176class _InterceptTimeLimit:
177 def __init__(self, timeout: float, jitter: float):
178 self._sockets: typing.Dict[Socket, float] = {}
179 assert timeout >= 0.0
180 self._timeout = timeout
182 self._jitter = jitter
184 def raise_if_timed_out(self, socket_from: Socket) ->
None:
185 if socket_from
not in self._sockets:
186 jitter = self._jitter * random.random()
187 expire_at = time.monotonic() + self._timeout + jitter
188 self._sockets[socket_from] = expire_at
190 if self._sockets[socket_from] <= time.monotonic():
191 del self._sockets[socket_from]
192 raise GateInterceptException(
'Socket hits the time limit')
195 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
197 self.raise_if_timed_out(socket_from)
198 await _intercept_ok(loop, socket_from, socket_to)
201class _InterceptSmallerParts:
202 def __init__(self, max_size: int, sleep_per_packet: float):
204 self._max_size = max_size
205 self._sleep_per_packet = sleep_per_packet
208 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
210 incoming_size = _incoming_data_size(socket_from)
211 chunk_size = min(incoming_size, self._max_size)
212 data = await loop.sock_recv(socket_from, chunk_size)
213 await asyncio.sleep(self._sleep_per_packet)
214 await loop.sock_sendall(socket_to, data)
217class _InterceptConcatPackets:
218 def __init__(self, packet_size: int):
219 assert packet_size >= 0
220 self._packet_size = packet_size
221 self._expire_at: typing.Optional[float] =
None
224 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
226 if self._expire_at
is None:
227 self._expire_at = time.monotonic() + MAX_DELAY
229 if self._expire_at <= time.monotonic():
231 f
'Failed to make a packet of sufficient size in {MAX_DELAY} '
232 'seconds. Check the test logic, it should end with checking '
233 'that the data was sent and by calling TcpGate function '
234 'to_client_pass() to pass the remaining packets.',
238 incoming_size = _incoming_data_size(socket_from)
239 if incoming_size >= self._packet_size:
240 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
241 await loop.sock_sendall(socket_to, data)
242 self._expire_at =
None
245class _InterceptBytesLimit:
246 def __init__(self, bytes_limit: int, gate:
'BaseGate'):
247 assert bytes_limit >= 0
248 self._bytes_limit = bytes_limit
249 self._bytes_remain = self._bytes_limit
253 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
255 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
256 if self._bytes_remain <= len(data):
257 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
258 await self._gate.sockets_close()
259 self._bytes_remain = self._bytes_limit
260 raise GateInterceptException(
'Data transmission limit reached')
262 self._bytes_remain -= len(data)
263 await loop.sock_sendall(socket_to, data)
266class _InterceptSubstitute:
267 def __init__(self, pattern: str, repl: str, encoding=
'utf-8'):
268 self._pattern = re.compile(pattern)
270 self._encoding = encoding
273 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
275 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
277 res = self._pattern.sub(self._repl, data.decode(self._encoding))
278 data = res.encode(self._encoding)
281 await loop.sock_sendall(socket_to, data)
284async def _cancel_and_join(task: typing.Optional[asyncio.Task]) ->
None:
285 if not task
or task.cancelled():
291 except asyncio.CancelledError:
293 except Exception
as exc:
294 logger.error(
'Exception in _cancel_and_join: %s', exc)
297def _make_socket_nonblocking(sock: Socket) ->
None:
298 sock.setblocking(
False)
299 if sock.type == socket.SOCK_STREAM:
300 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
301 fcntl.fcntl(sock, fcntl.F_SETFL, os.O_NONBLOCK)
304class _UdpDemuxSocketMock:
306 Emulates a point-to-point connection over UDP socket
307 with a non-blocking socket interface
310 def __init__(self, sock: Socket, peer_address: Address):
311 self._sock: Socket = sock
312 self._peeraddr: Address = peer_address
314 sockpair = socket.socketpair(type=socket.SOCK_DGRAM)
315 self._demux_in: Socket = sockpair[0]
316 self._demux_out: Socket = sockpair[1]
317 _make_socket_nonblocking(self._demux_in)
318 _make_socket_nonblocking(self._demux_out)
319 self._is_active: bool =
True
322 def peer_address(self):
323 return self._peeraddr
325 async def push(self, loop: EvLoop, data: bytes):
326 return await loop.sock_sendall(self._demux_in, data)
329 return self._is_active
332 self._is_active =
False
333 self._demux_out.close()
334 self._demux_in.close()
336 def recvfrom(self, bufsize: int, flags: int = 0):
337 return self._demux_out.recvfrom(bufsize, flags)
339 def recv(self, bufsize: int, flags: int = 0):
340 return self._demux_out.recv(bufsize, flags)
343 return self._demux_out.fileno()
345 def send(self, data: bytes):
346 return self._sock.sendto(data, self._peeraddr)
354 client: typing.Union[socket.socket, _UdpDemuxSocketMock],
355 server: socket.socket,
356 to_server_intercept: Interceptor,
357 to_client_intercept: Interceptor,
359 self._proxy_name = proxy_name
362 self._client = client
363 self._server = server
365 self._to_server_intercept: Interceptor = to_server_intercept
366 self._to_client_intercept: Interceptor = to_client_intercept
368 self._task_to_server = asyncio.create_task(
369 self._do_pipe_channels(to_server=
True),
371 self._task_to_client = asyncio.create_task(
372 self._do_pipe_channels(to_server=
False),
375 self._finished_channels = 0
377 async def _do_pipe_channels(self, *, to_server: bool) ->
None:
379 socket_from = self._client
380 socket_to = self._server
382 socket_from = self._server
383 socket_to = self._client
392 if not _incoming_data_size(socket_from):
397 interceptor = self._to_server_intercept
399 interceptor = self._to_client_intercept
401 await interceptor(self._loop, socket_from, socket_to)
403 except GateInterceptException
as exc:
404 logger.info(
'In "%s": %s', self._proxy_name, exc)
405 except socket.error
as exc:
406 logger.error(
'Exception in "%s": %s', self._proxy_name, exc)
408 self._finished_channels += 1
409 if self._finished_channels == 2:
412 logger.info(
'"%s" closes %s', self._proxy_name, self.info())
413 self._close_socket(self._client)
414 self._close_socket(self._server)
416 assert self._finished_channels == 1
418 self._task_to_client.cancel()
420 self._task_to_server.cancel()
422 def set_to_server_interceptor(self, interceptor: Interceptor) ->
None:
423 self._to_server_intercept = interceptor
425 def set_to_client_interceptor(self, interceptor: Interceptor) ->
None:
426 self._to_client_intercept = interceptor
428 def _close_socket(self, self_socket: Socket) ->
None:
429 assert self_socket
in {self._client, self._server}
432 except socket.error
as exc:
434 'Exception in "%s" on closing %s: %s',
436 'client' if self_socket == self._client
else 'server',
440 async def shutdown(self) -> None:
441 for task
in {self._task_to_client, self._task_to_server}:
442 await _cancel_and_join(task)
444 def is_active(self) -> bool:
446 not self._task_to_client.done()
or not self._task_to_server.done()
449 def info(self) -> str:
450 if not self.is_active():
454 f
'client fd={self._client.fileno()} <=> '
455 f
'server fd={self._server.fileno()}'