glam/frontend/src/components/database/EmbeddingProjector.tsx
2025-12-10 13:01:13 +01:00

1493 lines
47 KiB
TypeScript

/**
* Embedding Projector Component
*
* A TensorFlow Projector-inspired visualization tool for high-dimensional embeddings.
* Supports multiple dimensionality reduction techniques:
* - PCA (Principal Component Analysis) - fast, deterministic
* - UMAP (Uniform Manifold Approximation and Projection) - preserves local structure
* - t-SNE (t-distributed Stochastic Neighbor Embedding) - cluster visualization
*
* Features:
* - 2D/3D interactive visualization
* - Point search and filtering
* - Nearest neighbor exploration
* - Color coding by metadata fields
* - Pan, zoom, rotate controls
* - Export/import projections
*
* References:
* - TensorFlow Embedding Projector: https://projector.tensorflow.org/
* - UMAP: https://umap-learn.readthedocs.io/
* - t-SNE: https://distill.pub/2016/misread-tsne/
*/
import { useState, useEffect, useRef, useCallback, useMemo } from 'react';
import * as d3 from 'd3';
import * as THREE from 'three';
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls.js';
import { UMAP } from 'umap-js';
import { useLanguage } from '@/contexts/LanguageContext';
import { PointDetailsPanel } from './PointDetailsPanel';
// Types
export interface EmbeddingPoint {
id: string | number;
vector: number[];
payload: Record<string, unknown>;
}
export interface ProjectedPoint {
x: number;
y: number;
z?: number;
originalIndex: number;
}
export type ProjectionMethod = 'pca' | 'umap' | 'tsne';
export type ViewMode = '2d' | '3d';
interface EmbeddingProjectorProps {
points: EmbeddingPoint[];
onPointSelect?: (point: EmbeddingPoint | null) => void;
colorByField?: string;
height?: number;
width?: number;
/** Simple mode hides advanced controls for cleaner UI */
simpleMode?: boolean;
/** Indices of points to highlight (e.g., RAG sources) */
highlightedIndices?: number[];
/** Callback when user selects a point to add to conversation context */
onContextSelect?: (point: EmbeddingPoint) => void;
/** Show "Add to context" button on point selection */
showContextButton?: boolean;
/** Title override for simple mode */
title?: string;
}
// Translation strings
const TEXT = {
title: { nl: 'Embedding Projector', en: 'Embedding Projector' },
description: {
nl: 'Visualiseer hoog-dimensionale embeddings in 2D/3D',
en: 'Visualize high-dimensional embeddings in 2D/3D',
},
projectionMethod: { nl: 'Projectie methode', en: 'Projection method' },
viewMode: { nl: 'Weergave', en: 'View mode' },
colorBy: { nl: 'Kleur op', en: 'Color by' },
noField: { nl: 'Geen', en: 'None' },
neighbors: { nl: 'Buren', en: 'Neighbors' },
search: { nl: 'Zoeken...', en: 'Search...' },
computing: { nl: 'Berekenen...', en: 'Computing...' },
run: { nl: 'Start', en: 'Run' },
stop: { nl: 'Stop', en: 'Stop' },
reset: { nl: 'Reset', en: 'Reset' },
iteration: { nl: 'Iteratie', en: 'Iteration' },
perplexity: { nl: 'Perplexiteit', en: 'Perplexity' },
learningRate: { nl: 'Leersnelheid', en: 'Learning rate' },
umapNeighbors: { nl: 'Buren', en: 'Neighbors' },
minDist: { nl: 'Min afstand', en: 'Min distance' },
spread: { nl: 'Spreiding', en: 'Spread' },
pcaComponents: { nl: 'Componenten', en: 'Components' },
pointsLoaded: { nl: 'punten geladen', en: 'points loaded' },
dimensions: { nl: 'dimensies', en: 'dimensions' },
selectedPoint: { nl: 'Geselecteerd punt', en: 'Selected point' },
nearestNeighbors: { nl: 'Dichtstbijzijnde buren', en: 'Nearest neighbors' },
distance: { nl: 'Afstand', en: 'Distance' },
showLabels: { nl: 'Labels tonen', en: 'Show labels' },
sphereize: { nl: 'Sferiseren', en: 'Sphereize' },
variance: { nl: 'Variantie', en: 'Variance' },
addToContext: { nl: 'Toevoegen aan context', en: 'Add to context' },
sourcesHighlighted: { nl: 'bronnen gemarkeerd', en: 'sources highlighted' },
};
// Color palette for categorical data
const COLORS = [
'#6366f1', '#8b5cf6', '#a855f7', '#d946ef', '#ec4899',
'#f43f5e', '#ef4444', '#f97316', '#f59e0b', '#eab308',
'#84cc16', '#22c55e', '#10b981', '#14b8a6', '#06b6d4',
'#0ea5e9', '#3b82f6', '#1d4ed8',
];
/**
* Proper PCA implementation using power iteration for top eigenvectors
*/
function computePCA(vectors: number[][], nComponents: number = 2): {
projected: number[][];
variance: number[];
explained: number[];
} {
if (vectors.length === 0) return { projected: [], variance: [], explained: [] };
const n = vectors.length;
const d = vectors[0].length;
// Center the data
const means = new Array(d).fill(0);
for (const vec of vectors) {
for (let i = 0; i < d; i++) {
means[i] += vec[i] / n;
}
}
const centered = vectors.map(vec => vec.map((v, i) => v - means[i]));
// Compute covariance matrix (d x d can be large, so we use X^T X / n)
// For efficiency with high-d data, we compute principal components via power iteration
const components: number[][] = [];
const eigenvalues: number[] = [];
let workingData = centered.map(row => [...row]);
for (let comp = 0; comp < Math.min(nComponents, d); comp++) {
// Initialize random vector
let v = new Array(d).fill(0).map(() => Math.random() - 0.5);
let norm = Math.sqrt(v.reduce((s, x) => s + x * x, 0));
v = v.map(x => x / norm);
// Power iteration (50 iterations should be enough for convergence)
for (let iter = 0; iter < 50; iter++) {
// Multiply by X^T X
const Xv = workingData.map(row => row.reduce((s, x, i) => s + x * v[i], 0));
const XtXv = new Array(d).fill(0);
for (let i = 0; i < n; i++) {
for (let j = 0; j < d; j++) {
XtXv[j] += workingData[i][j] * Xv[i];
}
}
// Normalize
norm = Math.sqrt(XtXv.reduce((s, x) => s + x * x, 0));
if (norm > 1e-10) {
v = XtXv.map(x => x / norm);
}
}
// Compute eigenvalue
const Xv = workingData.map(row => row.reduce((s, x, i) => s + x * v[i], 0));
const eigenvalue = Xv.reduce((s, x) => s + x * x, 0) / n;
components.push(v);
eigenvalues.push(eigenvalue);
// Deflate: remove this component from data
for (let i = 0; i < n; i++) {
const proj = workingData[i].reduce((s, x, j) => s + x * v[j], 0);
for (let j = 0; j < d; j++) {
workingData[i][j] -= proj * v[j];
}
}
}
// Project data onto principal components
const projected = centered.map(vec =>
components.map(comp => vec.reduce((s, x, i) => s + x * comp[i], 0))
);
// Calculate explained variance ratio
const totalVariance = eigenvalues.reduce((s, e) => s + e, 0) || 1;
const explained = eigenvalues.map(e => (e / totalVariance) * 100);
// Normalize to [-1, 1] range
const mins = new Array(nComponents).fill(Infinity);
const maxs = new Array(nComponents).fill(-Infinity);
for (const point of projected) {
for (let i = 0; i < point.length; i++) {
mins[i] = Math.min(mins[i], point[i]);
maxs[i] = Math.max(maxs[i], point[i]);
}
}
const normalized = projected.map(point =>
point.map((v, i) => {
const range = maxs[i] - mins[i];
return range > 0 ? ((v - mins[i]) / range) * 2 - 1 : 0;
})
);
return {
projected: normalized,
variance: eigenvalues,
explained
};
}
/**
* Simple t-SNE implementation
* Based on the Barnes-Hut approximation algorithm
*/
function computeTSNE(
vectors: number[][],
nComponents: number = 2,
options: {
perplexity?: number;
learningRate?: number;
iterations?: number;
onProgress?: (iteration: number, error: number) => void;
} = {}
): number[][] {
const {
perplexity = 30,
learningRate = 200,
iterations = 500,
onProgress
} = options;
if (vectors.length === 0) return [];
const n = vectors.length;
// Compute pairwise distances
const distances: number[][] = [];
for (let i = 0; i < n; i++) {
distances[i] = [];
for (let j = 0; j < n; j++) {
let d = 0;
for (let k = 0; k < vectors[i].length; k++) {
const diff = vectors[i][k] - vectors[j][k];
d += diff * diff;
}
distances[i][j] = d;
}
}
// Compute Gaussian perplexities
const P: number[][] = [];
for (let i = 0; i < n; i++) {
P[i] = new Array(n).fill(0);
// Binary search for sigma
let sigma = 1.0;
let sigmaMin = 1e-10;
let sigmaMax = 1e10;
for (let iter = 0; iter < 50; iter++) {
let sumP = 0;
for (let j = 0; j < n; j++) {
if (i !== j) {
P[i][j] = Math.exp(-distances[i][j] / (2 * sigma * sigma));
sumP += P[i][j];
}
}
// Normalize
for (let j = 0; j < n; j++) {
P[i][j] /= sumP || 1;
}
// Compute entropy
let entropy = 0;
for (let j = 0; j < n; j++) {
if (P[i][j] > 1e-10) {
entropy -= P[i][j] * Math.log2(P[i][j]);
}
}
const perpCurrent = Math.pow(2, entropy);
if (Math.abs(perpCurrent - perplexity) < 1e-5) break;
if (perpCurrent > perplexity) {
sigmaMax = sigma;
sigma = (sigma + sigmaMin) / 2;
} else {
sigmaMin = sigma;
sigma = (sigma + sigmaMax) / 2;
}
}
}
// Symmetrize
for (let i = 0; i < n; i++) {
for (let j = i + 1; j < n; j++) {
const pij = (P[i][j] + P[j][i]) / (2 * n);
P[i][j] = pij;
P[j][i] = pij;
}
}
// Initialize embedding randomly
const Y: number[][] = [];
for (let i = 0; i < n; i++) {
Y[i] = [];
for (let d = 0; d < nComponents; d++) {
Y[i][d] = (Math.random() - 0.5) * 0.0001;
}
}
// Gradient descent
const gains: number[][] = Y.map(row => row.map(() => 1.0));
const momentum: number[][] = Y.map(row => row.map(() => 0));
for (let iter = 0; iter < iterations; iter++) {
// Compute Q distribution (Student-t with 1 DoF)
const Q: number[][] = [];
let sumQ = 0;
for (let i = 0; i < n; i++) {
Q[i] = [];
for (let j = 0; j < n; j++) {
if (i !== j) {
let d = 0;
for (let k = 0; k < nComponents; k++) {
const diff = Y[i][k] - Y[j][k];
d += diff * diff;
}
Q[i][j] = 1 / (1 + d);
sumQ += Q[i][j];
} else {
Q[i][j] = 0;
}
}
}
// Normalize Q
for (let i = 0; i < n; i++) {
for (let j = 0; j < n; j++) {
Q[i][j] /= sumQ || 1;
}
}
// Compute gradients
const grad: number[][] = Y.map(row => row.map(() => 0));
for (let i = 0; i < n; i++) {
for (let j = 0; j < n; j++) {
if (i !== j) {
const mult = 4 * (P[i][j] - Q[i][j]) * Q[i][j] * sumQ;
for (let d = 0; d < nComponents; d++) {
grad[i][d] += mult * (Y[i][d] - Y[j][d]);
}
}
}
}
// Update with momentum
const mom = iter < 250 ? 0.5 : 0.8;
for (let i = 0; i < n; i++) {
for (let d = 0; d < nComponents; d++) {
const sign = grad[i][d] * momentum[i][d] >= 0;
gains[i][d] = sign ? gains[i][d] * 0.8 : gains[i][d] + 0.2;
gains[i][d] = Math.max(gains[i][d], 0.01);
momentum[i][d] = mom * momentum[i][d] - learningRate * gains[i][d] * grad[i][d];
Y[i][d] += momentum[i][d];
}
}
// Center
const means = new Array(nComponents).fill(0);
for (let i = 0; i < n; i++) {
for (let d = 0; d < nComponents; d++) {
means[d] += Y[i][d] / n;
}
}
for (let i = 0; i < n; i++) {
for (let d = 0; d < nComponents; d++) {
Y[i][d] -= means[d];
}
}
// Report progress
if (onProgress && iter % 10 === 0) {
let error = 0;
for (let i = 0; i < n; i++) {
for (let j = 0; j < n; j++) {
if (P[i][j] > 1e-10) {
error += P[i][j] * Math.log(P[i][j] / (Q[i][j] + 1e-10));
}
}
}
onProgress(iter, error);
}
}
// Normalize to [-1, 1]
const mins = new Array(nComponents).fill(Infinity);
const maxs = new Array(nComponents).fill(-Infinity);
for (const point of Y) {
for (let i = 0; i < nComponents; i++) {
mins[i] = Math.min(mins[i], point[i]);
maxs[i] = Math.max(maxs[i], point[i]);
}
}
return Y.map(point =>
point.map((v, i) => {
const range = maxs[i] - mins[i];
return range > 0 ? ((v - mins[i]) / range) * 2 - 1 : 0;
})
);
}
/**
* Compute UMAP projection using umap-js library
*/
async function computeUMAP(
vectors: number[][],
nComponents: number = 2,
options: {
nNeighbors?: number;
minDist?: number;
spread?: number;
onProgress?: (epoch: number) => void;
} = {}
): Promise<number[][]> {
const {
nNeighbors = 15,
minDist = 0.1,
spread = 1.0,
// onProgress - available for future use
} = options;
if (vectors.length === 0) return [];
const umap = new UMAP({
nComponents,
nNeighbors: Math.min(nNeighbors, vectors.length - 1),
minDist,
spread,
});
// Fit the data
const embedding = umap.fit(vectors);
// Normalize to [-1, 1]
const mins = new Array(nComponents).fill(Infinity);
const maxs = new Array(nComponents).fill(-Infinity);
for (const point of embedding) {
for (let i = 0; i < nComponents; i++) {
mins[i] = Math.min(mins[i], point[i]);
maxs[i] = Math.max(maxs[i], point[i]);
}
}
return embedding.map(point =>
point.map((v, i) => {
const range = maxs[i] - mins[i];
return range > 0 ? ((v - mins[i]) / range) * 2 - 1 : 0;
})
);
}
/**
* Find k nearest neighbors in original space
*/
function findNearestNeighbors(
targetIndex: number,
vectors: number[][],
k: number = 10,
metric: 'euclidean' | 'cosine' = 'cosine'
): { index: number; distance: number }[] {
const target = vectors[targetIndex];
const distances: { index: number; distance: number }[] = [];
for (let i = 0; i < vectors.length; i++) {
if (i === targetIndex) continue;
let dist: number;
if (metric === 'cosine') {
// Cosine distance = 1 - cosine similarity
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let j = 0; j < target.length; j++) {
dotProduct += target[j] * vectors[i][j];
normA += target[j] * target[j];
normB += vectors[i][j] * vectors[i][j];
}
const similarity = dotProduct / (Math.sqrt(normA) * Math.sqrt(normB) || 1);
dist = 1 - similarity;
} else {
// Euclidean distance
dist = 0;
for (let j = 0; j < target.length; j++) {
const diff = target[j] - vectors[i][j];
dist += diff * diff;
}
dist = Math.sqrt(dist);
}
distances.push({ index: i, distance: dist });
}
return distances.sort((a, b) => a.distance - b.distance).slice(0, k);
}
/**
* Main Embedding Projector Component
*/
export function EmbeddingProjector({
points,
onPointSelect,
colorByField: initialColorByField,
height = 600,
width: _width,
simpleMode = false,
highlightedIndices = [],
onContextSelect,
showContextButton = false,
title: titleOverride,
}: EmbeddingProjectorProps) {
const { language } = useLanguage();
const t = (key: keyof typeof TEXT) => TEXT[key][language];
const svgRef = useRef<SVGSVGElement>(null);
const containerRef = useRef<HTMLDivElement>(null);
const threeContainerRef = useRef<HTMLDivElement>(null);
// Three.js refs
const sceneRef = useRef<THREE.Scene | null>(null);
const cameraRef = useRef<THREE.PerspectiveCamera | null>(null);
const rendererRef = useRef<THREE.WebGLRenderer | null>(null);
const controlsRef = useRef<OrbitControls | null>(null);
const pointCloudRef = useRef<THREE.Points | null>(null);
const animationFrameRef = useRef<number | null>(null);
// State
const [projectionMethod, setProjectionMethod] = useState<ProjectionMethod>('pca');
const [viewMode, setViewMode] = useState<ViewMode>('2d');
const [projectedPoints, setProjectedPoints] = useState<ProjectedPoint[]>([]);
const [isComputing, setIsComputing] = useState(false);
const [computeProgress, setComputeProgress] = useState<{ iteration: number; total: number } | null>(null);
const [selectedIndex, setSelectedIndex] = useState<number | null>(null);
const [hoveredIndex, setHoveredIndex] = useState<number | null>(null);
const [nearestNeighbors, setNearestNeighbors] = useState<{ index: number; distance: number }[]>([]);
const [searchQuery, setSearchQuery] = useState('');
const [colorByField, setColorByField] = useState(initialColorByField || '');
const [showLabels, setShowLabels] = useState(false);
const [neighborCount, setNeighborCount] = useState(10);
const [clickPosition, setClickPosition] = useState<{ x: number; y: number } | null>(null);
// UMAP parameters
const [umapNeighbors, setUmapNeighbors] = useState(15);
const [umapMinDist, setUmapMinDist] = useState(0.1);
// t-SNE parameters
const [tsnePerplexity, setTsnePerplexity] = useState(30);
const [tsneLearningRate, setTsneLearningRate] = useState(200);
// PCA variance explained
const [pcaVariance, setPcaVariance] = useState<number[]>([]);
// Extract unique payload fields
const payloadFields = useMemo(() => {
const fields = new Set<string>();
for (const point of points) {
for (const key of Object.keys(point.payload)) {
fields.add(key);
}
}
return Array.from(fields).sort();
}, [points]);
// Get unique categories for color legend
const fieldCategories = useMemo(() => {
if (!colorByField) return [];
const values = new Set<string>();
for (const point of points) {
const value = point.payload[colorByField];
if (value !== undefined && value !== null) {
values.add(String(value));
}
}
return Array.from(values).slice(0, 20);
}, [points, colorByField]);
// Filter points by search query
const filteredIndices = useMemo(() => {
if (!searchQuery.trim()) return null;
const query = searchQuery.toLowerCase();
const matches: number[] = [];
points.forEach((point, index) => {
// Search in ID
if (String(point.id).toLowerCase().includes(query)) {
matches.push(index);
return;
}
// Search in payload
for (const value of Object.values(point.payload)) {
if (String(value).toLowerCase().includes(query)) {
matches.push(index);
return;
}
}
});
return matches;
}, [points, searchQuery]);
// Get color for a point
const getPointColor = useCallback((index: number) => {
// Highlighted points (RAG sources) get a special color
if (highlightedIndices.includes(index)) {
return '#22c55e'; // Green for sources
}
if (!colorByField) return COLORS[0];
const value = String(points[index]?.payload[colorByField] ?? '');
const categoryIndex = fieldCategories.indexOf(value);
return categoryIndex >= 0 ? COLORS[categoryIndex % COLORS.length] : '#94a3b8';
}, [colorByField, fieldCategories, points, highlightedIndices]);
// Compute projection
const runProjection = useCallback(async () => {
if (points.length === 0) return;
setIsComputing(true);
setComputeProgress(null);
const vectors = points.map(p => p.vector);
const nComponents = viewMode === '3d' ? 3 : 2;
try {
let projected: number[][];
switch (projectionMethod) {
case 'pca': {
const result = computePCA(vectors, nComponents);
projected = result.projected;
setPcaVariance(result.explained);
break;
}
case 'umap': {
projected = await computeUMAP(vectors, nComponents, {
nNeighbors: umapNeighbors,
minDist: umapMinDist,
onProgress: (epoch) => setComputeProgress({ iteration: epoch, total: 200 }),
});
break;
}
case 'tsne': {
projected = computeTSNE(vectors, nComponents, {
perplexity: tsnePerplexity,
learningRate: tsneLearningRate,
iterations: 500,
onProgress: (iter) => setComputeProgress({ iteration: iter, total: 500 }),
});
break;
}
default:
projected = [];
}
setProjectedPoints(projected.map((coords, i) => ({
x: coords[0],
y: coords[1],
z: coords[2],
originalIndex: i,
})));
} finally {
setIsComputing(false);
setComputeProgress(null);
}
}, [points, projectionMethod, viewMode, umapNeighbors, umapMinDist, tsnePerplexity, tsneLearningRate]);
// Find nearest neighbors when point is selected
useEffect(() => {
if (selectedIndex !== null && points.length > 0) {
const neighbors = findNearestNeighbors(
selectedIndex,
points.map(p => p.vector),
neighborCount
);
setNearestNeighbors(neighbors);
onPointSelect?.(points[selectedIndex]);
} else {
setNearestNeighbors([]);
onPointSelect?.(null);
}
}, [selectedIndex, points, neighborCount, onPointSelect]);
// D3 visualization
useEffect(() => {
if (!svgRef.current || projectedPoints.length === 0) return;
const svg = d3.select(svgRef.current);
const containerWidth = containerRef.current?.clientWidth || 800;
const containerHeight = height;
// Clear previous content
svg.selectAll('*').remove();
// Set dimensions
svg
.attr('width', containerWidth)
.attr('height', containerHeight)
.attr('viewBox', `0 0 ${containerWidth} ${containerHeight}`);
// Create main group for zoom/pan
const g = svg.append('g');
// Setup zoom
const zoom = d3.zoom<SVGSVGElement, unknown>()
.scaleExtent([0.1, 10])
.on('zoom', (event) => {
g.attr('transform', event.transform);
});
svg.call(zoom);
// Scales
const xScale = d3.scaleLinear()
.domain([-1.2, 1.2])
.range([50, containerWidth - 50]);
const yScale = d3.scaleLinear()
.domain([-1.2, 1.2])
.range([containerHeight - 50, 50]);
// Grid
const gridGroup = g.append('g').attr('class', 'grid');
// Vertical grid lines
for (let x = -1; x <= 1; x += 0.5) {
gridGroup.append('line')
.attr('x1', xScale(x))
.attr('y1', yScale(-1.2))
.attr('x2', xScale(x))
.attr('y2', yScale(1.2))
.attr('stroke', '#e2e8f0')
.attr('stroke-width', 0.5);
}
// Horizontal grid lines
for (let y = -1; y <= 1; y += 0.5) {
gridGroup.append('line')
.attr('x1', xScale(-1.2))
.attr('y1', yScale(y))
.attr('x2', xScale(1.2))
.attr('y2', yScale(y))
.attr('stroke', '#e2e8f0')
.attr('stroke-width', 0.5);
}
// Axes
gridGroup.append('line')
.attr('x1', xScale(-1.2))
.attr('y1', yScale(0))
.attr('x2', xScale(1.2))
.attr('y2', yScale(0))
.attr('stroke', '#94a3b8')
.attr('stroke-width', 1);
gridGroup.append('line')
.attr('x1', xScale(0))
.attr('y1', yScale(-1.2))
.attr('x2', xScale(0))
.attr('y2', yScale(1.2))
.attr('stroke', '#94a3b8')
.attr('stroke-width', 1);
// Points
const pointsGroup = g.append('g').attr('class', 'points');
// Draw neighbor connections first (below points)
if (selectedIndex !== null) {
const selectedPoint = projectedPoints.find(p => p.originalIndex === selectedIndex);
if (selectedPoint) {
nearestNeighbors.forEach(neighbor => {
const neighborPoint = projectedPoints.find(p => p.originalIndex === neighbor.index);
if (neighborPoint) {
pointsGroup.append('line')
.attr('x1', xScale(selectedPoint.x))
.attr('y1', yScale(selectedPoint.y))
.attr('x2', xScale(neighborPoint.x))
.attr('y2', yScale(neighborPoint.y))
.attr('stroke', '#6366f1')
.attr('stroke-width', 1)
.attr('stroke-opacity', 0.3)
.attr('stroke-dasharray', '4,2');
}
});
}
}
// Draw points
const circles = pointsGroup.selectAll('circle')
.data(projectedPoints)
.join('circle')
.attr('cx', d => xScale(d.x))
.attr('cy', d => yScale(d.y))
.attr('r', d => {
if (d.originalIndex === selectedIndex) return 8;
if (d.originalIndex === hoveredIndex) return 6;
if (nearestNeighbors.some(n => n.index === d.originalIndex)) return 5;
if (filteredIndices && !filteredIndices.includes(d.originalIndex)) return 2;
return 4;
})
.attr('fill', d => getPointColor(d.originalIndex))
.attr('fill-opacity', d => {
if (filteredIndices && !filteredIndices.includes(d.originalIndex)) return 0.1;
if (d.originalIndex === selectedIndex) return 1;
if (nearestNeighbors.some(n => n.index === d.originalIndex)) return 0.9;
return 0.7;
})
.attr('stroke', d => d.originalIndex === selectedIndex ? '#1e293b' : 'none')
.attr('stroke-width', 2)
.attr('cursor', 'pointer')
.on('mouseenter', (_, d) => setHoveredIndex(d.originalIndex))
.on('mouseleave', () => setHoveredIndex(null))
.on('click', (event, d) => {
// Capture click position for popup positioning
setClickPosition({ x: event.clientX, y: event.clientY });
setSelectedIndex(prev => prev === d.originalIndex ? null : d.originalIndex);
});
// Labels
if (showLabels) {
const labelField = colorByField || 'id';
pointsGroup.selectAll('text')
.data(projectedPoints.filter((_, i) => i % Math.ceil(projectedPoints.length / 100) === 0))
.join('text')
.attr('x', d => xScale(d.x) + 6)
.attr('y', d => yScale(d.y) + 3)
.attr('font-size', '10px')
.attr('fill', '#475569')
.text(d => {
const point = points[d.originalIndex];
const value = labelField === 'id' ? point.id : point.payload[labelField];
const str = String(value ?? '');
return str.length > 15 ? str.slice(0, 12) + '...' : str;
});
}
// Tooltip
const tooltip = d3.select(containerRef.current)
.selectAll('.projector-tooltip')
.data([null])
.join('div')
.attr('class', 'projector-tooltip')
.style('position', 'absolute')
.style('pointer-events', 'none')
.style('background', 'white')
.style('border', '1px solid #e2e8f0')
.style('border-radius', '4px')
.style('padding', '8px')
.style('font-size', '12px')
.style('box-shadow', '0 2px 4px rgba(0,0,0,0.1)')
.style('display', 'none')
.style('z-index', '100');
circles
.on('mouseenter', function(event, d) {
const point = points[d.originalIndex];
tooltip
.style('display', 'block')
.style('left', `${event.offsetX + 10}px`)
.style('top', `${event.offsetY + 10}px`)
.html(`
<strong>ID:</strong> ${point.id}<br/>
${colorByField ? `<strong>${colorByField}:</strong> ${point.payload[colorByField]}<br/>` : ''}
`);
})
.on('mouseleave', () => {
tooltip.style('display', 'none');
});
}, [projectedPoints, selectedIndex, hoveredIndex, nearestNeighbors, filteredIndices,
showLabels, colorByField, getPointColor, points, height]);
// Three.js 3D visualization - Scene initialization (only runs when viewMode or projectedPoints change)
useEffect(() => {
if (viewMode !== '3d' || !threeContainerRef.current || projectedPoints.length === 0) {
// Cleanup if switching away from 3D
if (animationFrameRef.current) {
cancelAnimationFrame(animationFrameRef.current);
animationFrameRef.current = null;
}
// Dispose of all Three.js resources
if (pointCloudRef.current) {
const pointCloud = pointCloudRef.current;
const material = pointCloud.material as THREE.ShaderMaterial;
if (material.uniforms?.pointTexture?.value) {
material.uniforms.pointTexture.value.dispose();
}
material.dispose();
pointCloud.geometry.dispose();
pointCloudRef.current = null;
}
if (sceneRef.current) {
// Dispose all scene children
sceneRef.current.traverse((object) => {
if (object instanceof THREE.Mesh || object instanceof THREE.Points) {
if (object.geometry) object.geometry.dispose();
if (object.material) {
if (Array.isArray(object.material)) {
object.material.forEach(m => m.dispose());
} else {
object.material.dispose();
}
}
}
});
sceneRef.current.clear();
sceneRef.current = null;
}
if (controlsRef.current) {
controlsRef.current.dispose();
controlsRef.current = null;
}
if (rendererRef.current && threeContainerRef.current) {
threeContainerRef.current.removeChild(rendererRef.current.domElement);
rendererRef.current.dispose();
rendererRef.current.forceContextLoss();
rendererRef.current = null;
}
cameraRef.current = null;
return;
}
const container = threeContainerRef.current;
const containerWidth = container.clientWidth || 800;
const containerHeight = container.clientHeight || 600;
// Initialize scene
const scene = new THREE.Scene();
scene.background = new THREE.Color(0xfafafa);
sceneRef.current = scene;
// Initialize camera
const camera = new THREE.PerspectiveCamera(
60,
containerWidth / containerHeight,
0.1,
1000
);
camera.position.set(2, 2, 2);
camera.lookAt(0, 0, 0);
cameraRef.current = camera;
// Initialize renderer
const renderer = new THREE.WebGLRenderer({ antialias: true });
renderer.setSize(containerWidth, containerHeight);
renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2));
container.appendChild(renderer.domElement);
rendererRef.current = renderer;
// Add orbit controls for rotation/zoom/pan
const controls = new OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
controls.dampingFactor = 0.05;
controls.screenSpacePanning = true;
controls.minDistance = 0.5;
controls.maxDistance = 20;
controlsRef.current = controls;
// Add grid helper
const gridHelper = new THREE.GridHelper(2, 10, 0xcccccc, 0xe0e0e0);
scene.add(gridHelper);
// Add axes helper
const axesHelper = new THREE.AxesHelper(1.2);
scene.add(axesHelper);
// Create point cloud geometry
const geometry = new THREE.BufferGeometry();
const positions = new Float32Array(projectedPoints.length * 3);
const colors = new Float32Array(projectedPoints.length * 3);
const sizes = new Float32Array(projectedPoints.length);
projectedPoints.forEach((point, i) => {
positions[i * 3] = point.x;
positions[i * 3 + 1] = point.y;
positions[i * 3 + 2] = point.z ?? 0;
// Get color for point (initial color without selection)
const color = new THREE.Color(getPointColor(point.originalIndex));
colors[i * 3] = color.r;
colors[i * 3 + 1] = color.g;
colors[i * 3 + 2] = color.b;
// Initial size (will be updated by selection effect)
sizes[i] = 4;
});
geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3));
geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3));
geometry.setAttribute('size', new THREE.BufferAttribute(sizes, 1));
// Create texture (will be disposed in cleanup)
const pointTexture = createCircleTexture();
// Create shader material for variable-sized points
const material = new THREE.ShaderMaterial({
uniforms: {
pointTexture: { value: pointTexture }
},
vertexShader: `
attribute float size;
varying vec3 vColor;
void main() {
vColor = color;
vec4 mvPosition = modelViewMatrix * vec4(position, 1.0);
gl_PointSize = size * (20.0 / -mvPosition.z);
gl_Position = projectionMatrix * mvPosition;
}
`,
fragmentShader: `
uniform sampler2D pointTexture;
varying vec3 vColor;
void main() {
gl_FragColor = vec4(vColor, 1.0);
gl_FragColor = gl_FragColor * texture2D(pointTexture, gl_PointCoord);
if (gl_FragColor.a < 0.3) discard;
}
`,
vertexColors: true,
transparent: true,
});
const pointCloud = new THREE.Points(geometry, material);
scene.add(pointCloud);
pointCloudRef.current = pointCloud;
// Add ambient light
const ambientLight = new THREE.AmbientLight(0xffffff, 0.6);
scene.add(ambientLight);
// Add directional light
const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
directionalLight.position.set(5, 10, 5);
scene.add(directionalLight);
// Add raycaster for point selection
const raycaster = new THREE.Raycaster();
raycaster.params.Points = { threshold: 0.1 };
const mouse = new THREE.Vector2();
const onMouseClick = (event: MouseEvent) => {
const rect = container.getBoundingClientRect();
mouse.x = ((event.clientX - rect.left) / containerWidth) * 2 - 1;
mouse.y = -((event.clientY - rect.top) / containerHeight) * 2 + 1;
raycaster.setFromCamera(mouse, camera);
const intersects = raycaster.intersectObject(pointCloud);
if (intersects.length > 0) {
const index = intersects[0].index;
if (index !== undefined) {
const originalIndex = projectedPoints[index].originalIndex;
// Capture click position for popup positioning
setClickPosition({ x: event.clientX, y: event.clientY });
setSelectedIndex(prev => prev === originalIndex ? null : originalIndex);
}
}
};
container.addEventListener('click', onMouseClick);
// Animation loop
const animate = () => {
animationFrameRef.current = requestAnimationFrame(animate);
controls.update();
renderer.render(scene, camera);
};
animate();
// Handle resize
const handleResize = () => {
if (!container || !renderer || !camera) return;
const width = container.clientWidth;
const height = container.clientHeight;
camera.aspect = width / height;
camera.updateProjectionMatrix();
renderer.setSize(width, height);
};
window.addEventListener('resize', handleResize);
// Cleanup - comprehensive disposal of all Three.js resources
return () => {
window.removeEventListener('resize', handleResize);
container.removeEventListener('click', onMouseClick);
// Cancel animation frame first
if (animationFrameRef.current) {
cancelAnimationFrame(animationFrameRef.current);
animationFrameRef.current = null;
}
// Dispose controls
if (controls) {
controls.dispose();
}
// Dispose texture
if (pointTexture) {
pointTexture.dispose();
}
// Dispose geometry and material
if (geometry) geometry.dispose();
if (material) material.dispose();
// Clear scene (removes all objects)
if (scene) {
scene.traverse((object) => {
if (object instanceof THREE.Mesh || object instanceof THREE.Points) {
if (object.geometry) object.geometry.dispose();
if (object.material) {
if (Array.isArray(object.material)) {
object.material.forEach(m => m.dispose());
} else {
object.material.dispose();
}
}
}
});
scene.clear();
}
// Dispose renderer and force context loss to free WebGL context
if (renderer) {
container.removeChild(renderer.domElement);
renderer.dispose();
renderer.forceContextLoss();
}
// Clear refs
sceneRef.current = null;
cameraRef.current = null;
rendererRef.current = null;
controlsRef.current = null;
pointCloudRef.current = null;
};
}, [viewMode, projectedPoints, getPointColor]);
// Update point sizes and colors when selection changes (without recreating the scene)
useEffect(() => {
if (viewMode !== '3d' || !pointCloudRef.current || projectedPoints.length === 0) {
return;
}
const geometry = pointCloudRef.current.geometry;
const sizes = geometry.getAttribute('size') as THREE.BufferAttribute;
const colors = geometry.getAttribute('color') as THREE.BufferAttribute;
if (!sizes || !colors) return;
projectedPoints.forEach((point, i) => {
// Update color
const color = new THREE.Color(getPointColor(point.originalIndex));
colors.setXYZ(i, color.r, color.g, color.b);
// Update size based on selection
if (point.originalIndex === selectedIndex) {
sizes.setX(i, 12);
} else if (nearestNeighbors.some(n => n.index === point.originalIndex)) {
sizes.setX(i, 8);
} else if (filteredIndices && !filteredIndices.includes(point.originalIndex)) {
sizes.setX(i, 2);
} else {
sizes.setX(i, 4);
}
});
sizes.needsUpdate = true;
colors.needsUpdate = true;
}, [viewMode, projectedPoints, selectedIndex, nearestNeighbors, filteredIndices, getPointColor]);
// Helper function to create circle texture for points
function createCircleTexture(): THREE.Texture {
const canvas = document.createElement('canvas');
canvas.width = 64;
canvas.height = 64;
const ctx = canvas.getContext('2d')!;
const gradient = ctx.createRadialGradient(32, 32, 0, 32, 32, 32);
gradient.addColorStop(0, 'rgba(255, 255, 255, 1)');
gradient.addColorStop(0.3, 'rgba(255, 255, 255, 1)');
gradient.addColorStop(0.5, 'rgba(255, 255, 255, 0.8)');
gradient.addColorStop(1, 'rgba(255, 255, 255, 0)');
ctx.fillStyle = gradient;
ctx.fillRect(0, 0, 64, 64);
const texture = new THREE.CanvasTexture(canvas);
texture.needsUpdate = true;
return texture;
}
return (
<div className={`embedding-projector ${simpleMode ? 'embedding-projector--simple' : ''}`} ref={containerRef}>
{/* Header */}
<div className="projector-header">
<h3>{titleOverride || t('title')}</h3>
{!simpleMode && <p className="projector-description">{t('description')}</p>}
<div className="projector-stats">
<span>{points.length.toLocaleString()} {t('pointsLoaded')}</span>
{points[0] && !simpleMode && <span>{points[0].vector.length} {t('dimensions')}</span>}
{highlightedIndices.length > 0 && (
<span className="projector-stats__highlighted">
{highlightedIndices.length} {t('sourcesHighlighted')}
</span>
)}
</div>
</div>
{/* Controls */}
{!simpleMode && (
<div className="projector-controls">
{/* Left: Method & View */}
<div className="control-section">
<div className="control-group">
<label>{t('projectionMethod')}</label>
<div className="button-group">
<button
className={projectionMethod === 'pca' ? 'active' : ''}
onClick={() => setProjectionMethod('pca')}
>
PCA
</button>
<button
className={projectionMethod === 'umap' ? 'active' : ''}
onClick={() => setProjectionMethod('umap')}
>
UMAP
</button>
<button
className={projectionMethod === 'tsne' ? 'active' : ''}
onClick={() => setProjectionMethod('tsne')}
>
t-SNE
</button>
</div>
</div>
<div className="control-group">
<label>{t('viewMode')}</label>
<div className="button-group">
<button
className={viewMode === '2d' ? 'active' : ''}
onClick={() => setViewMode('2d')}
>
2D
</button>
<button
className={viewMode === '3d' ? 'active' : ''}
onClick={() => setViewMode('3d')}
>
3D
</button>
</div>
</div>
</div>
{/* Method-specific parameters */}
<div className="control-section parameters">
{projectionMethod === 'umap' && (
<>
<div className="control-group small">
<label>{t('umapNeighbors')}</label>
<input
type="range"
min="5"
max="50"
value={umapNeighbors}
onChange={(e) => setUmapNeighbors(Number(e.target.value))}
/>
<span className="value">{umapNeighbors}</span>
</div>
<div className="control-group small">
<label>{t('minDist')}</label>
<input
type="range"
min="0"
max="1"
step="0.05"
value={umapMinDist}
onChange={(e) => setUmapMinDist(Number(e.target.value))}
/>
<span className="value">{umapMinDist.toFixed(2)}</span>
</div>
</>
)}
{projectionMethod === 'tsne' && (
<>
<div className="control-group small">
<label>{t('perplexity')}</label>
<input
type="range"
min="5"
max="50"
value={tsnePerplexity}
onChange={(e) => setTsnePerplexity(Number(e.target.value))}
/>
<span className="value">{tsnePerplexity}</span>
</div>
<div className="control-group small">
<label>{t('learningRate')}</label>
<input
type="range"
min="10"
max="500"
value={tsneLearningRate}
onChange={(e) => setTsneLearningRate(Number(e.target.value))}
/>
<span className="value">{tsneLearningRate}</span>
</div>
</>
)}
{projectionMethod === 'pca' && pcaVariance.length > 0 && (
<div className="pca-variance">
<span>{t('variance')}: </span>
{pcaVariance.map((v, i) => (
<span key={i} className="variance-badge">
PC{i + 1}: {v.toFixed(1)}%
</span>
))}
</div>
)}
</div>
{/* Right: Color & Search */}
<div className="control-section">
<div className="control-group">
<label>{t('colorBy')}</label>
<select
value={colorByField}
onChange={(e) => setColorByField(e.target.value)}
>
<option value="">{t('noField')}</option>
{payloadFields.map(field => (
<option key={field} value={field}>{field}</option>
))}
</select>
</div>
<div className="control-group">
<label>{t('search')}</label>
<input
type="text"
placeholder={t('search')}
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
className="search-input"
/>
</div>
<div className="control-group checkbox">
<label>
<input
type="checkbox"
checked={showLabels}
onChange={(e) => setShowLabels(e.target.checked)}
/>
{t('showLabels')}
</label>
</div>
</div>
{/* Run button */}
<div className="control-section actions">
<button
className="run-button"
onClick={runProjection}
disabled={isComputing || points.length === 0}
>
{isComputing ? (
<>
{t('computing')}
{computeProgress && ` (${computeProgress.iteration}/${computeProgress.total})`}
</>
) : (
t('run')
)}
</button>
</div>
</div>
)}
{/* Simple mode: just a run button */}
{simpleMode && projectedPoints.length === 0 && points.length > 0 && (
<div className="projector-controls projector-controls--simple">
<button
className="run-button"
onClick={runProjection}
disabled={isComputing}
>
{isComputing ? t('computing') : t('run')}
</button>
</div>
)}
{/* Main visualization area */}
<div className="projector-main">
{/* Canvas */}
<div className="projector-canvas">
{projectedPoints.length > 0 ? (
viewMode === '2d' ? (
<svg ref={svgRef} />
) : (
<div ref={threeContainerRef} className="three-container" />
)
) : (
<div className="projector-placeholder">
<p>
{points.length > 0
? 'Click "Run" to compute projection'
: 'Load vectors to visualize'
}
</p>
</div>
)}
</div>
{/* Sidebar - legend only (point details moved to floating panel) */}
<div className="projector-sidebar">
{/* Legend */}
{colorByField && fieldCategories.length > 0 && (
<div className="projector-legend">
<h4>{colorByField}</h4>
<div className="legend-items">
{fieldCategories.map((value, idx) => (
<div key={value} className="legend-item">
<span
className="legend-color"
style={{ backgroundColor: COLORS[idx % COLORS.length] }}
/>
<span className="legend-label">
{value.length > 20 ? value.slice(0, 17) + '...' : value}
</span>
</div>
))}
</div>
</div>
)}
</div>
</div>
{/* Floating Point Details Panel */}
{selectedIndex !== null && points[selectedIndex] && (
<PointDetailsPanel
point={points[selectedIndex]}
nearestNeighbors={nearestNeighbors}
allPoints={points}
neighborCount={neighborCount}
onNeighborCountChange={setNeighborCount}
onClose={() => setSelectedIndex(null)}
onNeighborClick={(index) => {
setSelectedIndex(index);
// Update click position to current panel position (approx)
setClickPosition(prev => prev || { x: window.innerWidth / 2, y: window.innerHeight / 2 });
}}
clickPosition={clickPosition || undefined}
language={language as 'nl' | 'en'}
onAddToContext={onContextSelect ? (point) => onContextSelect(point) : undefined}
showContextButton={showContextButton}
/>
)}
</div>
);
}
export default EmbeddingProjector;