diff --git a/src/components/GraphVisualisation/GraphVisualisation.jsx b/src/components/GraphVisualisation/GraphVisualisation.jsx index 4d6d0acc..afb84be8 100644 --- a/src/components/GraphVisualisation/GraphVisualisation.jsx +++ b/src/components/GraphVisualisation/GraphVisualisation.jsx @@ -4,8 +4,10 @@ import { deduplicatePoints, getSimilarPoints, initGraph } from '../../lib/graph- import ForceGraph from 'force-graph'; import { useClient } from '../../context/client-context'; import { useSnackbar } from 'notistack'; +import { debounce } from 'lodash'; +import { resizeObserverWithCallback } from '../../lib/common-helpers'; -const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) => { +const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef, sampleLinks }) => { const graphRef = useRef(null); const { client: qdrantClient } = useClient(); const { enqueueSnackbar } = useSnackbar(); @@ -53,7 +55,9 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) => onDataDisplay(node); }) .autoPauseRedraw(false) - .nodeCanvasObjectMode((node) => (node?.id === highlightedNode?.id ? 'before' : undefined)) + .nodeCanvasObjectMode((node) => { + return node?.id === highlightedNode?.id ? 'before' : undefined; + }) .nodeCanvasObject((node, ctx) => { if (!node) return; // add ring for last hovered nodes @@ -62,11 +66,25 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) => ctx.fillStyle = node.id === highlightedNode?.id ? '#817' : 'transparent'; ctx.fill(); }) + .linkLabel('score') .linkColor(() => '#a6a6a6'); + + graphRef.current.d3Force('charge').strength(-10); }, [initNode, options]); useEffect(() => { + if (!wrapperRef) return; + + const debouncedResizeCallback = debounce((width, height) => { + graphRef.current.width(width).height(height); + }, 500); + graphRef.current.width(wrapperRef?.clientWidth).height(wrapperRef?.clientHeight); + resizeObserverWithCallback(debouncedResizeCallback).observe(wrapperRef); + + return () => { + resizeObserverWithCallback(debouncedResizeCallback).unobserve(wrapperRef); + }; }, [wrapperRef, initNode, options]); useEffect(() => { @@ -74,6 +92,7 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) => const graphData = await initGraph(qdrantClient, { ...options, initNode, + sampleLinks, }); if (graphRef.current && options) { const initialActiveNode = graphData.nodes[0]; @@ -83,9 +102,14 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) => } }; initNewGraph().catch((e) => { - enqueueSnackbar(JSON.stringify(e.getActualType()), { variant: 'error' }); + console.error(e); + if (e.getActualType) { + enqueueSnackbar(JSON.stringify(e.getActualType()), { variant: 'error' }); + } else { + enqueueSnackbar(e.message, { variant: 'error' }); + } }); - }, [initNode, options]); + }, [initNode, options, sampleLinks]); return
; }; @@ -95,6 +119,7 @@ GraphVisualisation.propTypes = { options: PropTypes.object.isRequired, onDataDisplay: PropTypes.func.isRequired, wrapperRef: PropTypes.object, + sampleLinks: PropTypes.array, }; export default GraphVisualisation; diff --git a/src/lib/common-helpers.js b/src/lib/common-helpers.js new file mode 100644 index 00000000..d6ad6b9c --- /dev/null +++ b/src/lib/common-helpers.js @@ -0,0 +1,9 @@ +export const resizeObserverWithCallback = (callback) => { + return new ResizeObserver((entries) => { + for (const entry of entries) { + const { target } = entry; + const { width, height } = target.getBoundingClientRect(); + if (typeof callback === 'function') callback(width, height); + } + }); +}; diff --git a/src/lib/graph-visualization-helpers.js b/src/lib/graph-visualization-helpers.js index 4efed78c..6196fbc6 100644 --- a/src/lib/graph-visualization-helpers.js +++ b/src/lib/graph-visualization-helpers.js @@ -1,17 +1,42 @@ -export const initGraph = async (qdrantClient, { collectionName, initNode, limit, filter, using }) => { - if (!initNode) { +import { axiosInstance } from '../common/axios'; + +export const initGraph = async ( + qdrantClient, + { collectionName, initNode, limit, filter, using, sampleLinks, tree = false } +) => { + let nodes = []; + let links = []; + + if (sampleLinks) { + const uniquePoints = new Set(); + + for (const link of sampleLinks) { + links.push({ source: link.a, target: link.b, score: link.score }); + uniquePoints.add(link.a); + uniquePoints.add(link.b); + } + + if (tree) { + // ToDo acs should depend on metric type + links = getMinimalSpanningTree(links, true); + } + + nodes = await getPointsWithPayload(qdrantClient, { collectionName, pointIds: Array.from(uniquePoints) }); + } else if (initNode) { + initNode.clicked = true; + nodes = await getSimilarPoints(qdrantClient, { collectionName, pointId: initNode.id, limit, filter, using }); + links = nodes.map((point) => ({ source: initNode.id, target: point.id, score: point.score })); + nodes = [initNode, ...nodes]; + } else { return { nodes: [], links: [], }; } - initNode.clicked = true; - - const points = await getSimilarPoints(qdrantClient, { collectionName, pointId: initNode.id, limit, filter, using }); const graphData = { - nodes: [initNode, ...points], - links: points.map((point) => ({ source: initNode.id, target: point.id })), + nodes, + links, }; return graphData; }; @@ -44,9 +69,94 @@ export const getFirstPoint = async (qdrantClient, { collectionName, filter }) => return points[0]; }; +const getPointsWithPayload = async (qdrantClient, { collectionName, pointIds }) => { + const points = await qdrantClient.retrieve(collectionName, { + ids: pointIds, + with_payload: true, + with_vector: false, + }); + + return points; +}; + +export const getSamplePoints = async ({ collectionName, filter, sample, using, limit }) => { + // ToDo: replace it with qdrantClient when it will be implemented + + const response = await axiosInstance({ + method: 'POST', + url: `collections/${collectionName}/points/search/matrix/pairs`, + data: { + filter, + sample, + using, + limit, + }, + }); + + return response.data.result.pairs; +}; + export const deduplicatePoints = (existingPoints, foundPoints) => { // Returns array of found points that are not in existing points // deduplication is done by id const existingIds = new Set(existingPoints.map((point) => point.id)); return foundPoints.filter((point) => !existingIds.has(point.id)); }; + +export const getMinimalSpanningTree = (links, acs = true) => { + // Sort links by score (assuming each link has a score property) + + let sortedLinks = []; + if (acs) { + sortedLinks = links.sort((a, b) => b.score - a.score); + } else { + sortedLinks = links.sort((a, b) => a.score - b.score); + } + // Helper function to find the root of a node + const findRoot = (parent, i) => { + if (parent[i] === i) { + return i; + } + return findRoot(parent, parent[i]); + }; + + // Helper function to perform union of two sets + const union = (parent, rank, x, y) => { + const rootX = findRoot(parent, x); + const rootY = findRoot(parent, y); + + if (rank[rootX] < rank[rootY]) { + parent[rootX] = rootY; + } else if (rank[rootX] > rank[rootY]) { + parent[rootY] = rootX; + } else { + parent[rootY] = rootX; + rank[rootX]++; + } + }; + + const parent = {}; + const rank = {}; + const mstLinks = []; + + // Initialize parent and rank arrays + links.forEach((link) => { + parent[link.source] = link.source; + parent[link.target] = link.target; + rank[link.source] = 0; + rank[link.target] = 0; + }); + + // Kruskal's algorithm + sortedLinks.forEach((link) => { + const sourceRoot = findRoot(parent, link.source); + const targetRoot = findRoot(parent, link.target); + + if (sourceRoot !== targetRoot) { + mstLinks.push(link); + union(parent, rank, sourceRoot, targetRoot); + } + }); + + return mstLinks; +}; diff --git a/src/lib/tests/graph-visualization-helpers.test.js b/src/lib/tests/graph-visualization-helpers.test.js new file mode 100644 index 00000000..ce728195 --- /dev/null +++ b/src/lib/tests/graph-visualization-helpers.test.js @@ -0,0 +1,56 @@ +import { describe, it, expect } from 'vitest'; +import { getMinimalSpanningTree } from '../graph-visualization-helpers'; + +describe('getMinimalSpanningTree', () => { + it('should return the minimal spanning tree for a given set of links (ascending order)', () => { + const links = [ + { source: 'A', target: 'B', score: 1 }, + { source: 'B', target: 'C', score: 2 }, + { source: 'A', target: 'C', score: 3 }, + { source: 'C', target: 'D', score: 4 }, + { source: 'B', target: 'D', score: 5 }, + ]; + + const expectedMST = [ + { source: 'B', target: 'D', score: 5 }, + { source: 'C', target: 'D', score: 4 }, + { source: 'A', target: 'C', score: 3 }, + ]; + + const result = getMinimalSpanningTree(links, true); + expect(result).toEqual(expectedMST); + }); + + it('should return the minimal spanning tree for a given set of links (descending order)', () => { + const links = [ + { source: 'A', target: 'B', score: 1 }, + { source: 'B', target: 'C', score: 2 }, + { source: 'A', target: 'C', score: 3 }, + { source: 'C', target: 'D', score: 4 }, + { source: 'B', target: 'D', score: 5 }, + ]; + + const expectedMST = [ + { source: 'A', target: 'B', score: 1 }, + { source: 'B', target: 'C', score: 2 }, + { source: 'C', target: 'D', score: 4 }, + ]; + + const result = getMinimalSpanningTree(links, false); + expect(result).toEqual(expectedMST); + }); + + it('should return an empty array if no links are provided', () => { + const links = []; + const expectedMST = []; + const result = getMinimalSpanningTree(links, true); + expect(result).toEqual(expectedMST); + }); + + it('should handle a single link correctly', () => { + const links = [{ source: 'A', target: 'B', score: 1 }]; + const expectedMST = [{ source: 'A', target: 'B', score: 1 }]; + const result = getMinimalSpanningTree(links, true); + expect(result).toEqual(expectedMST); + }); +}); diff --git a/src/pages/Graph.jsx b/src/pages/Graph.jsx index 4ee6b6d5..49dc909f 100644 --- a/src/pages/Graph.jsx +++ b/src/pages/Graph.jsx @@ -9,7 +9,7 @@ import { useWindowResize } from '../hooks/windowHooks'; import PointPreview from '../components/Common/PointPreview'; import CodeEditorWindow from '../components/FilterEditorWindow'; import { useClient } from '../context/client-context'; -import { getFirstPoint } from '../lib/graph-visualization-helpers'; +import { getFirstPoint, getSamplePoints } from '../lib/graph-visualization-helpers'; import { useSnackbar } from 'notistack'; const explanation = ` @@ -19,12 +19,15 @@ const explanation = ` // Available parameters: // // - 'limit': number of records to use on each step. +// - 'sample': bootstrap graph with sample data from collection. // // - 'filter': filter expression to select vectors for visualization. // See https://qdrant.tech/documentation/concepts/filtering/ // // - 'using': specify which vector to use for visualization // if there are multiple. +// +// - 'tree': if true, will use show spanning tree instead of full graph. `; @@ -45,6 +48,8 @@ function Graph() { const location = useLocation(); const { newInitNode, vectorName } = location.state || {}; const [initNode, setInitNode] = useState(null); + const [sampleLinks, setSampleLinks] = useState(null); + const [options, setOptions] = useState({ limit: 5, filter: null, @@ -92,8 +97,17 @@ function Graph() { const handleRunCode = async (data, collectionName) => { // scroll try { - const firstPoint = await getFirstPoint(qdrantClient, { collectionName: collectionName, filter: data?.filter }); - setInitNode(firstPoint); + if (data.sample) { + const sampleLinks = await getSamplePoints({ + collectionName: collectionName, + ...data, + }); + setSampleLinks(sampleLinks); + setInitNode(null); + } else { + const firstPoint = await getFirstPoint(qdrantClient, { collectionName: collectionName, filter: data?.filter }); + setInitNode(firstPoint); + } setOptions({ collectionName: collectionName, ...data, @@ -130,6 +144,16 @@ function Graph() { type: 'string', enum: vectorNames, }, + sample: { + description: 'Bootstrap graph with sample data from collection', + type: 'integer', + nullable: true, + }, + tree: { + description: 'Show spanning tree instead of full graph', + type: 'boolean', + nullable: true, + }, }, }); @@ -170,6 +194,7 @@ function Graph() { initNode={initNode} onDataDisplay={handlePointDisplay} wrapperRef={VisualizeChartWrapper.current} + sampleLinks={sampleLinks} />