Skip to content

Commit

Permalink
use matrix API to bootstrap graph view (#216)
Browse files Browse the repository at this point in the history
* use matrix API to bootstrap graph view

* add score to links (but do not use it yet)

* spanning tree (#225)

* Fix graph (#226)

* tests for getMinimalSpanningTree

* fixes

* canvas resize fix

* Update src/lib/graph-visualization-helpers.js

---------

Co-authored-by: trean <[email protected]>
  • Loading branch information
generall and trean authored Oct 4, 2024
1 parent ccfbc4f commit a65daf1
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 14 deletions.
33 changes: 29 additions & 4 deletions src/components/GraphVisualisation/GraphVisualisation.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import { deduplicatePoints, getSimilarPoints, initGraph } from '../../lib/graph-
import ForceGraph from 'force-graph';
import { useClient } from '../../context/client-context';
import { useSnackbar } from 'notistack';
import { debounce } from 'lodash';
import { resizeObserverWithCallback } from '../../lib/common-helpers';

const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) => {
const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef, sampleLinks }) => {
const graphRef = useRef(null);
const { client: qdrantClient } = useClient();
const { enqueueSnackbar } = useSnackbar();
Expand Down Expand Up @@ -53,7 +55,9 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) =>
onDataDisplay(node);
})
.autoPauseRedraw(false)
.nodeCanvasObjectMode((node) => (node?.id === highlightedNode?.id ? 'before' : undefined))
.nodeCanvasObjectMode((node) => {
return node?.id === highlightedNode?.id ? 'before' : undefined;
})
.nodeCanvasObject((node, ctx) => {
if (!node) return;
// add ring for last hovered nodes
Expand All @@ -62,18 +66,33 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) =>
ctx.fillStyle = node.id === highlightedNode?.id ? '#817' : 'transparent';
ctx.fill();
})
.linkLabel('score')
.linkColor(() => '#a6a6a6');

graphRef.current.d3Force('charge').strength(-10);
}, [initNode, options]);

useEffect(() => {
if (!wrapperRef) return;

const debouncedResizeCallback = debounce((width, height) => {
graphRef.current.width(width).height(height);
}, 500);

graphRef.current.width(wrapperRef?.clientWidth).height(wrapperRef?.clientHeight);
resizeObserverWithCallback(debouncedResizeCallback).observe(wrapperRef);

return () => {
resizeObserverWithCallback(debouncedResizeCallback).unobserve(wrapperRef);
};
}, [wrapperRef, initNode, options]);

useEffect(() => {
const initNewGraph = async () => {
const graphData = await initGraph(qdrantClient, {
...options,
initNode,
sampleLinks,
});
if (graphRef.current && options) {
const initialActiveNode = graphData.nodes[0];
Expand All @@ -83,9 +102,14 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) =>
}
};
initNewGraph().catch((e) => {
enqueueSnackbar(JSON.stringify(e.getActualType()), { variant: 'error' });
console.error(e);
if (e.getActualType) {
enqueueSnackbar(JSON.stringify(e.getActualType()), { variant: 'error' });
} else {
enqueueSnackbar(e.message, { variant: 'error' });
}
});
}, [initNode, options]);
}, [initNode, options, sampleLinks]);

return <div id="graph"></div>;
};
Expand All @@ -95,6 +119,7 @@ GraphVisualisation.propTypes = {
options: PropTypes.object.isRequired,
onDataDisplay: PropTypes.func.isRequired,
wrapperRef: PropTypes.object,
sampleLinks: PropTypes.array,
};

export default GraphVisualisation;
9 changes: 9 additions & 0 deletions src/lib/common-helpers.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export const resizeObserverWithCallback = (callback) => {
return new ResizeObserver((entries) => {
for (const entry of entries) {
const { target } = entry;
const { width, height } = target.getBoundingClientRect();
if (typeof callback === 'function') callback(width, height);
}
});
};
124 changes: 117 additions & 7 deletions src/lib/graph-visualization-helpers.js
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
export const initGraph = async (qdrantClient, { collectionName, initNode, limit, filter, using }) => {
if (!initNode) {
import { axiosInstance } from '../common/axios';

export const initGraph = async (
qdrantClient,
{ collectionName, initNode, limit, filter, using, sampleLinks, tree = false }
) => {
let nodes = [];
let links = [];

if (sampleLinks) {
const uniquePoints = new Set();

for (const link of sampleLinks) {
links.push({ source: link.a, target: link.b, score: link.score });
uniquePoints.add(link.a);
uniquePoints.add(link.b);
}

if (tree) {
// ToDo acs should depend on metric type
links = getMinimalSpanningTree(links, true);
}

nodes = await getPointsWithPayload(qdrantClient, { collectionName, pointIds: Array.from(uniquePoints) });
} else if (initNode) {
initNode.clicked = true;
nodes = await getSimilarPoints(qdrantClient, { collectionName, pointId: initNode.id, limit, filter, using });
links = nodes.map((point) => ({ source: initNode.id, target: point.id, score: point.score }));
nodes = [initNode, ...nodes];
} else {
return {
nodes: [],
links: [],
};
}
initNode.clicked = true;

const points = await getSimilarPoints(qdrantClient, { collectionName, pointId: initNode.id, limit, filter, using });

const graphData = {
nodes: [initNode, ...points],
links: points.map((point) => ({ source: initNode.id, target: point.id })),
nodes,
links,
};
return graphData;
};
Expand Down Expand Up @@ -44,9 +69,94 @@ export const getFirstPoint = async (qdrantClient, { collectionName, filter }) =>
return points[0];
};

const getPointsWithPayload = async (qdrantClient, { collectionName, pointIds }) => {
const points = await qdrantClient.retrieve(collectionName, {
ids: pointIds,
with_payload: true,
with_vector: false,
});

return points;
};

export const getSamplePoints = async ({ collectionName, filter, sample, using, limit }) => {
// ToDo: replace it with qdrantClient when it will be implemented

const response = await axiosInstance({
method: 'POST',
url: `collections/${collectionName}/points/search/matrix/pairs`,
data: {
filter,
sample,
using,
limit,
},
});

return response.data.result.pairs;
};

export const deduplicatePoints = (existingPoints, foundPoints) => {
// Returns array of found points that are not in existing points
// deduplication is done by id
const existingIds = new Set(existingPoints.map((point) => point.id));
return foundPoints.filter((point) => !existingIds.has(point.id));
};

export const getMinimalSpanningTree = (links, acs = true) => {
// Sort links by score (assuming each link has a score property)

let sortedLinks = [];
if (acs) {
sortedLinks = links.sort((a, b) => b.score - a.score);
} else {
sortedLinks = links.sort((a, b) => a.score - b.score);
}
// Helper function to find the root of a node
const findRoot = (parent, i) => {
if (parent[i] === i) {
return i;
}
return findRoot(parent, parent[i]);
};

// Helper function to perform union of two sets
const union = (parent, rank, x, y) => {
const rootX = findRoot(parent, x);
const rootY = findRoot(parent, y);

if (rank[rootX] < rank[rootY]) {
parent[rootX] = rootY;
} else if (rank[rootX] > rank[rootY]) {
parent[rootY] = rootX;
} else {
parent[rootY] = rootX;
rank[rootX]++;
}
};

const parent = {};
const rank = {};
const mstLinks = [];

// Initialize parent and rank arrays
links.forEach((link) => {
parent[link.source] = link.source;
parent[link.target] = link.target;
rank[link.source] = 0;
rank[link.target] = 0;
});

// Kruskal's algorithm
sortedLinks.forEach((link) => {
const sourceRoot = findRoot(parent, link.source);
const targetRoot = findRoot(parent, link.target);

if (sourceRoot !== targetRoot) {
mstLinks.push(link);
union(parent, rank, sourceRoot, targetRoot);
}
});

return mstLinks;
};
56 changes: 56 additions & 0 deletions src/lib/tests/graph-visualization-helpers.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import { describe, it, expect } from 'vitest';
import { getMinimalSpanningTree } from '../graph-visualization-helpers';

describe('getMinimalSpanningTree', () => {
it('should return the minimal spanning tree for a given set of links (ascending order)', () => {
const links = [
{ source: 'A', target: 'B', score: 1 },
{ source: 'B', target: 'C', score: 2 },
{ source: 'A', target: 'C', score: 3 },
{ source: 'C', target: 'D', score: 4 },
{ source: 'B', target: 'D', score: 5 },
];

const expectedMST = [
{ source: 'B', target: 'D', score: 5 },
{ source: 'C', target: 'D', score: 4 },
{ source: 'A', target: 'C', score: 3 },
];

const result = getMinimalSpanningTree(links, true);
expect(result).toEqual(expectedMST);
});

it('should return the minimal spanning tree for a given set of links (descending order)', () => {
const links = [
{ source: 'A', target: 'B', score: 1 },
{ source: 'B', target: 'C', score: 2 },
{ source: 'A', target: 'C', score: 3 },
{ source: 'C', target: 'D', score: 4 },
{ source: 'B', target: 'D', score: 5 },
];

const expectedMST = [
{ source: 'A', target: 'B', score: 1 },
{ source: 'B', target: 'C', score: 2 },
{ source: 'C', target: 'D', score: 4 },
];

const result = getMinimalSpanningTree(links, false);
expect(result).toEqual(expectedMST);
});

it('should return an empty array if no links are provided', () => {
const links = [];
const expectedMST = [];
const result = getMinimalSpanningTree(links, true);
expect(result).toEqual(expectedMST);
});

it('should handle a single link correctly', () => {
const links = [{ source: 'A', target: 'B', score: 1 }];
const expectedMST = [{ source: 'A', target: 'B', score: 1 }];
const result = getMinimalSpanningTree(links, true);
expect(result).toEqual(expectedMST);
});
});
31 changes: 28 additions & 3 deletions src/pages/Graph.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { useWindowResize } from '../hooks/windowHooks';
import PointPreview from '../components/Common/PointPreview';
import CodeEditorWindow from '../components/FilterEditorWindow';
import { useClient } from '../context/client-context';
import { getFirstPoint } from '../lib/graph-visualization-helpers';
import { getFirstPoint, getSamplePoints } from '../lib/graph-visualization-helpers';
import { useSnackbar } from 'notistack';

const explanation = `
Expand All @@ -19,12 +19,15 @@ const explanation = `
// Available parameters:
//
// - 'limit': number of records to use on each step.
// - 'sample': bootstrap graph with sample data from collection.
//
// - 'filter': filter expression to select vectors for visualization.
// See https://qdrant.tech/documentation/concepts/filtering/
//
// - 'using': specify which vector to use for visualization
// if there are multiple.
//
// - 'tree': if true, will use show spanning tree instead of full graph.
`;

Expand All @@ -45,6 +48,8 @@ function Graph() {
const location = useLocation();
const { newInitNode, vectorName } = location.state || {};
const [initNode, setInitNode] = useState(null);
const [sampleLinks, setSampleLinks] = useState(null);

const [options, setOptions] = useState({
limit: 5,
filter: null,
Expand Down Expand Up @@ -92,8 +97,17 @@ function Graph() {
const handleRunCode = async (data, collectionName) => {
// scroll
try {
const firstPoint = await getFirstPoint(qdrantClient, { collectionName: collectionName, filter: data?.filter });
setInitNode(firstPoint);
if (data.sample) {
const sampleLinks = await getSamplePoints({
collectionName: collectionName,
...data,
});
setSampleLinks(sampleLinks);
setInitNode(null);
} else {
const firstPoint = await getFirstPoint(qdrantClient, { collectionName: collectionName, filter: data?.filter });
setInitNode(firstPoint);
}
setOptions({
collectionName: collectionName,
...data,
Expand Down Expand Up @@ -130,6 +144,16 @@ function Graph() {
type: 'string',
enum: vectorNames,
},
sample: {
description: 'Bootstrap graph with sample data from collection',
type: 'integer',
nullable: true,
},
tree: {
description: 'Show spanning tree instead of full graph',
type: 'boolean',
nullable: true,
},
},
});

Expand Down Expand Up @@ -170,6 +194,7 @@ function Graph() {
initNode={initNode}
onDataDisplay={handlePointDisplay}
wrapperRef={VisualizeChartWrapper.current}
sampleLinks={sampleLinks}
/>
</Box>
</Box>
Expand Down

0 comments on commit a65daf1

Please sign in to comment.