monitoring.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import inspect
  2. from functools import wraps
  3. from sentry_sdk.consts import SPANDATA
  4. import sentry_sdk.utils
  5. from sentry_sdk import start_span
  6. from sentry_sdk.tracing import Span
  7. from sentry_sdk.utils import ContextVar
  8. from typing import TYPE_CHECKING
  9. if TYPE_CHECKING:
  10. from typing import Optional, Callable, Awaitable, Any, Union, TypeVar
  11. F = TypeVar("F", bound=Union[Callable[..., Any], Callable[..., Awaitable[Any]]])
  12. _ai_pipeline_name = ContextVar("ai_pipeline_name", default=None)
  13. def set_ai_pipeline_name(name):
  14. # type: (Optional[str]) -> None
  15. _ai_pipeline_name.set(name)
  16. def get_ai_pipeline_name():
  17. # type: () -> Optional[str]
  18. return _ai_pipeline_name.get()
  19. def ai_track(description, **span_kwargs):
  20. # type: (str, Any) -> Callable[[F], F]
  21. def decorator(f):
  22. # type: (F) -> F
  23. def sync_wrapped(*args, **kwargs):
  24. # type: (Any, Any) -> Any
  25. curr_pipeline = _ai_pipeline_name.get()
  26. op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
  27. with start_span(name=description, op=op, **span_kwargs) as span:
  28. for k, v in kwargs.pop("sentry_tags", {}).items():
  29. span.set_tag(k, v)
  30. for k, v in kwargs.pop("sentry_data", {}).items():
  31. span.set_data(k, v)
  32. if curr_pipeline:
  33. span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
  34. return f(*args, **kwargs)
  35. else:
  36. _ai_pipeline_name.set(description)
  37. try:
  38. res = f(*args, **kwargs)
  39. except Exception as e:
  40. event, hint = sentry_sdk.utils.event_from_exception(
  41. e,
  42. client_options=sentry_sdk.get_client().options,
  43. mechanism={"type": "ai_monitoring", "handled": False},
  44. )
  45. sentry_sdk.capture_event(event, hint=hint)
  46. raise e from None
  47. finally:
  48. _ai_pipeline_name.set(None)
  49. return res
  50. async def async_wrapped(*args, **kwargs):
  51. # type: (Any, Any) -> Any
  52. curr_pipeline = _ai_pipeline_name.get()
  53. op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
  54. with start_span(name=description, op=op, **span_kwargs) as span:
  55. for k, v in kwargs.pop("sentry_tags", {}).items():
  56. span.set_tag(k, v)
  57. for k, v in kwargs.pop("sentry_data", {}).items():
  58. span.set_data(k, v)
  59. if curr_pipeline:
  60. span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
  61. return await f(*args, **kwargs)
  62. else:
  63. _ai_pipeline_name.set(description)
  64. try:
  65. res = await f(*args, **kwargs)
  66. except Exception as e:
  67. event, hint = sentry_sdk.utils.event_from_exception(
  68. e,
  69. client_options=sentry_sdk.get_client().options,
  70. mechanism={"type": "ai_monitoring", "handled": False},
  71. )
  72. sentry_sdk.capture_event(event, hint=hint)
  73. raise e from None
  74. finally:
  75. _ai_pipeline_name.set(None)
  76. return res
  77. if inspect.iscoroutinefunction(f):
  78. return wraps(f)(async_wrapped) # type: ignore
  79. else:
  80. return wraps(f)(sync_wrapped) # type: ignore
  81. return decorator
  82. def record_token_usage(
  83. span,
  84. input_tokens=None,
  85. input_tokens_cached=None,
  86. output_tokens=None,
  87. output_tokens_reasoning=None,
  88. total_tokens=None,
  89. ):
  90. # type: (Span, Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]) -> None
  91. # TODO: move pipeline name elsewhere
  92. ai_pipeline_name = get_ai_pipeline_name()
  93. if ai_pipeline_name:
  94. span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, ai_pipeline_name)
  95. if input_tokens is not None:
  96. span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
  97. if input_tokens_cached is not None:
  98. span.set_data(
  99. SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
  100. input_tokens_cached,
  101. )
  102. if output_tokens is not None:
  103. span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
  104. if output_tokens_reasoning is not None:
  105. span.set_data(
  106. SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
  107. output_tokens_reasoning,
  108. )
  109. if total_tokens is None and input_tokens is not None and output_tokens is not None:
  110. total_tokens = input_tokens + output_tokens
  111. if total_tokens is not None:
  112. span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)