glam/tests/dspy_gitops/metrics/sparql_correctness.py
2026-01-11 18:08:40 +01:00

227 lines
6.2 KiB
Python

"""
SPARQL Correctness Metrics
Validates SPARQL syntax and query results.
"""
import re
from typing import Any, Optional
import logging
logger = logging.getLogger(__name__)
def validate_sparql_syntax(sparql: str) -> tuple[bool, Optional[str]]:
"""Validate SPARQL syntax using basic rules.
Args:
sparql: SPARQL query string
Returns:
Tuple of (is_valid, error_message)
"""
if not sparql or not sparql.strip():
return False, "Empty query"
sparql_upper = sparql.upper()
# Must have SELECT, ASK, CONSTRUCT, or DESCRIBE
if not any(kw in sparql_upper for kw in ["SELECT", "ASK", "CONSTRUCT", "DESCRIBE"]):
return False, "Missing query form (SELECT/ASK/CONSTRUCT/DESCRIBE)"
# Must have WHERE clause (except simple ASK)
if "SELECT" in sparql_upper and "WHERE" not in sparql_upper:
return False, "Missing WHERE clause"
# Check balanced braces
if sparql.count("{") != sparql.count("}"):
return False, "Unbalanced braces"
# Check balanced parentheses
if sparql.count("(") != sparql.count(")"):
return False, "Unbalanced parentheses"
# Check for common errors
if "??" in sparql:
return False, "Double question mark in variable"
return True, None
def check_required_prefixes(sparql: str, required: list[str] = None) -> tuple[bool, list[str]]:
"""Check if SPARQL has required prefixes.
Args:
sparql: SPARQL query string
required: List of required prefixes (e.g., ["hc:", "crm:"])
Returns:
Tuple of (has_all, missing_prefixes)
"""
if required is None:
required = ["hc:", "crm:"] # Default heritage prefixes
sparql_lower = sparql.lower()
missing = []
for prefix in required:
# Check if prefix is used but not declared
prefix_lower = prefix.lower()
if prefix_lower in sparql_lower:
# Should have PREFIX declaration
if f"prefix {prefix_lower.rstrip(':')}" not in sparql_lower:
missing.append(prefix)
return len(missing) == 0, missing
def sparql_validation_score(sparql: str) -> float:
"""Calculate SPARQL validation score.
Args:
sparql: SPARQL query string
Returns:
Score 0.0-1.0 based on validity
"""
is_valid, error = validate_sparql_syntax(sparql)
if not is_valid:
return 0.0
score = 1.0
# Deduct for missing prefix declarations
has_prefixes, missing = check_required_prefixes(sparql)
if not has_prefixes:
score -= 0.1 * len(missing)
return max(0.0, score)
async def execute_sparql_query(
sparql: str,
endpoint: str = "http://91.98.224.44:7878/query"
) -> tuple[bool, Any]:
"""Execute SPARQL query against Oxigraph.
Args:
sparql: SPARQL query string
endpoint: SPARQL endpoint URL
Returns:
Tuple of (success, results_or_error)
"""
try:
import httpx
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
endpoint,
data={"query": sparql},
headers={"Accept": "application/sparql-results+json"}
)
if response.status_code == 200:
return True, response.json()
else:
return False, f"HTTP {response.status_code}: {response.text[:200]}"
except Exception as e:
return False, str(e)
def sparql_result_score(
results: dict,
expected_min: int = 0,
expected_max: int = None,
) -> float:
"""Score SPARQL results based on expectations.
Args:
results: SPARQL JSON results
expected_min: Minimum expected results
expected_max: Maximum expected results (None = no limit)
Returns:
Score 0.0-1.0
"""
try:
bindings = results.get("results", {}).get("bindings", [])
count = len(bindings)
# No results when some expected
if count == 0 and expected_min > 0:
return 0.0
# Within expected range
if count >= expected_min:
if expected_max is None or count <= expected_max:
return 1.0
else:
# Penalty for too many results
return max(0.5, 1.0 - (count - expected_max) / expected_max * 0.5)
# Below minimum
return count / expected_min
except Exception as e:
logger.warning(f"Error scoring SPARQL results: {e}")
return 0.0
def sparql_correctness_metric(example: Any, pred: Any, trace: Any = None) -> float:
"""DSPy-compatible SPARQL correctness metric.
Only checks syntax - does not execute query.
Args:
example: DSPy Example
pred: Prediction with sparql field
trace: Optional trace
Returns:
Validation score 0.0-1.0
"""
sparql = getattr(pred, "sparql", None)
if not sparql:
return 0.0
return sparql_validation_score(sparql)
# Common SPARQL patterns for heritage queries
HERITAGE_SPARQL_PATTERNS = {
"count_by_type": re.compile(
r"SELECT.*COUNT.*WHERE.*institutionType",
re.IGNORECASE | re.DOTALL
),
"list_by_location": re.compile(
r"SELECT.*WHERE.*addressLocality|addressCountry",
re.IGNORECASE | re.DOTALL
),
"entity_lookup": re.compile(
r"SELECT.*WHERE.*prefLabel.*FILTER.*CONTAINS",
re.IGNORECASE | re.DOTALL
),
}
def matches_expected_pattern(sparql: str, intent: str) -> bool:
"""Check if SPARQL matches expected pattern for intent.
Args:
sparql: SPARQL query string
intent: Query intent
Returns:
True if pattern matches
"""
pattern_map = {
"statistical": HERITAGE_SPARQL_PATTERNS["count_by_type"],
"geographic": HERITAGE_SPARQL_PATTERNS["list_by_location"],
"entity_lookup": HERITAGE_SPARQL_PATTERNS["entity_lookup"],
}
pattern = pattern_map.get(intent)
if pattern:
return bool(pattern.search(sparql))
return True # No pattern defined, assume OK