anthropic.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. from functools import wraps
  2. from typing import TYPE_CHECKING
  3. import sentry_sdk
  4. from sentry_sdk.ai.monitoring import record_token_usage
  5. from sentry_sdk.ai.utils import (
  6. set_data_normalized,
  7. normalize_message_roles,
  8. get_start_span_function,
  9. )
  10. from sentry_sdk.consts import OP, SPANDATA, SPANSTATUS
  11. from sentry_sdk.integrations import _check_minimum_version, DidNotEnable, Integration
  12. from sentry_sdk.scope import should_send_default_pii
  13. from sentry_sdk.tracing_utils import set_span_errored
  14. from sentry_sdk.utils import (
  15. capture_internal_exceptions,
  16. event_from_exception,
  17. package_version,
  18. safe_serialize,
  19. )
  20. try:
  21. try:
  22. from anthropic import NOT_GIVEN
  23. except ImportError:
  24. NOT_GIVEN = None
  25. from anthropic.resources import AsyncMessages, Messages
  26. if TYPE_CHECKING:
  27. from anthropic.types import MessageStreamEvent
  28. except ImportError:
  29. raise DidNotEnable("Anthropic not installed")
  30. if TYPE_CHECKING:
  31. from typing import Any, AsyncIterator, Iterator
  32. from sentry_sdk.tracing import Span
  33. class AnthropicIntegration(Integration):
  34. identifier = "anthropic"
  35. origin = f"auto.ai.{identifier}"
  36. def __init__(self, include_prompts=True):
  37. # type: (AnthropicIntegration, bool) -> None
  38. self.include_prompts = include_prompts
  39. @staticmethod
  40. def setup_once():
  41. # type: () -> None
  42. version = package_version("anthropic")
  43. _check_minimum_version(AnthropicIntegration, version)
  44. Messages.create = _wrap_message_create(Messages.create)
  45. AsyncMessages.create = _wrap_message_create_async(AsyncMessages.create)
  46. def _capture_exception(exc):
  47. # type: (Any) -> None
  48. set_span_errored()
  49. event, hint = event_from_exception(
  50. exc,
  51. client_options=sentry_sdk.get_client().options,
  52. mechanism={"type": "anthropic", "handled": False},
  53. )
  54. sentry_sdk.capture_event(event, hint=hint)
  55. def _get_token_usage(result):
  56. # type: (Messages) -> tuple[int, int]
  57. """
  58. Get token usage from the Anthropic response.
  59. """
  60. input_tokens = 0
  61. output_tokens = 0
  62. if hasattr(result, "usage"):
  63. usage = result.usage
  64. if hasattr(usage, "input_tokens") and isinstance(usage.input_tokens, int):
  65. input_tokens = usage.input_tokens
  66. if hasattr(usage, "output_tokens") and isinstance(usage.output_tokens, int):
  67. output_tokens = usage.output_tokens
  68. return input_tokens, output_tokens
  69. def _collect_ai_data(event, model, input_tokens, output_tokens, content_blocks):
  70. # type: (MessageStreamEvent, str | None, int, int, list[str]) -> tuple[str | None, int, int, list[str]]
  71. """
  72. Collect model information, token usage, and collect content blocks from the AI streaming response.
  73. """
  74. with capture_internal_exceptions():
  75. if hasattr(event, "type"):
  76. if event.type == "message_start":
  77. usage = event.message.usage
  78. input_tokens += usage.input_tokens
  79. output_tokens += usage.output_tokens
  80. model = event.message.model or model
  81. elif event.type == "content_block_start":
  82. pass
  83. elif event.type == "content_block_delta":
  84. if hasattr(event.delta, "text"):
  85. content_blocks.append(event.delta.text)
  86. elif hasattr(event.delta, "partial_json"):
  87. content_blocks.append(event.delta.partial_json)
  88. elif event.type == "content_block_stop":
  89. pass
  90. elif event.type == "message_delta":
  91. output_tokens += event.usage.output_tokens
  92. return model, input_tokens, output_tokens, content_blocks
  93. def _set_input_data(span, kwargs, integration):
  94. # type: (Span, dict[str, Any], AnthropicIntegration) -> None
  95. """
  96. Set input data for the span based on the provided keyword arguments for the anthropic message creation.
  97. """
  98. messages = kwargs.get("messages")
  99. if (
  100. messages is not None
  101. and len(messages) > 0
  102. and should_send_default_pii()
  103. and integration.include_prompts
  104. ):
  105. normalized_messages = []
  106. for message in messages:
  107. if (
  108. message.get("role") == "user"
  109. and "content" in message
  110. and isinstance(message["content"], (list, tuple))
  111. ):
  112. for item in message["content"]:
  113. if item.get("type") == "tool_result":
  114. normalized_messages.append(
  115. {
  116. "role": "tool",
  117. "content": {
  118. "tool_use_id": item.get("tool_use_id"),
  119. "output": item.get("content"),
  120. },
  121. }
  122. )
  123. else:
  124. normalized_messages.append(message)
  125. role_normalized_messages = normalize_message_roles(normalized_messages)
  126. set_data_normalized(
  127. span,
  128. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  129. role_normalized_messages,
  130. unpack=False,
  131. )
  132. set_data_normalized(
  133. span, SPANDATA.GEN_AI_RESPONSE_STREAMING, kwargs.get("stream", False)
  134. )
  135. kwargs_keys_to_attributes = {
  136. "max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
  137. "model": SPANDATA.GEN_AI_REQUEST_MODEL,
  138. "temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
  139. "top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
  140. "top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
  141. }
  142. for key, attribute in kwargs_keys_to_attributes.items():
  143. value = kwargs.get(key)
  144. if value is not NOT_GIVEN and value is not None:
  145. set_data_normalized(span, attribute, value)
  146. # Input attributes: Tools
  147. tools = kwargs.get("tools")
  148. if tools is not NOT_GIVEN and tools is not None and len(tools) > 0:
  149. set_data_normalized(
  150. span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
  151. )
  152. def _set_output_data(
  153. span,
  154. integration,
  155. model,
  156. input_tokens,
  157. output_tokens,
  158. content_blocks,
  159. finish_span=False,
  160. ):
  161. # type: (Span, AnthropicIntegration, str | None, int | None, int | None, list[Any], bool) -> None
  162. """
  163. Set output data for the span based on the AI response."""
  164. span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, model)
  165. if should_send_default_pii() and integration.include_prompts:
  166. output_messages = {
  167. "response": [],
  168. "tool": [],
  169. } # type: (dict[str, list[Any]])
  170. for output in content_blocks:
  171. if output["type"] == "text":
  172. output_messages["response"].append(output["text"])
  173. elif output["type"] == "tool_use":
  174. output_messages["tool"].append(output)
  175. if len(output_messages["tool"]) > 0:
  176. set_data_normalized(
  177. span,
  178. SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  179. output_messages["tool"],
  180. unpack=False,
  181. )
  182. if len(output_messages["response"]) > 0:
  183. set_data_normalized(
  184. span, SPANDATA.GEN_AI_RESPONSE_TEXT, output_messages["response"]
  185. )
  186. record_token_usage(
  187. span,
  188. input_tokens=input_tokens,
  189. output_tokens=output_tokens,
  190. )
  191. if finish_span:
  192. span.__exit__(None, None, None)
  193. def _sentry_patched_create_common(f, *args, **kwargs):
  194. # type: (Any, *Any, **Any) -> Any
  195. integration = kwargs.pop("integration")
  196. if integration is None:
  197. return f(*args, **kwargs)
  198. if "messages" not in kwargs:
  199. return f(*args, **kwargs)
  200. try:
  201. iter(kwargs["messages"])
  202. except TypeError:
  203. return f(*args, **kwargs)
  204. model = kwargs.get("model", "")
  205. span = get_start_span_function()(
  206. op=OP.GEN_AI_CHAT,
  207. name=f"chat {model}".strip(),
  208. origin=AnthropicIntegration.origin,
  209. )
  210. span.__enter__()
  211. _set_input_data(span, kwargs, integration)
  212. result = yield f, args, kwargs
  213. with capture_internal_exceptions():
  214. if hasattr(result, "content"):
  215. input_tokens, output_tokens = _get_token_usage(result)
  216. content_blocks = []
  217. for content_block in result.content:
  218. if hasattr(content_block, "to_dict"):
  219. content_blocks.append(content_block.to_dict())
  220. elif hasattr(content_block, "model_dump"):
  221. content_blocks.append(content_block.model_dump())
  222. elif hasattr(content_block, "text"):
  223. content_blocks.append({"type": "text", "text": content_block.text})
  224. _set_output_data(
  225. span=span,
  226. integration=integration,
  227. model=getattr(result, "model", None),
  228. input_tokens=input_tokens,
  229. output_tokens=output_tokens,
  230. content_blocks=content_blocks,
  231. finish_span=True,
  232. )
  233. # Streaming response
  234. elif hasattr(result, "_iterator"):
  235. old_iterator = result._iterator
  236. def new_iterator():
  237. # type: () -> Iterator[MessageStreamEvent]
  238. model = None
  239. input_tokens = 0
  240. output_tokens = 0
  241. content_blocks = [] # type: list[str]
  242. for event in old_iterator:
  243. model, input_tokens, output_tokens, content_blocks = (
  244. _collect_ai_data(
  245. event, model, input_tokens, output_tokens, content_blocks
  246. )
  247. )
  248. yield event
  249. _set_output_data(
  250. span=span,
  251. integration=integration,
  252. model=model,
  253. input_tokens=input_tokens,
  254. output_tokens=output_tokens,
  255. content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
  256. finish_span=True,
  257. )
  258. async def new_iterator_async():
  259. # type: () -> AsyncIterator[MessageStreamEvent]
  260. model = None
  261. input_tokens = 0
  262. output_tokens = 0
  263. content_blocks = [] # type: list[str]
  264. async for event in old_iterator:
  265. model, input_tokens, output_tokens, content_blocks = (
  266. _collect_ai_data(
  267. event, model, input_tokens, output_tokens, content_blocks
  268. )
  269. )
  270. yield event
  271. _set_output_data(
  272. span=span,
  273. integration=integration,
  274. model=model,
  275. input_tokens=input_tokens,
  276. output_tokens=output_tokens,
  277. content_blocks=[{"text": "".join(content_blocks), "type": "text"}],
  278. finish_span=True,
  279. )
  280. if str(type(result._iterator)) == "<class 'async_generator'>":
  281. result._iterator = new_iterator_async()
  282. else:
  283. result._iterator = new_iterator()
  284. else:
  285. span.set_data("unknown_response", True)
  286. span.__exit__(None, None, None)
  287. return result
  288. def _wrap_message_create(f):
  289. # type: (Any) -> Any
  290. def _execute_sync(f, *args, **kwargs):
  291. # type: (Any, *Any, **Any) -> Any
  292. gen = _sentry_patched_create_common(f, *args, **kwargs)
  293. try:
  294. f, args, kwargs = next(gen)
  295. except StopIteration as e:
  296. return e.value
  297. try:
  298. try:
  299. result = f(*args, **kwargs)
  300. except Exception as exc:
  301. _capture_exception(exc)
  302. raise exc from None
  303. return gen.send(result)
  304. except StopIteration as e:
  305. return e.value
  306. @wraps(f)
  307. def _sentry_patched_create_sync(*args, **kwargs):
  308. # type: (*Any, **Any) -> Any
  309. integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
  310. kwargs["integration"] = integration
  311. try:
  312. return _execute_sync(f, *args, **kwargs)
  313. finally:
  314. span = sentry_sdk.get_current_span()
  315. if span is not None and span.status == SPANSTATUS.ERROR:
  316. with capture_internal_exceptions():
  317. span.__exit__(None, None, None)
  318. return _sentry_patched_create_sync
  319. def _wrap_message_create_async(f):
  320. # type: (Any) -> Any
  321. async def _execute_async(f, *args, **kwargs):
  322. # type: (Any, *Any, **Any) -> Any
  323. gen = _sentry_patched_create_common(f, *args, **kwargs)
  324. try:
  325. f, args, kwargs = next(gen)
  326. except StopIteration as e:
  327. return await e.value
  328. try:
  329. try:
  330. result = await f(*args, **kwargs)
  331. except Exception as exc:
  332. _capture_exception(exc)
  333. raise exc from None
  334. return gen.send(result)
  335. except StopIteration as e:
  336. return e.value
  337. @wraps(f)
  338. async def _sentry_patched_create_async(*args, **kwargs):
  339. # type: (*Any, **Any) -> Any
  340. integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
  341. kwargs["integration"] = integration
  342. try:
  343. return await _execute_async(f, *args, **kwargs)
  344. finally:
  345. span = sentry_sdk.get_current_span()
  346. if span is not None and span.status == SPANSTATUS.ERROR:
  347. with capture_internal_exceptions():
  348. span.__exit__(None, None, None)
  349. return _sentry_patched_create_async