Skip to content

Commit

Permalink
Add initial implemention for plpgsql ast traversal
Browse files Browse the repository at this point in the history
  • Loading branch information
svenklemm committed Aug 25, 2024
1 parent 39e44d2 commit b326070
Show file tree
Hide file tree
Showing 5 changed files with 395 additions and 1 deletion.
112 changes: 112 additions & 0 deletions src/pgspot/path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
class Path:
"""A path is a sequence of steps that will be executed in a PLpgSQL function."""

def __init__(self, root, steps=None, stack=None):
self.root = root
# steps is the list of nodes that have been processed
self.steps = steps.copy() if steps else []
# stack is a list of nodes that are yet to be processed
self.stack = stack.copy() if stack else []

def copy(self):
return Path(self.root, self.steps, self.stack)

def __str__(self):
return " -> ".join([str(step) for step in self.steps])


def paths(root):
p = Path(root)
pathes = []
dfs(root, p, pathes)
yield p

while pathes:
p = pathes.pop(0)
t = p.stack.pop(0)
dfs(t, p, pathes)
yield p


def dfs(node, path, pathes):
"""traverse tree depth first similar to how it would get executed"""
if not node:
return
if node:
match node.type:
case "PLpgSQL_function":
# This should be top level node and so stack should be empty
assert not path.stack
path.stack = [node.action] + path.stack
case "PLpgSQL_stmt_block":
# FIXME: Add support for exception handling
path.stack = node.body + path.stack
case "PLpgSQL_stmt_if":
path.steps.append(node)
if node.elsif_list:
for elsif in node.elsif_list:
alt = path.copy()
alt.stack = elsif.stmts + alt.stack
pathes.append(alt)
if node.else_body:
alt = path.copy()
alt.stack = node.else_body + alt.stack
pathes.append(alt)

path.stack = node.then_body + path.stack

# different types of loops
# FIXME: Add support for loop exit
case (
"PLpgSQL_stmt_loop"
| "PLpgSQL_stmt_while"
| "PLpgSQL_stmt_forc"
| "PLpgSQL_stmt_fori"
| "PLpgSQL_stmt_fors"
| "PLpgSQL_stmt_dynfors"
):
path.stack = node.body + path.stack

# nodes with no children
case (
"PLpgSQL_stmt_assert"
| "PLpgSQL_stmt_assign"
| "PLpgSQL_stmt_call"
| "PLpgSQL_stmt_close"
| "PLpgSQL_stmt_commit"
| "PLpgSQL_stmt_dynexecute"
| "PLpgSQL_stmt_execsql"
| "PLpgSQL_stmt_fetch"
| "PLpgSQL_stmt_getdiag"
| "PLpgSQL_stmt_open"
| "PLpgSQL_stmt_perform"
| "PLpgSQL_stmt_raise"
| "PLpgSQL_stmt_rollback"
):
path.steps.append(node)

# nodes not yet implemented
case (
"PLpgSQL_stmt_case"
| "PLpgSQL_stmt_exit"
| "PLpgSQL_stmt_forc"
| "PLpgSQL_stmt_foreach_a"
):
raise Exception(f"Not yet implemented {node.type}")

# nodes that will end current path
case (
"PLpgSQL_stmt_return"
| "PLpgSQL_stmt_return_next"
| "PLpgSQL_stmt_return_query"
):
path.steps.append(node)
path.stack.clear()
return

case _:
raise Exception(f"Unknown node type {node.type}")

while path.stack:
t = path.stack.pop(0)
dfs(t, path, pathes)
20 changes: 19 additions & 1 deletion src/pgspot/plpgsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class PLpgSQLNode:
def __init__(self, raw):
self.type = list(raw.keys())[0]
self.lineno = None
self.lineno = ""
for k, v in raw[self.type].items():
setattr(self, k, build_node(v))

Expand All @@ -24,6 +24,24 @@ def __repr__(self):
return f"{self.type}({fields})"


class PLpgSQL_stmt_if(PLpgSQLNode):
def __init__(self, raw):
self.then_body = None
self.elsif_list = None
self.else_body = None
super().__init__(raw)


class PLpgSQL_row(PLpgSQLNode):
def __init__(self, raw):
# PLpgSQL_row has a fields attribute which is a list of dicts that
# don't have the same structure as other node dicts. So we pop it out
# and set it as an attribute directly instead of having it handled by
# recursion.
self.fields = raw["PLpgSQL_row"].pop("fields")
super().__init__(raw)


class PLpgSQL_var(PLpgSQLNode):
def __init__(self, raw):
self.refname = None
Expand Down
111 changes: 111 additions & 0 deletions tests/plpgsql_path_if_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@

from pglast import parse_plpgsql
from pgspot.plpgsql import build_node
from pgspot.path import paths

def test_if_minimal_stmt():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
BEGIN
IF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '1';
END IF;
END
$$;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
assert node.type == "PLpgSQL_function"

pathes = list(paths(node))
assert len(pathes) == 1

assert str(pathes[0]) == "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(4) -> PLpgSQL_stmt_return()"

def test_if_else():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
BEGIN
IF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '1';
ELSIF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '2';
ELSIF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '3';
ELSIF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '4';
ELSE
EXECUTE cmd || '5';
END IF;
END
$$;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
assert node.type == "PLpgSQL_function"

pathes = list(paths(node))
assert len(pathes) == 5

assert str(pathes[0]) == "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(4) -> PLpgSQL_stmt_return()"
assert str(pathes[1]) == "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(6) -> PLpgSQL_stmt_return()"
assert str(pathes[2]) == "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(8) -> PLpgSQL_stmt_return()"
assert str(pathes[3]) == "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(10) -> PLpgSQL_stmt_return()"
assert str(pathes[4]) == "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(12) -> PLpgSQL_stmt_return()"

def test_if_stmt():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
BEGIN
IF EXISTS (SELECT 1 FROM pg_stat_activity) THEN
EXECUTE cmd || '1';
ELSE
EXECUTE cmd || '2';
RETURN 'foo';
END IF;
IF EXISTS (SELECT 1 FROM pg_stat_activity) THEN
EXECUTE cmd;
ELSE
EXECUTE cmd;
END IF;
END
$$;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
assert node.type == "PLpgSQL_function"

pathes = list(paths(node))
assert len(pathes) == 3

assert str(pathes[0]) == "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(4) -> PLpgSQL_stmt_if(9) -> PLpgSQL_stmt_dynexecute(10) -> PLpgSQL_stmt_return()"
assert str(pathes[1]) == "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(6) -> PLpgSQL_stmt_return(7)"
assert str(pathes[2]) == "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(4) -> PLpgSQL_stmt_if(9) -> PLpgSQL_stmt_dynexecute(12) -> PLpgSQL_stmt_return()"

def test_nested_if_stmt():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
BEGIN
IF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '1';
ELSE
IF EXISTS (SELECT FROM pg_stat_activity) THEN
EXECUTE cmd || '2';
ELSE
EXECUTE cmd || '3';
END IF;
END IF;
END
$$;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
assert node.type == "PLpgSQL_function"

pathes = list(paths(node))
assert len(pathes) == 3

assert str(pathes[0]) == "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_dynexecute(4) -> PLpgSQL_stmt_return()"
assert str(pathes[1]) == "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_if(6) -> PLpgSQL_stmt_dynexecute(7) -> PLpgSQL_stmt_return()"
assert str(pathes[2]) == "PLpgSQL_stmt_if(3) -> PLpgSQL_stmt_if(6) -> PLpgSQL_stmt_dynexecute(9) -> PLpgSQL_stmt_return()"

117 changes: 117 additions & 0 deletions tests/plpgsql_path_loop_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@

from pglast import parse_plpgsql
from pgspot.plpgsql import build_node
from pgspot.path import paths

def test_loop():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
BEGIN
LOOP
EXECUTE cmd || '1';
EXECUTE cmd || '2';
END LOOP;
END
$$;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
pathes = list(paths(node))
assert len(pathes) == 1

assert str(pathes[0]) == "PLpgSQL_stmt_dynexecute(4) -> PLpgSQL_stmt_dynexecute(5) -> PLpgSQL_stmt_return()"

def test_while_loop():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
BEGIN
WHILE true LOOP
EXECUTE cmd || '1';
EXECUTE cmd || '2';
END LOOP;
END
$$;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
pathes = list(paths(node))
assert len(pathes) == 1

assert str(pathes[0]) == "PLpgSQL_stmt_dynexecute(4) -> PLpgSQL_stmt_dynexecute(5) -> PLpgSQL_stmt_return()"

def test_fori_loop():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
DECLARE
i INT;
BEGIN
FOR i IN 1..10 LOOP
RAISE NOTICE 'i is %',i;
END LOOP;
END
$$;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
pathes = list(paths(node))
assert len(pathes) == 1

assert str(pathes[0]) == "PLpgSQL_stmt_raise(6) -> PLpgSQL_stmt_return()"

def test_fors_loop():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
DECLARE
i INT;
BEGIN
FOR i IN SELECT generate_series(1,10) LOOP
RAISE NOTICE 'i is %',i;
END LOOP;
END
$$ LANGUAGE plpgsql;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
pathes = list(paths(node))
assert len(pathes) == 1

assert str(pathes[0]) == "PLpgSQL_stmt_raise(6) -> PLpgSQL_stmt_return()"

def test_dynfors_loop():
sql = """
CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
DECLARE
i INT;
BEGIN
FOR i IN EXECUTE 'SELECT generate_series(1,10)' LOOP
RAISE NOTICE 'i is %',i;
END LOOP;
END
$$ LANGUAGE plpgsql;
"""
parsed = parse_plpgsql(sql)
node = build_node(parsed[0])
pathes = list(paths(node))
assert len(pathes) == 1

assert str(pathes[0]) == "PLpgSQL_stmt_raise(6) -> PLpgSQL_stmt_return()"

#def test_forc_loop():
# sql = """
# CREATE FUNCTION foo(cmd TEXT) RETURNS void AS $$
# DECLARE
# i INT;
# c CURSOR FOR SELECT generate_series(1,10);
# BEGIN
# FOR i IN c LOOP
# RAISE NOTICE 'i is %',i;
# END LOOP;
# END
# $$ LANGUAGE plpgsql;
# """
# parsed = parse_plpgsql(sql)
# node = build_node(parsed[0])
# pathes = list(paths(node))
# assert len(pathes) == 1
#
# assert str(pathes[0]) == "PLpgSQL_stmt_raise(6) -> PLpgSQL_stmt_return()"
Loading

0 comments on commit b326070

Please sign in to comment.