227 lines
6.2 KiB
Python
227 lines
6.2 KiB
Python
"""
|
|
SPARQL Correctness Metrics
|
|
|
|
Validates SPARQL syntax and query results.
|
|
"""
|
|
|
|
import re
|
|
from typing import Any, Optional
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def validate_sparql_syntax(sparql: str) -> tuple[bool, Optional[str]]:
|
|
"""Validate SPARQL syntax using basic rules.
|
|
|
|
Args:
|
|
sparql: SPARQL query string
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
"""
|
|
if not sparql or not sparql.strip():
|
|
return False, "Empty query"
|
|
|
|
sparql_upper = sparql.upper()
|
|
|
|
# Must have SELECT, ASK, CONSTRUCT, or DESCRIBE
|
|
if not any(kw in sparql_upper for kw in ["SELECT", "ASK", "CONSTRUCT", "DESCRIBE"]):
|
|
return False, "Missing query form (SELECT/ASK/CONSTRUCT/DESCRIBE)"
|
|
|
|
# Must have WHERE clause (except simple ASK)
|
|
if "SELECT" in sparql_upper and "WHERE" not in sparql_upper:
|
|
return False, "Missing WHERE clause"
|
|
|
|
# Check balanced braces
|
|
if sparql.count("{") != sparql.count("}"):
|
|
return False, "Unbalanced braces"
|
|
|
|
# Check balanced parentheses
|
|
if sparql.count("(") != sparql.count(")"):
|
|
return False, "Unbalanced parentheses"
|
|
|
|
# Check for common errors
|
|
if "??" in sparql:
|
|
return False, "Double question mark in variable"
|
|
|
|
return True, None
|
|
|
|
|
|
def check_required_prefixes(sparql: str, required: list[str] = None) -> tuple[bool, list[str]]:
|
|
"""Check if SPARQL has required prefixes.
|
|
|
|
Args:
|
|
sparql: SPARQL query string
|
|
required: List of required prefixes (e.g., ["hc:", "crm:"])
|
|
|
|
Returns:
|
|
Tuple of (has_all, missing_prefixes)
|
|
"""
|
|
if required is None:
|
|
required = ["hc:", "crm:"] # Default heritage prefixes
|
|
|
|
sparql_lower = sparql.lower()
|
|
missing = []
|
|
|
|
for prefix in required:
|
|
# Check if prefix is used but not declared
|
|
prefix_lower = prefix.lower()
|
|
if prefix_lower in sparql_lower:
|
|
# Should have PREFIX declaration
|
|
if f"prefix {prefix_lower.rstrip(':')}" not in sparql_lower:
|
|
missing.append(prefix)
|
|
|
|
return len(missing) == 0, missing
|
|
|
|
|
|
def sparql_validation_score(sparql: str) -> float:
|
|
"""Calculate SPARQL validation score.
|
|
|
|
Args:
|
|
sparql: SPARQL query string
|
|
|
|
Returns:
|
|
Score 0.0-1.0 based on validity
|
|
"""
|
|
is_valid, error = validate_sparql_syntax(sparql)
|
|
if not is_valid:
|
|
return 0.0
|
|
|
|
score = 1.0
|
|
|
|
# Deduct for missing prefix declarations
|
|
has_prefixes, missing = check_required_prefixes(sparql)
|
|
if not has_prefixes:
|
|
score -= 0.1 * len(missing)
|
|
|
|
return max(0.0, score)
|
|
|
|
|
|
async def execute_sparql_query(
|
|
sparql: str,
|
|
endpoint: str = "http://91.98.224.44:7878/query"
|
|
) -> tuple[bool, Any]:
|
|
"""Execute SPARQL query against Oxigraph.
|
|
|
|
Args:
|
|
sparql: SPARQL query string
|
|
endpoint: SPARQL endpoint URL
|
|
|
|
Returns:
|
|
Tuple of (success, results_or_error)
|
|
"""
|
|
try:
|
|
import httpx
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.post(
|
|
endpoint,
|
|
data={"query": sparql},
|
|
headers={"Accept": "application/sparql-results+json"}
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
return True, response.json()
|
|
else:
|
|
return False, f"HTTP {response.status_code}: {response.text[:200]}"
|
|
except Exception as e:
|
|
return False, str(e)
|
|
|
|
|
|
def sparql_result_score(
|
|
results: dict,
|
|
expected_min: int = 0,
|
|
expected_max: int = None,
|
|
) -> float:
|
|
"""Score SPARQL results based on expectations.
|
|
|
|
Args:
|
|
results: SPARQL JSON results
|
|
expected_min: Minimum expected results
|
|
expected_max: Maximum expected results (None = no limit)
|
|
|
|
Returns:
|
|
Score 0.0-1.0
|
|
"""
|
|
try:
|
|
bindings = results.get("results", {}).get("bindings", [])
|
|
count = len(bindings)
|
|
|
|
# No results when some expected
|
|
if count == 0 and expected_min > 0:
|
|
return 0.0
|
|
|
|
# Within expected range
|
|
if count >= expected_min:
|
|
if expected_max is None or count <= expected_max:
|
|
return 1.0
|
|
else:
|
|
# Penalty for too many results
|
|
return max(0.5, 1.0 - (count - expected_max) / expected_max * 0.5)
|
|
|
|
# Below minimum
|
|
return count / expected_min
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error scoring SPARQL results: {e}")
|
|
return 0.0
|
|
|
|
|
|
def sparql_correctness_metric(example: Any, pred: Any, trace: Any = None) -> float:
|
|
"""DSPy-compatible SPARQL correctness metric.
|
|
|
|
Only checks syntax - does not execute query.
|
|
|
|
Args:
|
|
example: DSPy Example
|
|
pred: Prediction with sparql field
|
|
trace: Optional trace
|
|
|
|
Returns:
|
|
Validation score 0.0-1.0
|
|
"""
|
|
sparql = getattr(pred, "sparql", None)
|
|
if not sparql:
|
|
return 0.0
|
|
|
|
return sparql_validation_score(sparql)
|
|
|
|
|
|
# Common SPARQL patterns for heritage queries
|
|
HERITAGE_SPARQL_PATTERNS = {
|
|
"count_by_type": re.compile(
|
|
r"SELECT.*COUNT.*WHERE.*institutionType",
|
|
re.IGNORECASE | re.DOTALL
|
|
),
|
|
"list_by_location": re.compile(
|
|
r"SELECT.*WHERE.*addressLocality|addressCountry",
|
|
re.IGNORECASE | re.DOTALL
|
|
),
|
|
"entity_lookup": re.compile(
|
|
r"SELECT.*WHERE.*prefLabel.*FILTER.*CONTAINS",
|
|
re.IGNORECASE | re.DOTALL
|
|
),
|
|
}
|
|
|
|
|
|
def matches_expected_pattern(sparql: str, intent: str) -> bool:
|
|
"""Check if SPARQL matches expected pattern for intent.
|
|
|
|
Args:
|
|
sparql: SPARQL query string
|
|
intent: Query intent
|
|
|
|
Returns:
|
|
True if pattern matches
|
|
"""
|
|
pattern_map = {
|
|
"statistical": HERITAGE_SPARQL_PATTERNS["count_by_type"],
|
|
"geographic": HERITAGE_SPARQL_PATTERNS["list_by_location"],
|
|
"entity_lookup": HERITAGE_SPARQL_PATTERNS["entity_lookup"],
|
|
}
|
|
|
|
pattern = pattern_map.get(intent)
|
|
if pattern:
|
|
return bool(pattern.search(sparql))
|
|
|
|
return True # No pattern defined, assume OK
|