langgraph.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. from functools import wraps
  2. from typing import Any, Callable, List, Optional
  3. import sentry_sdk
  4. from sentry_sdk.ai.utils import set_data_normalized, normalize_message_roles
  5. from sentry_sdk.consts import OP, SPANDATA
  6. from sentry_sdk.integrations import DidNotEnable, Integration
  7. from sentry_sdk.scope import should_send_default_pii
  8. from sentry_sdk.utils import safe_serialize
  9. try:
  10. from langgraph.graph import StateGraph
  11. from langgraph.pregel import Pregel
  12. except ImportError:
  13. raise DidNotEnable("langgraph not installed")
  14. class LanggraphIntegration(Integration):
  15. identifier = "langgraph"
  16. origin = f"auto.ai.{identifier}"
  17. def __init__(self, include_prompts=True):
  18. # type: (LanggraphIntegration, bool) -> None
  19. self.include_prompts = include_prompts
  20. @staticmethod
  21. def setup_once():
  22. # type: () -> None
  23. # LangGraph lets users create agents using a StateGraph or the Functional API.
  24. # StateGraphs are then compiled to a CompiledStateGraph. Both CompiledStateGraph and
  25. # the functional API execute on a Pregel instance. Pregel is the runtime for the graph
  26. # and the invocation happens on Pregel, so patching the invoke methods takes care of both.
  27. # The streaming methods are not patched, because due to some internal reasons, LangGraph
  28. # will automatically patch the streaming methods to run through invoke, and by doing this
  29. # we prevent duplicate spans for invocations.
  30. StateGraph.compile = _wrap_state_graph_compile(StateGraph.compile)
  31. if hasattr(Pregel, "invoke"):
  32. Pregel.invoke = _wrap_pregel_invoke(Pregel.invoke)
  33. if hasattr(Pregel, "ainvoke"):
  34. Pregel.ainvoke = _wrap_pregel_ainvoke(Pregel.ainvoke)
  35. def _get_graph_name(graph_obj):
  36. # type: (Any) -> Optional[str]
  37. for attr in ["name", "graph_name", "__name__", "_name"]:
  38. if hasattr(graph_obj, attr):
  39. name = getattr(graph_obj, attr)
  40. if name and isinstance(name, str):
  41. return name
  42. return None
  43. def _normalize_langgraph_message(message):
  44. # type: (Any) -> Any
  45. if not hasattr(message, "content"):
  46. return None
  47. parsed = {"role": getattr(message, "type", None), "content": message.content}
  48. for attr in ["name", "tool_calls", "function_call", "tool_call_id"]:
  49. if hasattr(message, attr):
  50. value = getattr(message, attr)
  51. if value is not None:
  52. parsed[attr] = value
  53. return parsed
  54. def _parse_langgraph_messages(state):
  55. # type: (Any) -> Optional[List[Any]]
  56. if not state:
  57. return None
  58. messages = None
  59. if isinstance(state, dict):
  60. messages = state.get("messages")
  61. elif hasattr(state, "messages"):
  62. messages = state.messages
  63. elif hasattr(state, "get") and callable(state.get):
  64. try:
  65. messages = state.get("messages")
  66. except Exception:
  67. pass
  68. if not messages or not isinstance(messages, (list, tuple)):
  69. return None
  70. normalized_messages = []
  71. for message in messages:
  72. try:
  73. normalized = _normalize_langgraph_message(message)
  74. if normalized:
  75. normalized_messages.append(normalized)
  76. except Exception:
  77. continue
  78. return normalized_messages if normalized_messages else None
  79. def _wrap_state_graph_compile(f):
  80. # type: (Callable[..., Any]) -> Callable[..., Any]
  81. @wraps(f)
  82. def new_compile(self, *args, **kwargs):
  83. # type: (Any, Any, Any) -> Any
  84. integration = sentry_sdk.get_client().get_integration(LanggraphIntegration)
  85. if integration is None:
  86. return f(self, *args, **kwargs)
  87. with sentry_sdk.start_span(
  88. op=OP.GEN_AI_CREATE_AGENT,
  89. origin=LanggraphIntegration.origin,
  90. ) as span:
  91. compiled_graph = f(self, *args, **kwargs)
  92. compiled_graph_name = getattr(compiled_graph, "name", None)
  93. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "create_agent")
  94. span.set_data(SPANDATA.GEN_AI_AGENT_NAME, compiled_graph_name)
  95. if compiled_graph_name:
  96. span.description = f"create_agent {compiled_graph_name}"
  97. else:
  98. span.description = "create_agent"
  99. if kwargs.get("model", None) is not None:
  100. span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, kwargs.get("model"))
  101. tools = None
  102. get_graph = getattr(compiled_graph, "get_graph", None)
  103. if get_graph and callable(get_graph):
  104. graph_obj = compiled_graph.get_graph()
  105. nodes = getattr(graph_obj, "nodes", None)
  106. if nodes and isinstance(nodes, dict):
  107. tools_node = nodes.get("tools")
  108. if tools_node:
  109. data = getattr(tools_node, "data", None)
  110. if data and hasattr(data, "tools_by_name"):
  111. tools = list(data.tools_by_name.keys())
  112. if tools is not None:
  113. span.set_data(SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, tools)
  114. return compiled_graph
  115. return new_compile
  116. def _wrap_pregel_invoke(f):
  117. # type: (Callable[..., Any]) -> Callable[..., Any]
  118. @wraps(f)
  119. def new_invoke(self, *args, **kwargs):
  120. # type: (Any, Any, Any) -> Any
  121. integration = sentry_sdk.get_client().get_integration(LanggraphIntegration)
  122. if integration is None:
  123. return f(self, *args, **kwargs)
  124. graph_name = _get_graph_name(self)
  125. span_name = (
  126. f"invoke_agent {graph_name}".strip() if graph_name else "invoke_agent"
  127. )
  128. with sentry_sdk.start_span(
  129. op=OP.GEN_AI_INVOKE_AGENT,
  130. name=span_name,
  131. origin=LanggraphIntegration.origin,
  132. ) as span:
  133. if graph_name:
  134. span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, graph_name)
  135. span.set_data(SPANDATA.GEN_AI_AGENT_NAME, graph_name)
  136. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
  137. # Store input messages to later compare with output
  138. input_messages = None
  139. if (
  140. len(args) > 0
  141. and should_send_default_pii()
  142. and integration.include_prompts
  143. ):
  144. input_messages = _parse_langgraph_messages(args[0])
  145. if input_messages:
  146. normalized_input_messages = normalize_message_roles(input_messages)
  147. set_data_normalized(
  148. span,
  149. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  150. normalized_input_messages,
  151. unpack=False,
  152. )
  153. result = f(self, *args, **kwargs)
  154. _set_response_attributes(span, input_messages, result, integration)
  155. return result
  156. return new_invoke
  157. def _wrap_pregel_ainvoke(f):
  158. # type: (Callable[..., Any]) -> Callable[..., Any]
  159. @wraps(f)
  160. async def new_ainvoke(self, *args, **kwargs):
  161. # type: (Any, Any, Any) -> Any
  162. integration = sentry_sdk.get_client().get_integration(LanggraphIntegration)
  163. if integration is None:
  164. return await f(self, *args, **kwargs)
  165. graph_name = _get_graph_name(self)
  166. span_name = (
  167. f"invoke_agent {graph_name}".strip() if graph_name else "invoke_agent"
  168. )
  169. with sentry_sdk.start_span(
  170. op=OP.GEN_AI_INVOKE_AGENT,
  171. name=span_name,
  172. origin=LanggraphIntegration.origin,
  173. ) as span:
  174. if graph_name:
  175. span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, graph_name)
  176. span.set_data(SPANDATA.GEN_AI_AGENT_NAME, graph_name)
  177. span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
  178. input_messages = None
  179. if (
  180. len(args) > 0
  181. and should_send_default_pii()
  182. and integration.include_prompts
  183. ):
  184. input_messages = _parse_langgraph_messages(args[0])
  185. if input_messages:
  186. normalized_input_messages = normalize_message_roles(input_messages)
  187. set_data_normalized(
  188. span,
  189. SPANDATA.GEN_AI_REQUEST_MESSAGES,
  190. normalized_input_messages,
  191. unpack=False,
  192. )
  193. result = await f(self, *args, **kwargs)
  194. _set_response_attributes(span, input_messages, result, integration)
  195. return result
  196. return new_ainvoke
  197. def _get_new_messages(input_messages, output_messages):
  198. # type: (Optional[List[Any]], Optional[List[Any]]) -> Optional[List[Any]]
  199. """Extract only the new messages added during this invocation."""
  200. if not output_messages:
  201. return None
  202. if not input_messages:
  203. return output_messages
  204. # only return the new messages, aka the output messages that are not in the input messages
  205. input_count = len(input_messages)
  206. new_messages = (
  207. output_messages[input_count:] if len(output_messages) > input_count else []
  208. )
  209. return new_messages if new_messages else None
  210. def _extract_llm_response_text(messages):
  211. # type: (Optional[List[Any]]) -> Optional[str]
  212. if not messages:
  213. return None
  214. for message in reversed(messages):
  215. if isinstance(message, dict):
  216. role = message.get("role")
  217. if role in ["assistant", "ai"]:
  218. content = message.get("content")
  219. if content and isinstance(content, str):
  220. return content
  221. return None
  222. def _extract_tool_calls(messages):
  223. # type: (Optional[List[Any]]) -> Optional[List[Any]]
  224. if not messages:
  225. return None
  226. tool_calls = []
  227. for message in messages:
  228. if isinstance(message, dict):
  229. msg_tool_calls = message.get("tool_calls")
  230. if msg_tool_calls and isinstance(msg_tool_calls, list):
  231. tool_calls.extend(msg_tool_calls)
  232. return tool_calls if tool_calls else None
  233. def _set_response_attributes(span, input_messages, result, integration):
  234. # type: (Any, Optional[List[Any]], Any, LanggraphIntegration) -> None
  235. if not (should_send_default_pii() and integration.include_prompts):
  236. return
  237. parsed_response_messages = _parse_langgraph_messages(result)
  238. new_messages = _get_new_messages(input_messages, parsed_response_messages)
  239. llm_response_text = _extract_llm_response_text(new_messages)
  240. if llm_response_text:
  241. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, llm_response_text)
  242. elif new_messages:
  243. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, new_messages)
  244. else:
  245. set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, result)
  246. tool_calls = _extract_tool_calls(new_messages)
  247. if tool_calls:
  248. set_data_normalized(
  249. span,
  250. SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
  251. safe_serialize(tool_calls),
  252. unpack=False,
  253. )