buffered.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from __future__ import annotations
  2. import sys
  3. from collections.abc import Callable, Iterable, Mapping
  4. from dataclasses import dataclass, field
  5. from typing import Any, SupportsIndex
  6. from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead
  7. from ..abc import (
  8. AnyByteReceiveStream,
  9. AnyByteStream,
  10. AnyByteStreamConnectable,
  11. ByteReceiveStream,
  12. ByteStream,
  13. ByteStreamConnectable,
  14. )
  15. if sys.version_info >= (3, 12):
  16. from typing import override
  17. else:
  18. from typing_extensions import override
  19. @dataclass(eq=False)
  20. class BufferedByteReceiveStream(ByteReceiveStream):
  21. """
  22. Wraps any bytes-based receive stream and uses a buffer to provide sophisticated
  23. receiving capabilities in the form of a byte stream.
  24. """
  25. receive_stream: AnyByteReceiveStream
  26. _buffer: bytearray = field(init=False, default_factory=bytearray)
  27. _closed: bool = field(init=False, default=False)
  28. async def aclose(self) -> None:
  29. await self.receive_stream.aclose()
  30. self._closed = True
  31. @property
  32. def buffer(self) -> bytes:
  33. """The bytes currently in the buffer."""
  34. return bytes(self._buffer)
  35. @property
  36. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  37. return self.receive_stream.extra_attributes
  38. def feed_data(self, data: Iterable[SupportsIndex], /) -> None:
  39. """
  40. Append data directly into the buffer.
  41. Any data in the buffer will be consumed by receive operations before receiving
  42. anything from the wrapped stream.
  43. :param data: the data to append to the buffer (can be bytes or anything else
  44. that supports ``__index__()``)
  45. """
  46. self._buffer.extend(data)
  47. async def receive(self, max_bytes: int = 65536) -> bytes:
  48. if self._closed:
  49. raise ClosedResourceError
  50. if self._buffer:
  51. chunk = bytes(self._buffer[:max_bytes])
  52. del self._buffer[:max_bytes]
  53. return chunk
  54. elif isinstance(self.receive_stream, ByteReceiveStream):
  55. return await self.receive_stream.receive(max_bytes)
  56. else:
  57. # With a bytes-oriented object stream, we need to handle any surplus bytes
  58. # we get from the receive() call
  59. chunk = await self.receive_stream.receive()
  60. if len(chunk) > max_bytes:
  61. # Save the surplus bytes in the buffer
  62. self._buffer.extend(chunk[max_bytes:])
  63. return chunk[:max_bytes]
  64. else:
  65. return chunk
  66. async def receive_exactly(self, nbytes: int) -> bytes:
  67. """
  68. Read exactly the given amount of bytes from the stream.
  69. :param nbytes: the number of bytes to read
  70. :return: the bytes read
  71. :raises ~anyio.IncompleteRead: if the stream was closed before the requested
  72. amount of bytes could be read from the stream
  73. """
  74. while True:
  75. remaining = nbytes - len(self._buffer)
  76. if remaining <= 0:
  77. retval = self._buffer[:nbytes]
  78. del self._buffer[:nbytes]
  79. return bytes(retval)
  80. try:
  81. if isinstance(self.receive_stream, ByteReceiveStream):
  82. chunk = await self.receive_stream.receive(remaining)
  83. else:
  84. chunk = await self.receive_stream.receive()
  85. except EndOfStream as exc:
  86. raise IncompleteRead from exc
  87. self._buffer.extend(chunk)
  88. async def receive_until(self, delimiter: bytes, max_bytes: int) -> bytes:
  89. """
  90. Read from the stream until the delimiter is found or max_bytes have been read.
  91. :param delimiter: the marker to look for in the stream
  92. :param max_bytes: maximum number of bytes that will be read before raising
  93. :exc:`~anyio.DelimiterNotFound`
  94. :return: the bytes read (not including the delimiter)
  95. :raises ~anyio.IncompleteRead: if the stream was closed before the delimiter
  96. was found
  97. :raises ~anyio.DelimiterNotFound: if the delimiter is not found within the
  98. bytes read up to the maximum allowed
  99. """
  100. delimiter_size = len(delimiter)
  101. offset = 0
  102. while True:
  103. # Check if the delimiter can be found in the current buffer
  104. index = self._buffer.find(delimiter, offset)
  105. if index >= 0:
  106. found = self._buffer[:index]
  107. del self._buffer[: index + len(delimiter) :]
  108. return bytes(found)
  109. # Check if the buffer is already at or over the limit
  110. if len(self._buffer) >= max_bytes:
  111. raise DelimiterNotFound(max_bytes)
  112. # Read more data into the buffer from the socket
  113. try:
  114. data = await self.receive_stream.receive()
  115. except EndOfStream as exc:
  116. raise IncompleteRead from exc
  117. # Move the offset forward and add the new data to the buffer
  118. offset = max(len(self._buffer) - delimiter_size + 1, 0)
  119. self._buffer.extend(data)
  120. class BufferedByteStream(BufferedByteReceiveStream, ByteStream):
  121. """
  122. A full-duplex variant of :class:`BufferedByteReceiveStream`. All writes are passed
  123. through to the wrapped stream as-is.
  124. """
  125. def __init__(self, stream: AnyByteStream):
  126. """
  127. :param stream: the stream to be wrapped
  128. """
  129. super().__init__(stream)
  130. self._stream = stream
  131. @override
  132. async def send_eof(self) -> None:
  133. await self._stream.send_eof()
  134. @override
  135. async def send(self, item: bytes) -> None:
  136. await self._stream.send(item)
  137. class BufferedConnectable(ByteStreamConnectable):
  138. def __init__(self, connectable: AnyByteStreamConnectable):
  139. """
  140. :param connectable: the connectable to wrap
  141. """
  142. self.connectable = connectable
  143. @override
  144. async def connect(self) -> BufferedByteStream:
  145. stream = await self.connectable.connect()
  146. return BufferedByteStream(stream)