From 12fed83d6e1a86a07d33ea9151db7638bd78562e Mon Sep 17 00:00:00 2001 From: kempersc Date: Fri, 9 Jan 2026 18:57:40 +0100 Subject: [PATCH] fix(rag): preserve count value for COUNT queries in non-streaming endpoint - Detect COUNT queries by checking for 'count' key in SPARQL results - Skip institution transformation for COUNT queries to preserve count value - Fixes bug where 'Hoeveel archieven in Utrecht?' returned 1 instead of 10 - COUNT queries now correctly extract integer count from SPARQL response --- backend/rag/main.py | 91 ++++++++++++++++++++++++++++----------------- 1 file changed, 57 insertions(+), 34 deletions(-) diff --git a/backend/rag/main.py b/backend/rag/main.py index 37c52e96d4..1628c31bb0 100644 --- a/backend/rag/main.py +++ b/backend/rag/main.py @@ -3052,41 +3052,64 @@ async def dspy_query(request: DSPyQueryRequest) -> DSPyQueryResponse: for binding in bindings ] - # Transform SPARQL results to match frontend expected format - # Frontend expects: {name, website, metadata: {latitude, longitude, city, ...}} - # SPARQL returns: {name, website, lat, lon, city, ...} - sparql_results = [] - for row in raw_results: - # Parse lat/lon to float if present - lat = None - lon = None - if row.get("lat"): - try: - lat = float(row["lat"]) - except (ValueError, TypeError): - pass - if row.get("lon"): - try: - lon = float(row["lon"]) - except (ValueError, TypeError): - pass - - transformed = { - "name": row.get("name"), - "website": row.get("website"), - "metadata": { - "latitude": lat, - "longitude": lon, - "city": row.get("city") or template_result.slots.get("city"), - "country": row.get("country") or template_result.slots.get("country"), - "region": row.get("region") or template_result.slots.get("region"), - "institution_type": row.get("type") or template_result.slots.get("institution_type"), - }, - "scores": {"combined": 1.0}, # SPARQL results are exact matches - } - sparql_results.append(transformed) + # Check if this is a COUNT query (raw_results has 'count' key) + # COUNT queries return [{"count": "10"}] - don't transform these + is_count_query = raw_results and "count" in raw_results[0] - logger.debug(f"[FACTUAL-QUERY] Transformed {len(sparql_results)} results, {sum(1 for r in sparql_results if r['metadata']['latitude'])} with coordinates") + if is_count_query: + # For COUNT queries, preserve raw results with count value + # Convert count string to int for template rendering + sparql_results = [] + for row in raw_results: + count_val = row.get("count", "0") + try: + count_int = int(count_val) + except (ValueError, TypeError): + count_int = 0 + sparql_results.append({ + "count": count_int, + "metadata": { + "institution_type": template_result.slots.get("institution_type"), + }, + "scores": {"combined": 1.0}, + }) + logger.debug(f"[FACTUAL-QUERY] COUNT query result: {sparql_results[0].get('count') if sparql_results else 0}") + else: + # Transform SPARQL results to match frontend expected format + # Frontend expects: {name, website, metadata: {latitude, longitude, city, ...}} + # SPARQL returns: {name, website, lat, lon, city, ...} + sparql_results = [] + for row in raw_results: + # Parse lat/lon to float if present + lat = None + lon = None + if row.get("lat"): + try: + lat = float(row["lat"]) + except (ValueError, TypeError): + pass + if row.get("lon"): + try: + lon = float(row["lon"]) + except (ValueError, TypeError): + pass + + transformed = { + "name": row.get("name"), + "website": row.get("website"), + "metadata": { + "latitude": lat, + "longitude": lon, + "city": row.get("city") or template_result.slots.get("city"), + "country": row.get("country") or template_result.slots.get("country"), + "region": row.get("region") or template_result.slots.get("region"), + "institution_type": row.get("type") or template_result.slots.get("institution_type"), + }, + "scores": {"combined": 1.0}, # SPARQL results are exact matches + } + sparql_results.append(transformed) + + logger.debug(f"[FACTUAL-QUERY] Transformed {len(sparql_results)} results, {sum(1 for r in sparql_results if r['metadata']['latitude'])} with coordinates") else: sparql_error = f"SPARQL returned {response.status_code}" else: