From 41b2a5e842153483f47034c651b218be11571518 Mon Sep 17 00:00:00 2001 From: Sebastian Boyd Date: Wed, 11 Oct 2023 18:03:10 -0700 Subject: [PATCH] Allow for multiple foreign_key in CreateQueryBuilder --- pypika/queries.py | 73 ++++++++++++++++++++++--------------- pypika/tests/test_create.py | 31 +++++++++++++--- 2 files changed, 69 insertions(+), 35 deletions(-) diff --git a/pypika/queries.py b/pypika/queries.py index 76eeea25..ce538e32 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -1707,6 +1707,36 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl self.fields = [field.replace_table(current_table, new_table) for field in self.fields] +class ForeignKey: + """Represents a foreign key constraint.""" + + def __init__( + self, + columns: List[Column], + reference_table: Union[str, Table], + reference_columns: List[Column], + on_delete: ReferenceOption = None, + on_update: ReferenceOption = None, + ) -> None: + self.columns = columns + self.reference_table = reference_table + self.reference_columns = reference_columns + self.on_delete = on_delete + self.on_update = on_update + + def get_sql(self, **kwargs: Any) -> str: + foreign_key_sql = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format( + columns=",".join(column.get_name_sql(**kwargs) for column in self.columns), + table_name=self.reference_table.get_sql(**kwargs), + reference_columns=",".join(column.get_name_sql(**kwargs) for column in self.reference_columns), + ) + if self.on_delete: + foreign_key_sql += " ON DELETE " + self.on_delete.value + if self.on_update: + foreign_key_sql += " ON UPDATE " + self.on_update.value + return foreign_key_sql + + class CreateQueryBuilder: """ Query builder used to build CREATE queries. @@ -1729,11 +1759,7 @@ def __init__(self, dialect: Optional[Dialects] = None) -> None: self._uniques = [] self._if_not_exists = False self.dialect = dialect - self._foreign_key = None - self._foreign_key_reference_table = None - self._foreign_key_reference = None - self._foreign_key_on_update: ReferenceOption = None - self._foreign_key_on_delete: ReferenceOption = None + self._foreign_keys = [] def _set_kwargs_defaults(self, kwargs: dict) -> None: kwargs.setdefault("quote_char", self.QUOTE_CHAR) @@ -1908,19 +1934,19 @@ def foreign_key( Update option. - :raises AttributeError: - If the foreign key is already defined. - :return: CreateQueryBuilder. """ - if self._foreign_key: - raise AttributeError("'Query' object already has attribute foreign_key") - self._foreign_key = self._prepare_columns_input(columns) - self._foreign_key_reference_table = reference_table - self._foreign_key_reference = self._prepare_columns_input(reference_columns) - self._foreign_key_on_delete = on_delete - self._foreign_key_on_update = on_update + + self._foreign_keys.append( + ForeignKey( + columns=self._prepare_columns_input(columns), + reference_table=reference_table, + reference_columns=self._prepare_columns_input(reference_columns), + on_delete=on_delete, + on_update=on_update, + ) + ) @builder def as_select(self, query_builder: QueryBuilder) -> "CreateQueryBuilder": @@ -2017,28 +2043,17 @@ def _primary_key_clause(self, **kwargs) -> str: columns=",".join(column.get_name_sql(**kwargs) for column in self._primary_key) ) - def _foreign_key_clause(self, **kwargs) -> str: - clause = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format( - columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key), - table_name=self._foreign_key_reference_table.get_sql(**kwargs), - reference_columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key_reference), - ) - if self._foreign_key_on_delete: - clause += " ON DELETE " + self._foreign_key_on_delete.value - if self._foreign_key_on_update: - clause += " ON UPDATE " + self._foreign_key_on_update.value - - return clause + def _foreign_key_clauses(self, **kwargs) -> str: + return [foreign_key.get_sql(**kwargs) for foreign_key in self._foreign_keys] def _body_sql(self, **kwargs) -> str: clauses = self._column_clauses(**kwargs) clauses += self._period_for_clauses(**kwargs) clauses += self._unique_key_clauses(**kwargs) + clauses += self._foreign_key_clauses(**kwargs) if self._primary_key: clauses.append(self._primary_key_clause(**kwargs)) - if self._foreign_key: - clauses.append(self._foreign_key_clause(**kwargs)) return ",".join(clauses) diff --git a/pypika/tests/test_create.py b/pypika/tests/test_create.py index 32507654..ccda1494 100644 --- a/pypika/tests/test_create.py +++ b/pypika/tests/test_create.py @@ -99,6 +99,31 @@ def test_create_table_with_columns(self): str(q), ) + with self.subTest("with multiple foreign key constrains"): + secondary_table = Table("secondary_table") + cref, dref = Columns(("c", "INT"), ("d", "VARCHAR(100)")) + q = ( + Query.create_table(self.new_table) + .columns(self.foo, self.bar) + .foreign_key([self.foo], self.existing_table, [cref]) + .foreign_key( + [self.bar], + secondary_table, + [dref], + on_delete=ReferenceOption.cascade, + on_update=ReferenceOption.restrict, + ) + ) + + self.assertEqual( + 'CREATE TABLE "abc" (' + '"a" INT,' + '"b" VARCHAR(100),' + 'FOREIGN KEY ("a") REFERENCES "efg" ("c"),' + 'FOREIGN KEY ("b") REFERENCES "secondary_table" ("d") ON DELETE CASCADE ON UPDATE RESTRICT)', + str(q), + ) + with self.subTest("with unique keys"): q = ( Query.create_table(self.new_table) @@ -156,12 +181,6 @@ def test_create_table_with_select_and_columns_fails(self): with self.assertRaises(AttributeError): Query.create_table(self.new_table).as_select(select).columns(self.foo, self.bar) - with self.subTest("repeated foreign key"): - with self.assertRaises(AttributeError): - Query.create_table(self.new_table).foreign_key([self.foo], self.existing_table, [self.bar]).foreign_key( - [self.foo], self.existing_table, [self.bar] - ) - def test_create_table_as_select_not_query_raises_error(self): with self.assertRaises(TypeError): Query.create_table(self.new_table).as_select("abc")