Skip to content

Commit

Permalink
fixed: fixed case when overridden type of column does not work
Browse files Browse the repository at this point in the history
  • Loading branch information
wwoytenko committed Aug 20, 2024
1 parent 14e328e commit 3f26398
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 5 deletions.
4 changes: 4 additions & 0 deletions internal/db/postgres/context/pg_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ func getTables(

// Assigning columns, pk and fk for each table
for _, t := range tables {
if len(t.Columns) > 0 {
// Columns were already initialized during the transformer initialization
continue
}
columns, err := getColumnsConfig(ctx, tx, t.Oid, version)
if err != nil {
return nil, nil, fmt.Errorf("unable to collect table columns: %w", err)
Expand Down
6 changes: 6 additions & 0 deletions internal/db/postgres/context/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ func validateAndBuildTablesConfig(
}
table.Columns = columns

pkColumns, err := getPrimaryKeyColumns(ctx, tx, table.Oid)
if err != nil {
return nil, nil, fmt.Errorf("unable to collect primary key columns: %w", err)
}
table.PrimaryKey = pkColumns

// Assigning overridden column types for driver initialization
if tableCfg.ColumnsTypeOverride != nil {
for _, c := range table.Columns {
Expand Down
7 changes: 7 additions & 0 deletions pkg/toolkit/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,10 @@ func (c *Column) GetType() (string, Oid) {
}
return c.TypeName, c.TypeOid
}

func (c *Column) GetTypeOid() Oid {
if c.OverriddenTypeName != "" {
return c.OverriddenTypeOid
}
return c.TypeOid
}
6 changes: 3 additions & 3 deletions pkg/toolkit/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (d *Driver) EncodeValueByColumnIdx(idx int, src any, buf []byte) ([]byte, e
return nil, fmt.Errorf("index out ouf range: must be between 0 and %d received %d", d.maxIdx, idx)
}
c := d.Table.Columns[idx]
oid := uint32(c.TypeOid)
oid := uint32(c.GetTypeOid())
if c.OverriddenTypeOid != 0 {
oid = uint32(c.OverriddenTypeOid)
}
Expand Down Expand Up @@ -158,7 +158,7 @@ func (d *Driver) ScanValueByColumnIdx(idx int, src []byte, dest any) error {
return fmt.Errorf("index out ouf range: must be between 0 and %d received %d", d.maxIdx, idx)
}
c := d.Table.Columns[idx]
oid := uint32(c.TypeOid)
oid := uint32(c.GetTypeOid())
if c.OverriddenTypeOid != 0 {
oid = uint32(c.OverriddenTypeOid)
}
Expand Down Expand Up @@ -189,7 +189,7 @@ func (d *Driver) DecodeValueByColumnIdx(idx int, src []byte) (any, error) {
return nil, fmt.Errorf("index out ouf range: must be between 0 and %d received %d", d.maxIdx, idx)
}
c := d.Table.Columns[idx]
oid := uint32(c.TypeOid)
oid := uint32(c.GetTypeOid())
if c.OverriddenTypeOid != 0 {
oid = uint32(c.OverriddenTypeOid)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/toolkit/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Meta struct {
Table *Table `json:"table"`
Parameters *Parameters `json:"parameters"`
Types []*Type `json:"types"`
ColumnTypeOverrides map[string]string `json:"column_type_overrides"`
ColumnsTypeOverride map[string]string `json:"columns_type_override"`
}

type Parameters struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/toolkit/static_parameter.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func scanValue(driver *Driver, definition *ParameterDefinition, rawValue ParamsV

var typeOid uint32
if linkedColumnParameter != nil {
typeOid = uint32(linkedColumnParameter.Column.TypeOid)
typeOid = uint32(linkedColumnParameter.Column.GetTypeOid())
} else {
t, ok := driver.GetTypeMap().TypeForName(definition.CastDbType)
if !ok {
Expand Down

0 comments on commit 3f26398

Please sign in to comment.