| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- import inspect
- from functools import wraps
- from sentry_sdk.consts import SPANDATA
- import sentry_sdk.utils
- from sentry_sdk import start_span
- from sentry_sdk.tracing import Span
- from sentry_sdk.utils import ContextVar
- from typing import TYPE_CHECKING
- if TYPE_CHECKING:
- from typing import Optional, Callable, Awaitable, Any, Union, TypeVar
- F = TypeVar("F", bound=Union[Callable[..., Any], Callable[..., Awaitable[Any]]])
- _ai_pipeline_name = ContextVar("ai_pipeline_name", default=None)
- def set_ai_pipeline_name(name):
- # type: (Optional[str]) -> None
- _ai_pipeline_name.set(name)
- def get_ai_pipeline_name():
- # type: () -> Optional[str]
- return _ai_pipeline_name.get()
- def ai_track(description, **span_kwargs):
- # type: (str, Any) -> Callable[[F], F]
- def decorator(f):
- # type: (F) -> F
- def sync_wrapped(*args, **kwargs):
- # type: (Any, Any) -> Any
- curr_pipeline = _ai_pipeline_name.get()
- op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
- with start_span(name=description, op=op, **span_kwargs) as span:
- for k, v in kwargs.pop("sentry_tags", {}).items():
- span.set_tag(k, v)
- for k, v in kwargs.pop("sentry_data", {}).items():
- span.set_data(k, v)
- if curr_pipeline:
- span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
- return f(*args, **kwargs)
- else:
- _ai_pipeline_name.set(description)
- try:
- res = f(*args, **kwargs)
- except Exception as e:
- event, hint = sentry_sdk.utils.event_from_exception(
- e,
- client_options=sentry_sdk.get_client().options,
- mechanism={"type": "ai_monitoring", "handled": False},
- )
- sentry_sdk.capture_event(event, hint=hint)
- raise e from None
- finally:
- _ai_pipeline_name.set(None)
- return res
- async def async_wrapped(*args, **kwargs):
- # type: (Any, Any) -> Any
- curr_pipeline = _ai_pipeline_name.get()
- op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
- with start_span(name=description, op=op, **span_kwargs) as span:
- for k, v in kwargs.pop("sentry_tags", {}).items():
- span.set_tag(k, v)
- for k, v in kwargs.pop("sentry_data", {}).items():
- span.set_data(k, v)
- if curr_pipeline:
- span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
- return await f(*args, **kwargs)
- else:
- _ai_pipeline_name.set(description)
- try:
- res = await f(*args, **kwargs)
- except Exception as e:
- event, hint = sentry_sdk.utils.event_from_exception(
- e,
- client_options=sentry_sdk.get_client().options,
- mechanism={"type": "ai_monitoring", "handled": False},
- )
- sentry_sdk.capture_event(event, hint=hint)
- raise e from None
- finally:
- _ai_pipeline_name.set(None)
- return res
- if inspect.iscoroutinefunction(f):
- return wraps(f)(async_wrapped) # type: ignore
- else:
- return wraps(f)(sync_wrapped) # type: ignore
- return decorator
- def record_token_usage(
- span,
- input_tokens=None,
- input_tokens_cached=None,
- output_tokens=None,
- output_tokens_reasoning=None,
- total_tokens=None,
- ):
- # type: (Span, Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]) -> None
- # TODO: move pipeline name elsewhere
- ai_pipeline_name = get_ai_pipeline_name()
- if ai_pipeline_name:
- span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, ai_pipeline_name)
- if input_tokens is not None:
- span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
- if input_tokens_cached is not None:
- span.set_data(
- SPANDATA.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
- input_tokens_cached,
- )
- if output_tokens is not None:
- span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
- if output_tokens_reasoning is not None:
- span.set_data(
- SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS_REASONING,
- output_tokens_reasoning,
- )
- if total_tokens is None and input_tokens is not None and output_tokens is not None:
- total_tokens = input_tokens + output_tokens
- if total_tokens is not None:
- span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)
|