diff --git a/config/csv/config.go b/config/csv/config.go index 01207bb..199da06 100644 --- a/config/csv/config.go +++ b/config/csv/config.go @@ -141,3 +141,10 @@ func Header(header bool) ToConfigFunc { c.Header = header } } + +// Columns holds the order to write CSV columns. +func Columns(cols []string) ToConfigFunc { + return func(c *ToConfig) { + c.Columns = cols + } +} diff --git a/internal/io/csv.go b/internal/io/csv.go index 6fe8107..dd07776 100644 --- a/internal/io/csv.go +++ b/internal/io/csv.go @@ -34,7 +34,8 @@ type CSVConfig struct { // For writing CSV type ToCsvConfig struct { - Header bool + Header bool + Columns []string } func isEmptyLine(fields [][]byte) bool { diff --git a/qframe.go b/qframe.go index 55aad08..86d5c30 100644 --- a/qframe.go +++ b/qframe.go @@ -1106,8 +1106,26 @@ func (qf QFrame) ToCSV(writer io.Writer, confFuncs ...csv.ToConfigFunc) error { return qerrors.Propagate("ToCSV", qf.Err) } - row := make([]string, 0, len(qf.columns)) - for _, s := range qf.columns { + var iterCols []namedColumn + if conf.Columns != nil { + if len(conf.Columns) != len(qf.columns) { + return qerrors.New("ToCSV", fmt.Sprintf("wrong number of columns: expected: %d", len(qf.columns))) + } + iterCols = make([]namedColumn, len(qf.columns)) + for i := range conf.Columns { + cName := conf.Columns[i] + if col, ok := qf.columnsByName[cName]; !ok { + return qerrors.New("ToCSV", fmt.Sprintf("%s: column does not exist in QFrame", cName)) + } else { + iterCols[i] = col + } + } + } else { + iterCols = qf.columns + } + + row := make([]string, 0, len(iterCols)) + for _, s := range iterCols { row = append(row, s.name) } columns := make([]column.Column, 0, len(qf.columns)) diff --git a/qframe_test.go b/qframe_test.go index d9b576d..8a904a9 100644 --- a/qframe_test.go +++ b/qframe_test.go @@ -1475,6 +1475,60 @@ func TestQFrame_ReadJSON(t *testing.T) { } } +func TestQFrame_ToCSV_ColOrder(t *testing.T) { + table := []struct { + input string + config []string + expected string + err string + }{ + { + input: `COL1,COL2,COL3 +1a,2a,3a +1b,2b,3b`, + config: []string{"COL1", "COL3", "COL2"}, + expected: `COL1,COL3,COL2 +1a,3a,2a +1b,3b,2b`, + err: "", + }, + { + input: `COL1,COL2,COL3 +1a,2a,3a +1b,2b,3b`, + config: []string{"COL1", "COLX", "COL2"}, + err: "COLX: column does not exist in QFrame", + expected: "", + }, + { + input: `COL1,COL2,COL3 +1a,2a,3a +1b,2b,3b`, + config: []string{"COL1", "COL3", "COL2", "COL4"}, + err: "wrong number of columns: expected: 3", + expected: "", + }, + } + + for i, tc := range table { + t.Run(fmt.Sprintf("ToCSV (ordered) %d", i), func(t *testing.T) { + in := qframe.ReadCSV(strings.NewReader(tc.input)) + assertNotErr(t, in.Err) + buf := new(bytes.Buffer) + err := in.ToCSV(buf, csv.Columns(tc.config)) + if tc.err == "" { + assertNotErr(t, err) + output := strings.TrimSpace(buf.String()) + if output != tc.expected { + t.Errorf("CSV columns not in order. \nGot:\n|%s|\nExpected:\n|%s|", output, tc.expected) + } + } else { + assertErr(t, err, tc.err) + } + }) + } +} + func TestQFrame_ToCSV(t *testing.T) { table := []struct { input map[string]interface{}