fix(rag): preserve count value for COUNT queries in non-streaming endpoint
- Detect COUNT queries by checking for 'count' key in SPARQL results - Skip institution transformation for COUNT queries to preserve count value - Fixes bug where 'Hoeveel archieven in Utrecht?' returned 1 instead of 10 - COUNT queries now correctly extract integer count from SPARQL response
This commit is contained in:
parent
8a7ed757b8
commit
12fed83d6e
1 changed files with 57 additions and 34 deletions
|
|
@ -3052,41 +3052,64 @@ async def dspy_query(request: DSPyQueryRequest) -> DSPyQueryResponse:
|
|||
for binding in bindings
|
||||
]
|
||||
|
||||
# Transform SPARQL results to match frontend expected format
|
||||
# Frontend expects: {name, website, metadata: {latitude, longitude, city, ...}}
|
||||
# SPARQL returns: {name, website, lat, lon, city, ...}
|
||||
sparql_results = []
|
||||
for row in raw_results:
|
||||
# Parse lat/lon to float if present
|
||||
lat = None
|
||||
lon = None
|
||||
if row.get("lat"):
|
||||
try:
|
||||
lat = float(row["lat"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if row.get("lon"):
|
||||
try:
|
||||
lon = float(row["lon"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
transformed = {
|
||||
"name": row.get("name"),
|
||||
"website": row.get("website"),
|
||||
"metadata": {
|
||||
"latitude": lat,
|
||||
"longitude": lon,
|
||||
"city": row.get("city") or template_result.slots.get("city"),
|
||||
"country": row.get("country") or template_result.slots.get("country"),
|
||||
"region": row.get("region") or template_result.slots.get("region"),
|
||||
"institution_type": row.get("type") or template_result.slots.get("institution_type"),
|
||||
},
|
||||
"scores": {"combined": 1.0}, # SPARQL results are exact matches
|
||||
}
|
||||
sparql_results.append(transformed)
|
||||
# Check if this is a COUNT query (raw_results has 'count' key)
|
||||
# COUNT queries return [{"count": "10"}] - don't transform these
|
||||
is_count_query = raw_results and "count" in raw_results[0]
|
||||
|
||||
logger.debug(f"[FACTUAL-QUERY] Transformed {len(sparql_results)} results, {sum(1 for r in sparql_results if r['metadata']['latitude'])} with coordinates")
|
||||
if is_count_query:
|
||||
# For COUNT queries, preserve raw results with count value
|
||||
# Convert count string to int for template rendering
|
||||
sparql_results = []
|
||||
for row in raw_results:
|
||||
count_val = row.get("count", "0")
|
||||
try:
|
||||
count_int = int(count_val)
|
||||
except (ValueError, TypeError):
|
||||
count_int = 0
|
||||
sparql_results.append({
|
||||
"count": count_int,
|
||||
"metadata": {
|
||||
"institution_type": template_result.slots.get("institution_type"),
|
||||
},
|
||||
"scores": {"combined": 1.0},
|
||||
})
|
||||
logger.debug(f"[FACTUAL-QUERY] COUNT query result: {sparql_results[0].get('count') if sparql_results else 0}")
|
||||
else:
|
||||
# Transform SPARQL results to match frontend expected format
|
||||
# Frontend expects: {name, website, metadata: {latitude, longitude, city, ...}}
|
||||
# SPARQL returns: {name, website, lat, lon, city, ...}
|
||||
sparql_results = []
|
||||
for row in raw_results:
|
||||
# Parse lat/lon to float if present
|
||||
lat = None
|
||||
lon = None
|
||||
if row.get("lat"):
|
||||
try:
|
||||
lat = float(row["lat"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if row.get("lon"):
|
||||
try:
|
||||
lon = float(row["lon"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
transformed = {
|
||||
"name": row.get("name"),
|
||||
"website": row.get("website"),
|
||||
"metadata": {
|
||||
"latitude": lat,
|
||||
"longitude": lon,
|
||||
"city": row.get("city") or template_result.slots.get("city"),
|
||||
"country": row.get("country") or template_result.slots.get("country"),
|
||||
"region": row.get("region") or template_result.slots.get("region"),
|
||||
"institution_type": row.get("type") or template_result.slots.get("institution_type"),
|
||||
},
|
||||
"scores": {"combined": 1.0}, # SPARQL results are exact matches
|
||||
}
|
||||
sparql_results.append(transformed)
|
||||
|
||||
logger.debug(f"[FACTUAL-QUERY] Transformed {len(sparql_results)} results, {sum(1 for r in sparql_results if r['metadata']['latitude'])} with coordinates")
|
||||
else:
|
||||
sparql_error = f"SPARQL returned {response.status_code}"
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in a new issue