cohere.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. from functools import wraps
  2. from sentry_sdk import consts
  3. from sentry_sdk.ai.monitoring import record_token_usage
  4. from sentry_sdk.consts import SPANDATA
  5. from sentry_sdk.ai.utils import set_data_normalized
  6. from typing import TYPE_CHECKING
  7. from sentry_sdk.tracing_utils import set_span_errored
  8. if TYPE_CHECKING:
  9. from typing import Any, Callable, Iterator
  10. from sentry_sdk.tracing import Span
  11. import sentry_sdk
  12. from sentry_sdk.scope import should_send_default_pii
  13. from sentry_sdk.integrations import DidNotEnable, Integration
  14. from sentry_sdk.utils import capture_internal_exceptions, event_from_exception
  15. try:
  16. from cohere.client import Client
  17. from cohere.base_client import BaseCohere
  18. from cohere import (
  19. ChatStreamEndEvent,
  20. NonStreamedChatResponse,
  21. )
  22. if TYPE_CHECKING:
  23. from cohere import StreamedChatResponse
  24. except ImportError:
  25. raise DidNotEnable("Cohere not installed")
  26. try:
  27. # cohere 5.9.3+
  28. from cohere import StreamEndStreamedChatResponse
  29. except ImportError:
  30. from cohere import StreamedChatResponse_StreamEnd as StreamEndStreamedChatResponse
  31. COLLECTED_CHAT_PARAMS = {
  32. "model": SPANDATA.AI_MODEL_ID,
  33. "k": SPANDATA.AI_TOP_K,
  34. "p": SPANDATA.AI_TOP_P,
  35. "seed": SPANDATA.AI_SEED,
  36. "frequency_penalty": SPANDATA.AI_FREQUENCY_PENALTY,
  37. "presence_penalty": SPANDATA.AI_PRESENCE_PENALTY,
  38. "raw_prompting": SPANDATA.AI_RAW_PROMPTING,
  39. }
  40. COLLECTED_PII_CHAT_PARAMS = {
  41. "tools": SPANDATA.AI_TOOLS,
  42. "preamble": SPANDATA.AI_PREAMBLE,
  43. }
  44. COLLECTED_CHAT_RESP_ATTRS = {
  45. "generation_id": SPANDATA.AI_GENERATION_ID,
  46. "is_search_required": SPANDATA.AI_SEARCH_REQUIRED,
  47. "finish_reason": SPANDATA.AI_FINISH_REASON,
  48. }
  49. COLLECTED_PII_CHAT_RESP_ATTRS = {
  50. "citations": SPANDATA.AI_CITATIONS,
  51. "documents": SPANDATA.AI_DOCUMENTS,
  52. "search_queries": SPANDATA.AI_SEARCH_QUERIES,
  53. "search_results": SPANDATA.AI_SEARCH_RESULTS,
  54. "tool_calls": SPANDATA.AI_TOOL_CALLS,
  55. }
  56. class CohereIntegration(Integration):
  57. identifier = "cohere"
  58. origin = f"auto.ai.{identifier}"
  59. def __init__(self, include_prompts=True):
  60. # type: (CohereIntegration, bool) -> None
  61. self.include_prompts = include_prompts
  62. @staticmethod
  63. def setup_once():
  64. # type: () -> None
  65. BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False)
  66. Client.embed = _wrap_embed(Client.embed)
  67. BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True)
  68. def _capture_exception(exc):
  69. # type: (Any) -> None
  70. set_span_errored()
  71. event, hint = event_from_exception(
  72. exc,
  73. client_options=sentry_sdk.get_client().options,
  74. mechanism={"type": "cohere", "handled": False},
  75. )
  76. sentry_sdk.capture_event(event, hint=hint)
  77. def _wrap_chat(f, streaming):
  78. # type: (Callable[..., Any], bool) -> Callable[..., Any]
  79. def collect_chat_response_fields(span, res, include_pii):
  80. # type: (Span, NonStreamedChatResponse, bool) -> None
  81. if include_pii:
  82. if hasattr(res, "text"):
  83. set_data_normalized(
  84. span,
  85. SPANDATA.AI_RESPONSES,
  86. [res.text],
  87. )
  88. for pii_attr in COLLECTED_PII_CHAT_RESP_ATTRS:
  89. if hasattr(res, pii_attr):
  90. set_data_normalized(span, "ai." + pii_attr, getattr(res, pii_attr))
  91. for attr in COLLECTED_CHAT_RESP_ATTRS:
  92. if hasattr(res, attr):
  93. set_data_normalized(span, "ai." + attr, getattr(res, attr))
  94. if hasattr(res, "meta"):
  95. if hasattr(res.meta, "billed_units"):
  96. record_token_usage(
  97. span,
  98. input_tokens=res.meta.billed_units.input_tokens,
  99. output_tokens=res.meta.billed_units.output_tokens,
  100. )
  101. elif hasattr(res.meta, "tokens"):
  102. record_token_usage(
  103. span,
  104. input_tokens=res.meta.tokens.input_tokens,
  105. output_tokens=res.meta.tokens.output_tokens,
  106. )
  107. if hasattr(res.meta, "warnings"):
  108. set_data_normalized(span, SPANDATA.AI_WARNINGS, res.meta.warnings)
  109. @wraps(f)
  110. def new_chat(*args, **kwargs):
  111. # type: (*Any, **Any) -> Any
  112. integration = sentry_sdk.get_client().get_integration(CohereIntegration)
  113. if (
  114. integration is None
  115. or "message" not in kwargs
  116. or not isinstance(kwargs.get("message"), str)
  117. ):
  118. return f(*args, **kwargs)
  119. message = kwargs.get("message")
  120. span = sentry_sdk.start_span(
  121. op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE,
  122. name="cohere.client.Chat",
  123. origin=CohereIntegration.origin,
  124. )
  125. span.__enter__()
  126. try:
  127. res = f(*args, **kwargs)
  128. except Exception as e:
  129. _capture_exception(e)
  130. span.__exit__(None, None, None)
  131. raise e from None
  132. with capture_internal_exceptions():
  133. if should_send_default_pii() and integration.include_prompts:
  134. set_data_normalized(
  135. span,
  136. SPANDATA.AI_INPUT_MESSAGES,
  137. list(
  138. map(
  139. lambda x: {
  140. "role": getattr(x, "role", "").lower(),
  141. "content": getattr(x, "message", ""),
  142. },
  143. kwargs.get("chat_history", []),
  144. )
  145. )
  146. + [{"role": "user", "content": message}],
  147. )
  148. for k, v in COLLECTED_PII_CHAT_PARAMS.items():
  149. if k in kwargs:
  150. set_data_normalized(span, v, kwargs[k])
  151. for k, v in COLLECTED_CHAT_PARAMS.items():
  152. if k in kwargs:
  153. set_data_normalized(span, v, kwargs[k])
  154. set_data_normalized(span, SPANDATA.AI_STREAMING, False)
  155. if streaming:
  156. old_iterator = res
  157. def new_iterator():
  158. # type: () -> Iterator[StreamedChatResponse]
  159. with capture_internal_exceptions():
  160. for x in old_iterator:
  161. if isinstance(x, ChatStreamEndEvent) or isinstance(
  162. x, StreamEndStreamedChatResponse
  163. ):
  164. collect_chat_response_fields(
  165. span,
  166. x.response,
  167. include_pii=should_send_default_pii()
  168. and integration.include_prompts,
  169. )
  170. yield x
  171. span.__exit__(None, None, None)
  172. return new_iterator()
  173. elif isinstance(res, NonStreamedChatResponse):
  174. collect_chat_response_fields(
  175. span,
  176. res,
  177. include_pii=should_send_default_pii()
  178. and integration.include_prompts,
  179. )
  180. span.__exit__(None, None, None)
  181. else:
  182. set_data_normalized(span, "unknown_response", True)
  183. span.__exit__(None, None, None)
  184. return res
  185. return new_chat
  186. def _wrap_embed(f):
  187. # type: (Callable[..., Any]) -> Callable[..., Any]
  188. @wraps(f)
  189. def new_embed(*args, **kwargs):
  190. # type: (*Any, **Any) -> Any
  191. integration = sentry_sdk.get_client().get_integration(CohereIntegration)
  192. if integration is None:
  193. return f(*args, **kwargs)
  194. with sentry_sdk.start_span(
  195. op=consts.OP.COHERE_EMBEDDINGS_CREATE,
  196. name="Cohere Embedding Creation",
  197. origin=CohereIntegration.origin,
  198. ) as span:
  199. if "texts" in kwargs and (
  200. should_send_default_pii() and integration.include_prompts
  201. ):
  202. if isinstance(kwargs["texts"], str):
  203. set_data_normalized(span, SPANDATA.AI_TEXTS, [kwargs["texts"]])
  204. elif (
  205. isinstance(kwargs["texts"], list)
  206. and len(kwargs["texts"]) > 0
  207. and isinstance(kwargs["texts"][0], str)
  208. ):
  209. set_data_normalized(
  210. span, SPANDATA.AI_INPUT_MESSAGES, kwargs["texts"]
  211. )
  212. if "model" in kwargs:
  213. set_data_normalized(span, SPANDATA.AI_MODEL_ID, kwargs["model"])
  214. try:
  215. res = f(*args, **kwargs)
  216. except Exception as e:
  217. _capture_exception(e)
  218. raise e from None
  219. if (
  220. hasattr(res, "meta")
  221. and hasattr(res.meta, "billed_units")
  222. and hasattr(res.meta.billed_units, "input_tokens")
  223. ):
  224. record_token_usage(
  225. span,
  226. input_tokens=res.meta.billed_units.input_tokens,
  227. total_tokens=res.meta.billed_units.input_tokens,
  228. )
  229. return res
  230. return new_embed