utils.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import json
  2. from typing import TYPE_CHECKING
  3. if TYPE_CHECKING:
  4. from typing import Any, Callable
  5. from sentry_sdk.tracing import Span
  6. import sentry_sdk
  7. from sentry_sdk.utils import logger
  8. class GEN_AI_ALLOWED_MESSAGE_ROLES:
  9. SYSTEM = "system"
  10. USER = "user"
  11. ASSISTANT = "assistant"
  12. TOOL = "tool"
  13. GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING = {
  14. GEN_AI_ALLOWED_MESSAGE_ROLES.SYSTEM: ["system"],
  15. GEN_AI_ALLOWED_MESSAGE_ROLES.USER: ["user", "human"],
  16. GEN_AI_ALLOWED_MESSAGE_ROLES.ASSISTANT: ["assistant", "ai"],
  17. GEN_AI_ALLOWED_MESSAGE_ROLES.TOOL: ["tool", "tool_call"],
  18. }
  19. GEN_AI_MESSAGE_ROLE_MAPPING = {}
  20. for target_role, source_roles in GEN_AI_MESSAGE_ROLE_REVERSE_MAPPING.items():
  21. for source_role in source_roles:
  22. GEN_AI_MESSAGE_ROLE_MAPPING[source_role] = target_role
  23. def _normalize_data(data, unpack=True):
  24. # type: (Any, bool) -> Any
  25. # convert pydantic data (e.g. OpenAI v1+) to json compatible format
  26. if hasattr(data, "model_dump"):
  27. try:
  28. return _normalize_data(data.model_dump(), unpack=unpack)
  29. except Exception as e:
  30. logger.warning("Could not convert pydantic data to JSON: %s", e)
  31. return data if isinstance(data, (int, float, bool, str)) else str(data)
  32. if isinstance(data, list):
  33. if unpack and len(data) == 1:
  34. return _normalize_data(data[0], unpack=unpack) # remove empty dimensions
  35. return list(_normalize_data(x, unpack=unpack) for x in data)
  36. if isinstance(data, dict):
  37. return {k: _normalize_data(v, unpack=unpack) for (k, v) in data.items()}
  38. return data if isinstance(data, (int, float, bool, str)) else str(data)
  39. def set_data_normalized(span, key, value, unpack=True):
  40. # type: (Span, str, Any, bool) -> None
  41. normalized = _normalize_data(value, unpack=unpack)
  42. if isinstance(normalized, (int, float, bool, str)):
  43. span.set_data(key, normalized)
  44. else:
  45. span.set_data(key, json.dumps(normalized))
  46. def normalize_message_role(role):
  47. # type: (str) -> str
  48. """
  49. Normalize a message role to one of the 4 allowed gen_ai role values.
  50. Maps "ai" -> "assistant" and keeps other standard roles unchanged.
  51. """
  52. return GEN_AI_MESSAGE_ROLE_MAPPING.get(role, role)
  53. def normalize_message_roles(messages):
  54. # type: (list[dict[str, Any]]) -> list[dict[str, Any]]
  55. """
  56. Normalize roles in a list of messages to use standard gen_ai role values.
  57. Creates a deep copy to avoid modifying the original messages.
  58. """
  59. normalized_messages = []
  60. for message in messages:
  61. if not isinstance(message, dict):
  62. normalized_messages.append(message)
  63. continue
  64. normalized_message = message.copy()
  65. if "role" in message:
  66. normalized_message["role"] = normalize_message_role(message["role"])
  67. normalized_messages.append(normalized_message)
  68. return normalized_messages
  69. def get_start_span_function():
  70. # type: () -> Callable[..., Any]
  71. current_span = sentry_sdk.get_current_span()
  72. transaction_exists = (
  73. current_span is not None and current_span.containing_transaction is not None
  74. )
  75. return sentry_sdk.start_span if transaction_exists else sentry_sdk.start_transaction