huggingface_hub.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import inspect
  2. from functools import wraps
  3. import sentry_sdk
  4. from sentry_sdk.ai.monitoring import record_token_usage
  5. from sentry_sdk.ai.utils import set_data_normalized
  6. from sentry_sdk.consts import OP, SPANDATA
  7. from sentry_sdk.integrations import DidNotEnable, Integration
  8. from sentry_sdk.scope import should_send_default_pii
  9. from sentry_sdk.tracing_utils import set_span_errored
  10. from sentry_sdk.utils import (
  11. capture_internal_exceptions,
  12. event_from_exception,
  13. )
  14. from typing import TYPE_CHECKING
  15. if TYPE_CHECKING:
  16. from typing import Any, Callable, Iterable
  17. try:
  18. import huggingface_hub.inference._client
  19. except ImportError:
  20. raise DidNotEnable("Huggingface not installed")
  21. class HuggingfaceHubIntegration(Integration):
  22. identifier = "huggingface_hub"
  23. origin = f"auto.ai.{identifier}"
  24. def __init__(self, include_prompts=True):
  25. # type: (HuggingfaceHubIntegration, bool) -> None
  26. self.include_prompts = include_prompts
  27. @staticmethod
  28. def setup_once():
  29. # type: () -> None
  30. # Other tasks that can be called: https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks
  31. huggingface_hub.inference._client.InferenceClient.text_generation = (
  32. _wrap_huggingface_task(
  33. huggingface_hub.inference._client.InferenceClient.text_generation,
  34. OP.GEN_AI_GENERATE_TEXT,
  35. )
  36. )
  37. huggingface_hub.inference._client.InferenceClient.chat_completion = (
  38. _wrap_huggingface_task(
  39. huggingface_hub.inference._client.InferenceClient.chat_completion,
  40. OP.GEN_AI_CHAT,
  41. )
  42. )
  43. def _capture_exception(exc):
  44. # type: (Any) -> None
  45. set_span_errored()
  46. event, hint = event_from_exception(
  47. exc,
  48. client_options=sentry_sdk.get_client().options,
  49. mechanism={"type": "huggingface_hub", "handled": False},
  50. )
  51. sentry_sdk.capture_event(event, hint=hint)
  52. def _wrap_huggingface_task(f, op):
  53. # type: (Callable[..., Any], str) -> Callable[..., Any]
  54. @wraps(f)
  55. def new_huggingface_task(*args, **kwargs):
  56. # type: (*Any, **Any) -> Any
  57. integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)
  58. if integration is None:
  59. return f(*args, **kwargs)
  60. prompt = None
  61. if "prompt" in kwargs:
  62. prompt = kwargs["prompt"]
  63. elif "messages" in kwargs:
  64. prompt = kwargs["messages"]
  65. elif len(args) >= 2:
  66. if isinstance(args[1], str) or isinstance(args[1], list):
  67. prompt = args[1]
  68. if prompt is None:
  69. # invalid call, dont instrument, let it return error
  70. return f(*args, **kwargs)
  71. client = args[0]
  72. model = client.model or kwargs.get("model") or ""
  73. operation_name = op.split(".")[-1]
  74. span = sentry_sdk.start_span(
  75. op=op,
  76. name=f"{operation_name} {model}",
  77. origin=HuggingfaceHubIntegration.origin,
  78. )
  79. span.__enter__()
  80. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, operation_name)
  81. if model:
  82. span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)
  83. # Input attributes
  84. if should_send_default_pii() and integration.include_prompts:
  85. set_data_normalized(
  86. span, SPANDATA.GEN_AI_REQUEST_MESSAGES, prompt, unpack=False
  87. )
  88. attribute_mapping = {
  89. "tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
  90. "frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
  91. "max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
  92. "presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
  93. "temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
  94. "top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
  95. "top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
  96. "stream": SPANDATA.GEN_AI_RESPONSE_STREAMING,
  97. }
  98. for attribute, span_attribute in attribute_mapping.items():
  99. value = kwargs.get(attribute, None)
  100. if value is not None:
  101. if isinstance(value, (int, float, bool, str)):
  102. span.set_data(span_attribute, value)
  103. else:
  104. set_data_normalized(span, span_attribute, value, unpack=False)
  105. # LLM Execution
  106. try:
  107. res = f(*args, **kwargs)
  108. except Exception as e:
  109. _capture_exception(e)
  110. span.__exit__(None, None, None)
  111. raise e from None
  112. # Output attributes
  113. finish_reason = None
  114. response_model = None
  115. response_text_buffer: list[str] = []
  116. tokens_used = 0
  117. tool_calls = None
  118. usage = None
  119. with capture_internal_exceptions():
  120. if isinstance(res, str) and res is not None:
  121. response_text_buffer.append(res)
  122. if hasattr(res, "generated_text") and res.generated_text is not None:
  123. response_text_buffer.append(res.generated_text)
  124. if hasattr(res, "model") and res.model is not None:
  125. response_model = res.model
  126. if hasattr(res, "details") and hasattr(res.details, "finish_reason"):
  127. finish_reason = res.details.finish_reason
  128. if (
  129. hasattr(res, "details")
  130. and hasattr(res.details, "generated_tokens")
  131. and res.details.generated_tokens is not None
  132. ):
  133. tokens_used = res.details.generated_tokens
  134. if hasattr(res, "usage") and res.usage is not None:
  135. usage = res.usage
  136. if hasattr(res, "choices") and res.choices is not None:
  137. for choice in res.choices:
  138. if hasattr(choice, "finish_reason"):
  139. finish_reason = choice.finish_reason
  140. if hasattr(choice, "message") and hasattr(
  141. choice.message, "tool_calls"
  142. ):
  143. tool_calls = choice.message.tool_calls
  144. if (
  145. hasattr(choice, "message")
  146. and hasattr(choice.message, "content")
  147. and choice.message.content is not None
  148. ):
  149. response_text_buffer.append(choice.message.content)
  150. if response_model is not None:
  151. span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)
  152. if finish_reason is not None:
  153. set_data_normalized(
  154. span,
  155. SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
  156. finish_reason,
  157. )
  158. if should_send_default_pii() and integration.include_prompts:
  159. if tool_calls is not None and len(tool_calls) > 0:
  160. set_data_normalized(
  161. span,
  162. SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  163. tool_calls,
  164. unpack=False,
  165. )
  166. if len(response_text_buffer) > 0:
  167. text_response = "".join(response_text_buffer)
  168. if text_response:
  169. set_data_normalized(
  170. span,
  171. SPANDATA.GEN_AI_RESPONSE_TEXT,
  172. text_response,
  173. )
  174. if usage is not None:
  175. record_token_usage(
  176. span,
  177. input_tokens=usage.prompt_tokens,
  178. output_tokens=usage.completion_tokens,
  179. total_tokens=usage.total_tokens,
  180. )
  181. elif tokens_used > 0:
  182. record_token_usage(
  183. span,
  184. total_tokens=tokens_used,
  185. )
  186. # If the response is not a generator (meaning a streaming response)
  187. # we are done and can return the response
  188. if not inspect.isgenerator(res):
  189. span.__exit__(None, None, None)
  190. return res
  191. if kwargs.get("details", False):
  192. # text-generation stream output
  193. def new_details_iterator():
  194. # type: () -> Iterable[Any]
  195. finish_reason = None
  196. response_text_buffer: list[str] = []
  197. tokens_used = 0
  198. with capture_internal_exceptions():
  199. for chunk in res:
  200. if (
  201. hasattr(chunk, "token")
  202. and hasattr(chunk.token, "text")
  203. and chunk.token.text is not None
  204. ):
  205. response_text_buffer.append(chunk.token.text)
  206. if hasattr(chunk, "details") and hasattr(
  207. chunk.details, "finish_reason"
  208. ):
  209. finish_reason = chunk.details.finish_reason
  210. if (
  211. hasattr(chunk, "details")
  212. and hasattr(chunk.details, "generated_tokens")
  213. and chunk.details.generated_tokens is not None
  214. ):
  215. tokens_used = chunk.details.generated_tokens
  216. yield chunk
  217. if finish_reason is not None:
  218. set_data_normalized(
  219. span,
  220. SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
  221. finish_reason,
  222. )
  223. if should_send_default_pii() and integration.include_prompts:
  224. if len(response_text_buffer) > 0:
  225. text_response = "".join(response_text_buffer)
  226. if text_response:
  227. set_data_normalized(
  228. span,
  229. SPANDATA.GEN_AI_RESPONSE_TEXT,
  230. text_response,
  231. )
  232. if tokens_used > 0:
  233. record_token_usage(
  234. span,
  235. total_tokens=tokens_used,
  236. )
  237. span.__exit__(None, None, None)
  238. return new_details_iterator()
  239. else:
  240. # chat-completion stream output
  241. def new_iterator():
  242. # type: () -> Iterable[str]
  243. finish_reason = None
  244. response_model = None
  245. response_text_buffer: list[str] = []
  246. tool_calls = None
  247. usage = None
  248. with capture_internal_exceptions():
  249. for chunk in res:
  250. if hasattr(chunk, "model") and chunk.model is not None:
  251. response_model = chunk.model
  252. if hasattr(chunk, "usage") and chunk.usage is not None:
  253. usage = chunk.usage
  254. if isinstance(chunk, str):
  255. if chunk is not None:
  256. response_text_buffer.append(chunk)
  257. if hasattr(chunk, "choices") and chunk.choices is not None:
  258. for choice in chunk.choices:
  259. if (
  260. hasattr(choice, "delta")
  261. and hasattr(choice.delta, "content")
  262. and choice.delta.content is not None
  263. ):
  264. response_text_buffer.append(
  265. choice.delta.content
  266. )
  267. if (
  268. hasattr(choice, "finish_reason")
  269. and choice.finish_reason is not None
  270. ):
  271. finish_reason = choice.finish_reason
  272. if (
  273. hasattr(choice, "delta")
  274. and hasattr(choice.delta, "tool_calls")
  275. and choice.delta.tool_calls is not None
  276. ):
  277. tool_calls = choice.delta.tool_calls
  278. yield chunk
  279. if response_model is not None:
  280. span.set_data(
  281. SPANDATA.GEN_AI_RESPONSE_MODEL, response_model
  282. )
  283. if finish_reason is not None:
  284. set_data_normalized(
  285. span,
  286. SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS,
  287. finish_reason,
  288. )
  289. if should_send_default_pii() and integration.include_prompts:
  290. if tool_calls is not None and len(tool_calls) > 0:
  291. set_data_normalized(
  292. span,
  293. SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  294. tool_calls,
  295. unpack=False,
  296. )
  297. if len(response_text_buffer) > 0:
  298. text_response = "".join(response_text_buffer)
  299. if text_response:
  300. set_data_normalized(
  301. span,
  302. SPANDATA.GEN_AI_RESPONSE_TEXT,
  303. text_response,
  304. )
  305. if usage is not None:
  306. record_token_usage(
  307. span,
  308. input_tokens=usage.prompt_tokens,
  309. output_tokens=usage.completion_tokens,
  310. total_tokens=usage.total_tokens,
  311. )
  312. span.__exit__(None, None, None)
  313. return new_iterator()
  314. return new_huggingface_task