asyncpg.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. from __future__ import annotations
  2. import contextlib
  3. from typing import Any, TypeVar, Callable, Awaitable, Iterator
  4. import sentry_sdk
  5. from sentry_sdk.consts import OP, SPANDATA
  6. from sentry_sdk.integrations import _check_minimum_version, Integration, DidNotEnable
  7. from sentry_sdk.tracing import Span
  8. from sentry_sdk.tracing_utils import add_query_source, record_sql_queries
  9. from sentry_sdk.utils import (
  10. ensure_integration_enabled,
  11. parse_version,
  12. capture_internal_exceptions,
  13. )
  14. try:
  15. import asyncpg # type: ignore[import-not-found]
  16. from asyncpg.cursor import BaseCursor # type: ignore
  17. except ImportError:
  18. raise DidNotEnable("asyncpg not installed.")
  19. class AsyncPGIntegration(Integration):
  20. identifier = "asyncpg"
  21. origin = f"auto.db.{identifier}"
  22. _record_params = False
  23. def __init__(self, *, record_params: bool = False):
  24. AsyncPGIntegration._record_params = record_params
  25. @staticmethod
  26. def setup_once() -> None:
  27. # asyncpg.__version__ is a string containing the semantic version in the form of "<major>.<minor>.<patch>"
  28. asyncpg_version = parse_version(asyncpg.__version__)
  29. _check_minimum_version(AsyncPGIntegration, asyncpg_version)
  30. asyncpg.Connection.execute = _wrap_execute(
  31. asyncpg.Connection.execute,
  32. )
  33. asyncpg.Connection._execute = _wrap_connection_method(
  34. asyncpg.Connection._execute
  35. )
  36. asyncpg.Connection._executemany = _wrap_connection_method(
  37. asyncpg.Connection._executemany, executemany=True
  38. )
  39. asyncpg.Connection.cursor = _wrap_cursor_creation(asyncpg.Connection.cursor)
  40. asyncpg.Connection.prepare = _wrap_connection_method(asyncpg.Connection.prepare)
  41. asyncpg.connect_utils._connect_addr = _wrap_connect_addr(
  42. asyncpg.connect_utils._connect_addr
  43. )
  44. T = TypeVar("T")
  45. def _wrap_execute(f: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
  46. async def _inner(*args: Any, **kwargs: Any) -> T:
  47. if sentry_sdk.get_client().get_integration(AsyncPGIntegration) is None:
  48. return await f(*args, **kwargs)
  49. # Avoid recording calls to _execute twice.
  50. # Calls to Connection.execute with args also call
  51. # Connection._execute, which is recorded separately
  52. # args[0] = the connection object, args[1] is the query
  53. if len(args) > 2:
  54. return await f(*args, **kwargs)
  55. query = args[1]
  56. with record_sql_queries(
  57. cursor=None,
  58. query=query,
  59. params_list=None,
  60. paramstyle=None,
  61. executemany=False,
  62. span_origin=AsyncPGIntegration.origin,
  63. ) as span:
  64. res = await f(*args, **kwargs)
  65. with capture_internal_exceptions():
  66. add_query_source(span)
  67. return res
  68. return _inner
  69. SubCursor = TypeVar("SubCursor", bound=BaseCursor)
  70. @contextlib.contextmanager
  71. def _record(
  72. cursor: SubCursor | None,
  73. query: str,
  74. params_list: tuple[Any, ...] | None,
  75. *,
  76. executemany: bool = False,
  77. ) -> Iterator[Span]:
  78. integration = sentry_sdk.get_client().get_integration(AsyncPGIntegration)
  79. if integration is not None and not integration._record_params:
  80. params_list = None
  81. param_style = "pyformat" if params_list else None
  82. with record_sql_queries(
  83. cursor=cursor,
  84. query=query,
  85. params_list=params_list,
  86. paramstyle=param_style,
  87. executemany=executemany,
  88. record_cursor_repr=cursor is not None,
  89. span_origin=AsyncPGIntegration.origin,
  90. ) as span:
  91. yield span
  92. def _wrap_connection_method(
  93. f: Callable[..., Awaitable[T]], *, executemany: bool = False
  94. ) -> Callable[..., Awaitable[T]]:
  95. async def _inner(*args: Any, **kwargs: Any) -> T:
  96. if sentry_sdk.get_client().get_integration(AsyncPGIntegration) is None:
  97. return await f(*args, **kwargs)
  98. query = args[1]
  99. params_list = args[2] if len(args) > 2 else None
  100. with _record(None, query, params_list, executemany=executemany) as span:
  101. _set_db_data(span, args[0])
  102. res = await f(*args, **kwargs)
  103. return res
  104. return _inner
  105. def _wrap_cursor_creation(f: Callable[..., T]) -> Callable[..., T]:
  106. @ensure_integration_enabled(AsyncPGIntegration, f)
  107. def _inner(*args: Any, **kwargs: Any) -> T: # noqa: N807
  108. query = args[1]
  109. params_list = args[2] if len(args) > 2 else None
  110. with _record(
  111. None,
  112. query,
  113. params_list,
  114. executemany=False,
  115. ) as span:
  116. _set_db_data(span, args[0])
  117. res = f(*args, **kwargs)
  118. span.set_data("db.cursor", res)
  119. return res
  120. return _inner
  121. def _wrap_connect_addr(f: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
  122. async def _inner(*args: Any, **kwargs: Any) -> T:
  123. if sentry_sdk.get_client().get_integration(AsyncPGIntegration) is None:
  124. return await f(*args, **kwargs)
  125. user = kwargs["params"].user
  126. database = kwargs["params"].database
  127. with sentry_sdk.start_span(
  128. op=OP.DB,
  129. name="connect",
  130. origin=AsyncPGIntegration.origin,
  131. ) as span:
  132. span.set_data(SPANDATA.DB_SYSTEM, "postgresql")
  133. addr = kwargs.get("addr")
  134. if addr:
  135. try:
  136. span.set_data(SPANDATA.SERVER_ADDRESS, addr[0])
  137. span.set_data(SPANDATA.SERVER_PORT, addr[1])
  138. except IndexError:
  139. pass
  140. span.set_data(SPANDATA.DB_NAME, database)
  141. span.set_data(SPANDATA.DB_USER, user)
  142. with capture_internal_exceptions():
  143. sentry_sdk.add_breadcrumb(
  144. message="connect", category="query", data=span._data
  145. )
  146. res = await f(*args, **kwargs)
  147. return res
  148. return _inner
  149. def _set_db_data(span: Span, conn: Any) -> None:
  150. span.set_data(SPANDATA.DB_SYSTEM, "postgresql")
  151. addr = conn._addr
  152. if addr:
  153. try:
  154. span.set_data(SPANDATA.SERVER_ADDRESS, addr[0])
  155. span.set_data(SPANDATA.SERVER_PORT, addr[1])
  156. except IndexError:
  157. pass
  158. database = conn._params.database
  159. if database:
  160. span.set_data(SPANDATA.DB_NAME, database)
  161. user = conn._params.user
  162. if user:
  163. span.set_data(SPANDATA.DB_USER, user)