| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378 |
- import inspect
- from functools import wraps
- import sentry_sdk
- from sentry_sdk.ai.monitoring import record_token_usage
- from sentry_sdk.ai.utils import set_data_normalized
- from sentry_sdk.consts import OP, SPANDATA
- from sentry_sdk.integrations import DidNotEnable, Integration
- from sentry_sdk.scope import should_send_default_pii
- from sentry_sdk.tracing_utils import set_span_errored
- from sentry_sdk.utils import (
- capture_internal_exceptions,
- event_from_exception,
- )
- from typing import TYPE_CHECKING
- if TYPE_CHECKING:
- from typing import Any, Callable, Iterable
- try:
- import huggingface_hub.inference._client
- except ImportError:
- raise DidNotEnable("Huggingface not installed")
- class HuggingfaceHubIntegration(Integration):
- identifier = "huggingface_hub"
- origin = f"auto.ai.{identifier}"
- def __init__(self, include_prompts=True):
- # type: (HuggingfaceHubIntegration, bool) -> None
- self.include_prompts = include_prompts
- @staticmethod
- def setup_once():
- # type: () -> None
- # Other tasks that can be called: https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks
- huggingface_hub.inference._client.InferenceClient.text_generation = (
- _wrap_huggingface_task(
- huggingface_hub.inference._client.InferenceClient.text_generation,
- OP.GEN_AI_GENERATE_TEXT,
- )
- )
- huggingface_hub.inference._client.InferenceClient.chat_completion = (
- _wrap_huggingface_task(
- huggingface_hub.inference._client.InferenceClient.chat_completion,
- OP.GEN_AI_CHAT,
- )
- )
- def _capture_exception(exc):
- # type: (Any) -> None
- set_span_errored()
- event, hint = event_from_exception(
- exc,
- client_options=sentry_sdk.get_client().options,
- mechanism={"type": "huggingface_hub", "handled": False},
- )
- sentry_sdk.capture_event(event, hint=hint)
- def _wrap_huggingface_task(f, op):
- # type: (Callable[..., Any], str) -> Callable[..., Any]
- @wraps(f)
- def new_huggingface_task(*args, **kwargs):
- # type: (*Any, **Any) -> Any
- integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)
- if integration is None:
- return f(*args, **kwargs)
- prompt = None
- if "prompt" in kwargs:
- prompt = kwargs["prompt"]
- elif "messages" in kwargs:
- prompt = kwargs["messages"]
- elif len(args) >= 2:
- if isinstance(args[1], str) or isinstance(args[1], list):
- prompt = args[1]
- if prompt is None:
- # invalid call, dont instrument, let it return error
- return f(*args, **kwargs)
- client = args[0]
- model = client.model or kwargs.get("model") or ""
- operation_name = op.split(".")[-1]
- span = sentry_sdk.start_span(
- op=op,
- name=f"{operation_name} {model}",
- origin=HuggingfaceHubIntegration.origin,
- )
- span.__enter__()
- span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, operation_name)
- if model:
- span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)
- # Input attributes
- if should_send_default_pii() and integration.include_prompts:
- set_data_normalized(
- span, SPANDATA.GEN_AI_REQUEST_MESSAGES, prompt, unpack=False
- )
- attribute_mapping = {
- "tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
- "frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
- "max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
- "presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
- "temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
- "top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
- "top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
- "stream": SPANDATA.GEN_AI_RESPONSE_STREAMING,
- }
- for attribute, span_attribute in attribute_mapping.items():
- value = kwargs.get(attribute, None)
- if value is not None:
- if isinstance(value, (int, float, bool, str)):
- span.set_data(span_attribute, value)
- else:
- set_data_normalized(span, span_attribute, value, unpack=False)
- # LLM Execution
- try:
- res = f(*args, **kwargs)
- except Exception as e:
- _capture_exception(e)
- span.__exit__(None, None, None)
- raise e from None
- # Output attributes
- finish_reason = None
- response_model = None
- response_text_buffer: list[str] = []
- tokens_used = 0
- tool_calls = None
- usage = None
- with capture_internal_exceptions():
- if isinstance(res, str) and res is not None:
- response_text_buffer.append(res)
- if hasattr(res, "generated_text") and res.generated_text is not None:
- response_text_buffer.append(res.generated_text)
- if hasattr(res, "model") and res.model is not None:
- response_model = res.model
- if hasattr(res, "details") and hasattr(res.details, "finish_reason"):
- finish_reason = res.details.finish_reason
- if (
- hasattr(res, "details")
- and hasattr(res.details, "generated_tokens")
- and res.details.generated_tokens is not None
- ):
- tokens_used = res.details.generated_tokens
- if hasattr(res, "usage") and res.usage is not None:
- usage = res.usage
- if hasattr(res, "choices") and res.choices is not None:
- for choice in res.choices:
- if hasattr(choice, "finish_reason"):
- finish_reason = choice.finish_reason
- if hasattr(choice, "message") and hasattr(
- choice.message, "tool_calls"
- ):
- tool_calls = choice.message.tool_calls
- if (
- hasattr(choice, "message")
- and hasattr(choice.message, "content")
- and choice.message.content is not None
- ):
- response_text_buffer.append(choice.message.content)
- if response_model is not None:
- span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)
- if finish_reason is not None:
- set_data_normalized(
- span,
- SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
- finish_reason,
- )
- if should_send_default_pii() and integration.include_prompts:
- if tool_calls is not None and len(tool_calls) > 0:
- set_data_normalized(
- span,
- SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
- tool_calls,
- unpack=False,
- )
- if len(response_text_buffer) > 0:
- text_response = "".join(response_text_buffer)
- if text_response:
- set_data_normalized(
- span,
- SPANDATA.GEN_AI_RESPONSE_TEXT,
- text_response,
- )
- if usage is not None:
- record_token_usage(
- span,
- input_tokens=usage.prompt_tokens,
- output_tokens=usage.completion_tokens,
- total_tokens=usage.total_tokens,
- )
- elif tokens_used > 0:
- record_token_usage(
- span,
- total_tokens=tokens_used,
- )
- # If the response is not a generator (meaning a streaming response)
- # we are done and can return the response
- if not inspect.isgenerator(res):
- span.__exit__(None, None, None)
- return res
- if kwargs.get("details", False):
- # text-generation stream output
- def new_details_iterator():
- # type: () -> Iterable[Any]
- finish_reason = None
- response_text_buffer: list[str] = []
- tokens_used = 0
- with capture_internal_exceptions():
- for chunk in res:
- if (
- hasattr(chunk, "token")
- and hasattr(chunk.token, "text")
- and chunk.token.text is not None
- ):
- response_text_buffer.append(chunk.token.text)
- if hasattr(chunk, "details") and hasattr(
- chunk.details, "finish_reason"
- ):
- finish_reason = chunk.details.finish_reason
- if (
- hasattr(chunk, "details")
- and hasattr(chunk.details, "generated_tokens")
- and chunk.details.generated_tokens is not None
- ):
- tokens_used = chunk.details.generated_tokens
- yield chunk
- if finish_reason is not None:
- set_data_normalized(
- span,
- SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
- finish_reason,
- )
- if should_send_default_pii() and integration.include_prompts:
- if len(response_text_buffer) > 0:
- text_response = "".join(response_text_buffer)
- if text_response:
- set_data_normalized(
- span,
- SPANDATA.GEN_AI_RESPONSE_TEXT,
- text_response,
- )
- if tokens_used > 0:
- record_token_usage(
- span,
- total_tokens=tokens_used,
- )
- span.__exit__(None, None, None)
- return new_details_iterator()
- else:
- # chat-completion stream output
- def new_iterator():
- # type: () -> Iterable[str]
- finish_reason = None
- response_model = None
- response_text_buffer: list[str] = []
- tool_calls = None
- usage = None
- with capture_internal_exceptions():
- for chunk in res:
- if hasattr(chunk, "model") and chunk.model is not None:
- response_model = chunk.model
- if hasattr(chunk, "usage") and chunk.usage is not None:
- usage = chunk.usage
- if isinstance(chunk, str):
- if chunk is not None:
- response_text_buffer.append(chunk)
- if hasattr(chunk, "choices") and chunk.choices is not None:
- for choice in chunk.choices:
- if (
- hasattr(choice, "delta")
- and hasattr(choice.delta, "content")
- and choice.delta.content is not None
- ):
- response_text_buffer.append(
- choice.delta.content
- )
- if (
- hasattr(choice, "finish_reason")
- and choice.finish_reason is not None
- ):
- finish_reason = choice.finish_reason
- if (
- hasattr(choice, "delta")
- and hasattr(choice.delta, "tool_calls")
- and choice.delta.tool_calls is not None
- ):
- tool_calls = choice.delta.tool_calls
- yield chunk
- if response_model is not None:
- span.set_data(
- SPANDATA.GEN_AI_RESPONSE_MODEL, response_model
- )
- if finish_reason is not None:
- set_data_normalized(
- span,
- SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
- finish_reason,
- )
- if should_send_default_pii() and integration.include_prompts:
- if tool_calls is not None and len(tool_calls) > 0:
- set_data_normalized(
- span,
- SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
- tool_calls,
- unpack=False,
- )
- if len(response_text_buffer) > 0:
- text_response = "".join(response_text_buffer)
- if text_response:
- set_data_normalized(
- span,
- SPANDATA.GEN_AI_RESPONSE_TEXT,
- text_response,
- )
- if usage is not None:
- record_token_usage(
- span,
- input_tokens=usage.prompt_tokens,
- output_tokens=usage.completion_tokens,
- total_tokens=usage.total_tokens,
- )
- span.__exit__(None, None, None)
- return new_iterator()
- return new_huggingface_task
|