langchain.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873
  1. import itertools
  2. from collections import OrderedDict
  3. from functools import wraps
  4. import sentry_sdk
  5. from sentry_sdk.ai.monitoring import set_ai_pipeline_name
  6. from sentry_sdk.ai.utils import (
  7. GEN_AI_ALLOWED_MESSAGE_ROLES,
  8. normalize_message_roles,
  9. set_data_normalized,
  10. get_start_span_function,
  11. )
  12. from sentry_sdk.consts import OP, SPANDATA
  13. from sentry_sdk.integrations import DidNotEnable, Integration
  14. from sentry_sdk.scope import should_send_default_pii
  15. from sentry_sdk.tracing_utils import _get_value, set_span_errored
  16. from sentry_sdk.utils import logger, capture_internal_exceptions
  17. from typing import TYPE_CHECKING
  18. if TYPE_CHECKING:
  19. from typing import (
  20. Any,
  21. AsyncIterator,
  22. Callable,
  23. Dict,
  24. Iterator,
  25. List,
  26. Optional,
  27. Union,
  28. )
  29. from uuid import UUID
  30. from sentry_sdk.tracing import Span
  31. try:
  32. from langchain_core.agents import AgentFinish
  33. from langchain_core.callbacks import (
  34. BaseCallbackHandler,
  35. BaseCallbackManager,
  36. Callbacks,
  37. manager,
  38. )
  39. from langchain_core.messages import BaseMessage
  40. from langchain_core.outputs import LLMResult
  41. except ImportError:
  42. raise DidNotEnable("langchain not installed")
  43. try:
  44. from langchain.agents import AgentExecutor
  45. except ImportError:
  46. AgentExecutor = None
  47. DATA_FIELDS = {
  48. "frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY,
  49. "function_call": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  50. "max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS,
  51. "presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY,
  52. "temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE,
  53. "tool_calls": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  54. "top_k": SPANDATA.GEN_AI_REQUEST_TOP_K,
  55. "top_p": SPANDATA.GEN_AI_REQUEST_TOP_P,
  56. }
  57. class LangchainIntegration(Integration):
  58. identifier = "langchain"
  59. origin = f"auto.ai.{identifier}"
  60. # The most number of spans (e.g., LLM calls) that can be processed at the same time.
  61. max_spans = 1024
  62. def __init__(self, include_prompts=True, max_spans=1024):
  63. # type: (LangchainIntegration, bool, int) -> None
  64. self.include_prompts = include_prompts
  65. self.max_spans = max_spans
  66. @staticmethod
  67. def setup_once():
  68. # type: () -> None
  69. manager._configure = _wrap_configure(manager._configure)
  70. if AgentExecutor is not None:
  71. AgentExecutor.invoke = _wrap_agent_executor_invoke(AgentExecutor.invoke)
  72. AgentExecutor.stream = _wrap_agent_executor_stream(AgentExecutor.stream)
  73. class WatchedSpan:
  74. span = None # type: Span
  75. children = [] # type: List[WatchedSpan]
  76. is_pipeline = False # type: bool
  77. def __init__(self, span):
  78. # type: (Span) -> None
  79. self.span = span
  80. class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc]
  81. """Callback handler that creates Sentry spans."""
  82. def __init__(self, max_span_map_size, include_prompts):
  83. # type: (int, bool) -> None
  84. self.span_map = OrderedDict() # type: OrderedDict[UUID, WatchedSpan]
  85. self.max_span_map_size = max_span_map_size
  86. self.include_prompts = include_prompts
  87. def gc_span_map(self):
  88. # type: () -> None
  89. while len(self.span_map) > self.max_span_map_size:
  90. run_id, watched_span = self.span_map.popitem(last=False)
  91. self._exit_span(watched_span, run_id)
  92. def _handle_error(self, run_id, error):
  93. # type: (UUID, Any) -> None
  94. with capture_internal_exceptions():
  95. if not run_id or run_id not in self.span_map:
  96. return
  97. span_data = self.span_map[run_id]
  98. span = span_data.span
  99. set_span_errored(span)
  100. sentry_sdk.capture_exception(error, span.scope)
  101. span.__exit__(None, None, None)
  102. del self.span_map[run_id]
  103. def _normalize_langchain_message(self, message):
  104. # type: (BaseMessage) -> Any
  105. parsed = {"role": message.type, "content": message.content}
  106. parsed.update(message.additional_kwargs)
  107. return parsed
  108. def _create_span(self, run_id, parent_id, **kwargs):
  109. # type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan
  110. watched_span = None # type: Optional[WatchedSpan]
  111. if parent_id:
  112. parent_span = self.span_map.get(parent_id) # type: Optional[WatchedSpan]
  113. if parent_span:
  114. watched_span = WatchedSpan(parent_span.span.start_child(**kwargs))
  115. parent_span.children.append(watched_span)
  116. if watched_span is None:
  117. watched_span = WatchedSpan(sentry_sdk.start_span(**kwargs))
  118. watched_span.span.__enter__()
  119. self.span_map[run_id] = watched_span
  120. self.gc_span_map()
  121. return watched_span
  122. def _exit_span(self, span_data, run_id):
  123. # type: (SentryLangchainCallback, WatchedSpan, UUID) -> None
  124. if span_data.is_pipeline:
  125. set_ai_pipeline_name(None)
  126. span_data.span.__exit__(None, None, None)
  127. del self.span_map[run_id]
  128. def on_llm_start(
  129. self,
  130. serialized,
  131. prompts,
  132. *,
  133. run_id,
  134. tags=None,
  135. parent_run_id=None,
  136. metadata=None,
  137. **kwargs,
  138. ):
  139. # type: (SentryLangchainCallback, Dict[str, Any], List[str], UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Any) -> Any
  140. """Run when LLM starts running."""
  141. with capture_internal_exceptions():
  142. if not run_id:
  143. return
  144. all_params = kwargs.get("invocation_params", {})
  145. all_params.update(serialized.get("kwargs", {}))
  146. model = (
  147. all_params.get("model")
  148. or all_params.get("model_name")
  149. or all_params.get("model_id")
  150. or ""
  151. )
  152. watched_span = self._create_span(
  153. run_id,
  154. parent_run_id,
  155. op=OP.GEN_AI_PIPELINE,
  156. name=kwargs.get("name") or "Langchain LLM call",
  157. origin=LangchainIntegration.origin,
  158. )
  159. span = watched_span.span
  160. if model:
  161. span.set_data(
  162. SPANDATA.GEN_AI_REQUEST_MODEL,
  163. model,
  164. )
  165. ai_type = all_params.get("_type", "")
  166. if "anthropic" in ai_type:
  167. span.set_data(SPANDATA.GEN_AI_SYSTEM, "anthropic")
  168. elif "openai" in ai_type:
  169. span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
  170. for key, attribute in DATA_FIELDS.items():
  171. if key in all_params and all_params[key] is not None:
  172. set_data_normalized(span, attribute, all_params[key], unpack=False)
  173. _set_tools_on_span(span, all_params.get("tools"))
  174. if should_send_default_pii() and self.include_prompts:
  175. normalized_messages = [
  176. {
  177. "role": GEN_AI_ALLOWED_MESSAGE_ROLES.USER,
  178. "content": {"type": "text", "text": prompt},
  179. }
  180. for prompt in prompts
  181. ]
  182. set_data_normalized(
  183. span,
  184. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  185. normalized_messages,
  186. unpack=False,
  187. )
  188. def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
  189. # type: (SentryLangchainCallback, Dict[str, Any], List[List[BaseMessage]], UUID, Any) -> Any
  190. """Run when Chat Model starts running."""
  191. with capture_internal_exceptions():
  192. if not run_id:
  193. return
  194. all_params = kwargs.get("invocation_params", {})
  195. all_params.update(serialized.get("kwargs", {}))
  196. model = (
  197. all_params.get("model")
  198. or all_params.get("model_name")
  199. or all_params.get("model_id")
  200. or ""
  201. )
  202. watched_span = self._create_span(
  203. run_id,
  204. kwargs.get("parent_run_id"),
  205. op=OP.GEN_AI_CHAT,
  206. name=f"chat {model}".strip(),
  207. origin=LangchainIntegration.origin,
  208. )
  209. span = watched_span.span
  210. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat")
  211. if model:
  212. span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)
  213. ai_type = all_params.get("_type", "")
  214. if "anthropic" in ai_type:
  215. span.set_data(SPANDATA.GEN_AI_SYSTEM, "anthropic")
  216. elif "openai" in ai_type:
  217. span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
  218. for key, attribute in DATA_FIELDS.items():
  219. if key in all_params and all_params[key] is not None:
  220. set_data_normalized(span, attribute, all_params[key], unpack=False)
  221. _set_tools_on_span(span, all_params.get("tools"))
  222. if should_send_default_pii() and self.include_prompts:
  223. normalized_messages = []
  224. for list_ in messages:
  225. for message in list_:
  226. normalized_messages.append(
  227. self._normalize_langchain_message(message)
  228. )
  229. normalized_messages = normalize_message_roles(normalized_messages)
  230. set_data_normalized(
  231. span,
  232. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  233. normalized_messages,
  234. unpack=False,
  235. )
  236. def on_chat_model_end(self, response, *, run_id, **kwargs):
  237. # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
  238. """Run when Chat Model ends running."""
  239. with capture_internal_exceptions():
  240. if not run_id or run_id not in self.span_map:
  241. return
  242. span_data = self.span_map[run_id]
  243. span = span_data.span
  244. if should_send_default_pii() and self.include_prompts:
  245. set_data_normalized(
  246. span,
  247. SPANDATA.GEN_AI_RESPONSE_TEXT,
  248. [[x.text for x in list_] for list_ in response.generations],
  249. )
  250. _record_token_usage(span, response)
  251. self._exit_span(span_data, run_id)
  252. def on_llm_end(self, response, *, run_id, **kwargs):
  253. # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
  254. """Run when LLM ends running."""
  255. with capture_internal_exceptions():
  256. if not run_id or run_id not in self.span_map:
  257. return
  258. span_data = self.span_map[run_id]
  259. span = span_data.span
  260. try:
  261. generation = response.generations[0][0]
  262. except IndexError:
  263. generation = None
  264. if generation is not None:
  265. try:
  266. response_model = generation.generation_info.get("model_name")
  267. if response_model is not None:
  268. span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)
  269. except AttributeError:
  270. pass
  271. try:
  272. finish_reason = generation.generation_info.get("finish_reason")
  273. if finish_reason is not None:
  274. span.set_data(
  275. SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS, finish_reason
  276. )
  277. except AttributeError:
  278. pass
  279. try:
  280. if should_send_default_pii() and self.include_prompts:
  281. tool_calls = getattr(generation.message, "tool_calls", None)
  282. if tool_calls is not None and tool_calls != []:
  283. set_data_normalized(
  284. span,
  285. SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  286. tool_calls,
  287. unpack=False,
  288. )
  289. except AttributeError:
  290. pass
  291. if should_send_default_pii() and self.include_prompts:
  292. set_data_normalized(
  293. span,
  294. SPANDATA.GEN_AI_RESPONSE_TEXT,
  295. [[x.text for x in list_] for list_ in response.generations],
  296. )
  297. _record_token_usage(span, response)
  298. self._exit_span(span_data, run_id)
  299. def on_llm_error(self, error, *, run_id, **kwargs):
  300. # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
  301. """Run when LLM errors."""
  302. self._handle_error(run_id, error)
  303. def on_chat_model_error(self, error, *, run_id, **kwargs):
  304. # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
  305. """Run when Chat Model errors."""
  306. self._handle_error(run_id, error)
  307. def on_agent_finish(self, finish, *, run_id, **kwargs):
  308. # type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any
  309. with capture_internal_exceptions():
  310. if not run_id or run_id not in self.span_map:
  311. return
  312. span_data = self.span_map[run_id]
  313. span = span_data.span
  314. if should_send_default_pii() and self.include_prompts:
  315. set_data_normalized(
  316. span, SPANDATA.GEN_AI_RESPONSE_TEXT, finish.return_values.items()
  317. )
  318. self._exit_span(span_data, run_id)
  319. def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
  320. # type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Any) -> Any
  321. """Run when tool starts running."""
  322. with capture_internal_exceptions():
  323. if not run_id:
  324. return
  325. tool_name = serialized.get("name") or kwargs.get("name") or ""
  326. watched_span = self._create_span(
  327. run_id,
  328. kwargs.get("parent_run_id"),
  329. op=OP.GEN_AI_EXECUTE_TOOL,
  330. name=f"execute_tool {tool_name}".strip(),
  331. origin=LangchainIntegration.origin,
  332. )
  333. span = watched_span.span
  334. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "execute_tool")
  335. span.set_data(SPANDATA.GEN_AI_TOOL_NAME, tool_name)
  336. tool_description = serialized.get("description")
  337. if tool_description is not None:
  338. span.set_data(SPANDATA.GEN_AI_TOOL_DESCRIPTION, tool_description)
  339. if should_send_default_pii() and self.include_prompts:
  340. set_data_normalized(
  341. span,
  342. SPANDATA.GEN_AI_TOOL_INPUT,
  343. kwargs.get("inputs", [input_str]),
  344. )
  345. def on_tool_end(self, output, *, run_id, **kwargs):
  346. # type: (SentryLangchainCallback, str, UUID, Any) -> Any
  347. """Run when tool ends running."""
  348. with capture_internal_exceptions():
  349. if not run_id or run_id not in self.span_map:
  350. return
  351. span_data = self.span_map[run_id]
  352. span = span_data.span
  353. if should_send_default_pii() and self.include_prompts:
  354. set_data_normalized(span, SPANDATA.GEN_AI_TOOL_OUTPUT, output)
  355. self._exit_span(span_data, run_id)
  356. def on_tool_error(self, error, *args, run_id, **kwargs):
  357. # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
  358. """Run when tool errors."""
  359. self._handle_error(run_id, error)
  360. def _extract_tokens(token_usage):
  361. # type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
  362. if not token_usage:
  363. return None, None, None
  364. input_tokens = _get_value(token_usage, "prompt_tokens") or _get_value(
  365. token_usage, "input_tokens"
  366. )
  367. output_tokens = _get_value(token_usage, "completion_tokens") or _get_value(
  368. token_usage, "output_tokens"
  369. )
  370. total_tokens = _get_value(token_usage, "total_tokens")
  371. return input_tokens, output_tokens, total_tokens
  372. def _extract_tokens_from_generations(generations):
  373. # type: (Any) -> tuple[Optional[int], Optional[int], Optional[int]]
  374. """Extract token usage from response.generations structure."""
  375. if not generations:
  376. return None, None, None
  377. total_input = 0
  378. total_output = 0
  379. total_total = 0
  380. for gen_list in generations:
  381. for gen in gen_list:
  382. token_usage = _get_token_usage(gen)
  383. input_tokens, output_tokens, total_tokens = _extract_tokens(token_usage)
  384. total_input += input_tokens if input_tokens is not None else 0
  385. total_output += output_tokens if output_tokens is not None else 0
  386. total_total += total_tokens if total_tokens is not None else 0
  387. return (
  388. total_input if total_input > 0 else None,
  389. total_output if total_output > 0 else None,
  390. total_total if total_total > 0 else None,
  391. )
  392. def _get_token_usage(obj):
  393. # type: (Any) -> Optional[Dict[str, Any]]
  394. """
  395. Check multiple paths to extract token usage from different objects.
  396. """
  397. possible_names = ("usage", "token_usage", "usage_metadata")
  398. message = _get_value(obj, "message")
  399. if message is not None:
  400. for name in possible_names:
  401. usage = _get_value(message, name)
  402. if usage is not None:
  403. return usage
  404. llm_output = _get_value(obj, "llm_output")
  405. if llm_output is not None:
  406. for name in possible_names:
  407. usage = _get_value(llm_output, name)
  408. if usage is not None:
  409. return usage
  410. for name in possible_names:
  411. usage = _get_value(obj, name)
  412. if usage is not None:
  413. return usage
  414. return None
  415. def _record_token_usage(span, response):
  416. # type: (Span, Any) -> None
  417. token_usage = _get_token_usage(response)
  418. if token_usage:
  419. input_tokens, output_tokens, total_tokens = _extract_tokens(token_usage)
  420. else:
  421. input_tokens, output_tokens, total_tokens = _extract_tokens_from_generations(
  422. response.generations
  423. )
  424. if input_tokens is not None:
  425. span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
  426. if output_tokens is not None:
  427. span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
  428. if total_tokens is not None:
  429. span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)
  430. def _get_request_data(obj, args, kwargs):
  431. # type: (Any, Any, Any) -> tuple[Optional[str], Optional[List[Any]]]
  432. """
  433. Get the agent name and available tools for the agent.
  434. """
  435. agent = getattr(obj, "agent", None)
  436. runnable = getattr(agent, "runnable", None)
  437. runnable_config = getattr(runnable, "config", {})
  438. tools = (
  439. getattr(obj, "tools", None)
  440. or getattr(agent, "tools", None)
  441. or runnable_config.get("tools")
  442. or runnable_config.get("available_tools")
  443. )
  444. tools = tools if tools and len(tools) > 0 else None
  445. try:
  446. agent_name = None
  447. if len(args) > 1:
  448. agent_name = args[1].get("run_name")
  449. if agent_name is None:
  450. agent_name = runnable_config.get("run_name")
  451. except Exception:
  452. pass
  453. return (agent_name, tools)
  454. def _simplify_langchain_tools(tools):
  455. # type: (Any) -> Optional[List[Any]]
  456. """Parse and simplify tools into a cleaner format."""
  457. if not tools:
  458. return None
  459. if not isinstance(tools, (list, tuple)):
  460. return None
  461. simplified_tools = []
  462. for tool in tools:
  463. try:
  464. if isinstance(tool, dict):
  465. if "function" in tool and isinstance(tool["function"], dict):
  466. func = tool["function"]
  467. simplified_tool = {
  468. "name": func.get("name"),
  469. "description": func.get("description"),
  470. }
  471. if simplified_tool["name"]:
  472. simplified_tools.append(simplified_tool)
  473. elif "name" in tool:
  474. simplified_tool = {
  475. "name": tool.get("name"),
  476. "description": tool.get("description"),
  477. }
  478. simplified_tools.append(simplified_tool)
  479. else:
  480. name = (
  481. tool.get("name")
  482. or tool.get("tool_name")
  483. or tool.get("function_name")
  484. )
  485. if name:
  486. simplified_tools.append(
  487. {
  488. "name": name,
  489. "description": tool.get("description")
  490. or tool.get("desc"),
  491. }
  492. )
  493. elif hasattr(tool, "name"):
  494. simplified_tool = {
  495. "name": getattr(tool, "name", None),
  496. "description": getattr(tool, "description", None)
  497. or getattr(tool, "desc", None),
  498. }
  499. if simplified_tool["name"]:
  500. simplified_tools.append(simplified_tool)
  501. elif hasattr(tool, "__name__"):
  502. simplified_tools.append(
  503. {
  504. "name": tool.__name__,
  505. "description": getattr(tool, "__doc__", None),
  506. }
  507. )
  508. else:
  509. tool_str = str(tool)
  510. if tool_str and tool_str != "":
  511. simplified_tools.append({"name": tool_str, "description": None})
  512. except Exception:
  513. continue
  514. return simplified_tools if simplified_tools else None
  515. def _set_tools_on_span(span, tools):
  516. # type: (Span, Any) -> None
  517. """Set available tools data on a span if tools are provided."""
  518. if tools is not None:
  519. simplified_tools = _simplify_langchain_tools(tools)
  520. if simplified_tools:
  521. set_data_normalized(
  522. span,
  523. SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS,
  524. simplified_tools,
  525. unpack=False,
  526. )
  527. def _wrap_configure(f):
  528. # type: (Callable[..., Any]) -> Callable[..., Any]
  529. @wraps(f)
  530. def new_configure(
  531. callback_manager_cls, # type: type
  532. inheritable_callbacks=None, # type: Callbacks
  533. local_callbacks=None, # type: Callbacks
  534. *args, # type: Any
  535. **kwargs, # type: Any
  536. ):
  537. # type: (...) -> Any
  538. integration = sentry_sdk.get_client().get_integration(LangchainIntegration)
  539. if integration is None:
  540. return f(
  541. callback_manager_cls,
  542. inheritable_callbacks,
  543. local_callbacks,
  544. *args,
  545. **kwargs,
  546. )
  547. local_callbacks = local_callbacks or []
  548. # Handle each possible type of local_callbacks. For each type, we
  549. # extract the list of callbacks to check for SentryLangchainCallback,
  550. # and define a function that would add the SentryLangchainCallback
  551. # to the existing callbacks list.
  552. if isinstance(local_callbacks, BaseCallbackManager):
  553. callbacks_list = local_callbacks.handlers
  554. elif isinstance(local_callbacks, BaseCallbackHandler):
  555. callbacks_list = [local_callbacks]
  556. elif isinstance(local_callbacks, list):
  557. callbacks_list = local_callbacks
  558. else:
  559. logger.debug("Unknown callback type: %s", local_callbacks)
  560. # Just proceed with original function call
  561. return f(
  562. callback_manager_cls,
  563. inheritable_callbacks,
  564. local_callbacks,
  565. *args,
  566. **kwargs,
  567. )
  568. # Handle each possible type of inheritable_callbacks.
  569. if isinstance(inheritable_callbacks, BaseCallbackManager):
  570. inheritable_callbacks_list = inheritable_callbacks.handlers
  571. elif isinstance(inheritable_callbacks, list):
  572. inheritable_callbacks_list = inheritable_callbacks
  573. else:
  574. inheritable_callbacks_list = []
  575. if not any(
  576. isinstance(cb, SentryLangchainCallback)
  577. for cb in itertools.chain(callbacks_list, inheritable_callbacks_list)
  578. ):
  579. sentry_handler = SentryLangchainCallback(
  580. integration.max_spans,
  581. integration.include_prompts,
  582. )
  583. if isinstance(local_callbacks, BaseCallbackManager):
  584. local_callbacks = local_callbacks.copy()
  585. local_callbacks.handlers = [
  586. *local_callbacks.handlers,
  587. sentry_handler,
  588. ]
  589. elif isinstance(local_callbacks, BaseCallbackHandler):
  590. local_callbacks = [local_callbacks, sentry_handler]
  591. else:
  592. local_callbacks = [*local_callbacks, sentry_handler]
  593. return f(
  594. callback_manager_cls,
  595. inheritable_callbacks,
  596. local_callbacks,
  597. *args,
  598. **kwargs,
  599. )
  600. return new_configure
  601. def _wrap_agent_executor_invoke(f):
  602. # type: (Callable[..., Any]) -> Callable[..., Any]
  603. @wraps(f)
  604. def new_invoke(self, *args, **kwargs):
  605. # type: (Any, Any, Any) -> Any
  606. integration = sentry_sdk.get_client().get_integration(LangchainIntegration)
  607. if integration is None:
  608. return f(self, *args, **kwargs)
  609. agent_name, tools = _get_request_data(self, args, kwargs)
  610. start_span_function = get_start_span_function()
  611. with start_span_function(
  612. op=OP.GEN_AI_INVOKE_AGENT,
  613. name=f"invoke_agent {agent_name}" if agent_name else "invoke_agent",
  614. origin=LangchainIntegration.origin,
  615. ) as span:
  616. if agent_name:
  617. span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)
  618. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
  619. span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, False)
  620. _set_tools_on_span(span, tools)
  621. # Run the agent
  622. result = f(self, *args, **kwargs)
  623. input = result.get("input")
  624. if (
  625. input is not None
  626. and should_send_default_pii()
  627. and integration.include_prompts
  628. ):
  629. normalized_messages = normalize_message_roles([input])
  630. set_data_normalized(
  631. span,
  632. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  633. normalized_messages,
  634. unpack=False,
  635. )
  636. output = result.get("output")
  637. if (
  638. output is not None
  639. and should_send_default_pii()
  640. and integration.include_prompts
  641. ):
  642. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)
  643. return result
  644. return new_invoke
  645. def _wrap_agent_executor_stream(f):
  646. # type: (Callable[..., Any]) -> Callable[..., Any]
  647. @wraps(f)
  648. def new_stream(self, *args, **kwargs):
  649. # type: (Any, Any, Any) -> Any
  650. integration = sentry_sdk.get_client().get_integration(LangchainIntegration)
  651. if integration is None:
  652. return f(self, *args, **kwargs)
  653. agent_name, tools = _get_request_data(self, args, kwargs)
  654. start_span_function = get_start_span_function()
  655. span = start_span_function(
  656. op=OP.GEN_AI_INVOKE_AGENT,
  657. name=f"invoke_agent {agent_name}".strip(),
  658. origin=LangchainIntegration.origin,
  659. )
  660. span.__enter__()
  661. if agent_name:
  662. span.set_data(SPANDATA.GEN_AI_AGENT_NAME, agent_name)
  663. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
  664. span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)
  665. _set_tools_on_span(span, tools)
  666. input = args[0].get("input") if len(args) >= 1 else None
  667. if (
  668. input is not None
  669. and should_send_default_pii()
  670. and integration.include_prompts
  671. ):
  672. normalized_messages = normalize_message_roles([input])
  673. set_data_normalized(
  674. span,
  675. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  676. normalized_messages,
  677. unpack=False,
  678. )
  679. # Run the agent
  680. result = f(self, *args, **kwargs)
  681. old_iterator = result
  682. def new_iterator():
  683. # type: () -> Iterator[Any]
  684. for event in old_iterator:
  685. yield event
  686. try:
  687. output = event.get("output")
  688. except Exception:
  689. output = None
  690. if (
  691. output is not None
  692. and should_send_default_pii()
  693. and integration.include_prompts
  694. ):
  695. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)
  696. span.__exit__(None, None, None)
  697. async def new_iterator_async():
  698. # type: () -> AsyncIterator[Any]
  699. async for event in old_iterator:
  700. yield event
  701. try:
  702. output = event.get("output")
  703. except Exception:
  704. output = None
  705. if (
  706. output is not None
  707. and should_send_default_pii()
  708. and integration.include_prompts
  709. ):
  710. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, output)
  711. span.__exit__(None, None, None)
  712. if str(type(result)) == "<class 'async_generator'>":
  713. result = new_iterator_async()
  714. else:
  715. result = new_iterator()
  716. return result
  717. return new_stream