diff --git a/graph/dijkstra.py b/graph/dijkstra.py new file mode 100644 index 0000000..ae84c0f --- /dev/null +++ b/graph/dijkstra.py @@ -0,0 +1,49 @@ +from typing import Dict, List, Tuple +import heapq + +def dijkstra(graph: Dict[int, List[Tuple[int, int]]], start: int) -> Dict[int, int]: + """ + Implements Dijkstra's algorithm for finding the shortest path in a graph. + + Args: + graph (Dict[int, List[Tuple[int, int]]]): A dictionary representing the graph. + Keys are nodes, values are lists of (neighbor, weight) tuples. + start (int): The starting node. + + Returns: + Dict[int, int]: A dictionary with nodes as keys and shortest distances from start as values. + """ + distances = {node: float('infinity') for node in graph} + distances[start] = 0 + pq = [(0, start)] + + while pq: + current_distance, current_node = heapq.heappop(pq) + + if current_distance > distances[current_node]: + continue + + for neighbor, weight in graph[current_node]: + distance = current_distance + weight + if distance < distances[neighbor]: + distances[neighbor] = distance + heapq.heappush(pq, (distance, neighbor)) + + return distances + +# Example usage +if __name__ == "__main__": + # Example graph + graph = { + 0: [(1, 4), (2, 1)], + 1: [(3, 1)], + 2: [(1, 2), (3, 5)], + 3: [(4, 3)], + 4: [] + } + + start_node = 0 + shortest_paths = dijkstra(graph, start_node) + print(f"Shortest paths from node {start_node}:") + for node, distance in shortest_paths.items(): + print(f"To node {node}: {distance}") diff --git a/graph/test_dijkstra.py b/graph/test_dijkstra.py new file mode 100644 index 0000000..5372bf4 --- /dev/null +++ b/graph/test_dijkstra.py @@ -0,0 +1,54 @@ +import unittest +from dijkstra import dijkstra + +class TestDijkstra(unittest.TestCase): + def test_simple_graph(self): + graph = { + 0: [(1, 4), (2, 1)], + 1: [(3, 1)], + 2: [(1, 2), (3, 5)], + 3: [(4, 3)], + 4: [] + } + start_node = 0 + expected = {0: 0, 1: 3, 2: 1, 3: 4, 4: 7} + self.assertEqual(dijkstra(graph, start_node), expected) + + def test_disconnected_graph(self): + graph = { + 0: [(1, 1)], + 1: [(0, 1)], + 2: [(3, 1)], + 3: [(2, 1)] + } + start_node = 0 + expected = {0: 0, 1: 1, 2: float('infinity'), 3: float('infinity')} + self.assertEqual(dijkstra(graph, start_node), expected) + + def test_single_node_graph(self): + graph = {0: []} + start_node = 0 + expected = {0: 0} + self.assertEqual(dijkstra(graph, start_node), expected) + + def test_complex_graph(self): + graph = { + 0: [(1, 4), (2, 2)], + 1: [(2, 1), (3, 5)], + 2: [(3, 8), (4, 10)], + 3: [(4, 2), (5, 6)], + 4: [(5, 3)], + 5: [] + } + start_node = 0 + expected = {0: 0, 1: 4, 2: 2, 3: 9, 4: 11, 5: 14} + self.assertEqual(dijkstra(graph, start_node), expected) + + def test_start_node_not_in_graph(self): + graph = {0: [(1, 1)], 1: [(0, 1)]} + start_node = 2 + with self.assertRaises(KeyError): + dijkstra(graph, start_node) + +if __name__ == '__main__': + unittest.main()