Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exec' returning number of affected rows #52

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
0.2.7:
* Add support for DATETIMEOFFSET
* Add exec variant returning number of affected rows
0.2.6:
* Add support for SQLSTATE
* Fix copying issues for error messages
Expand Down
2 changes: 0 additions & 2 deletions app/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@

module Main (main) where

import Data.List
import Data.Time.LocalTime (ZonedTime(..))
import qualified Data.Text as T
import qualified Data.Text.IO as T
import Control.Exception
import qualified Data.Text as T
import qualified Database.ODBC.Internal as ODBC
import System.Environment
import System.IO
Expand Down
9 changes: 8 additions & 1 deletion cbits/odbc.c
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ RETCODE odbc_SQLNumResultCols(SQLHSTMT *hstmt, SQLSMALLINT *cols){
return SQLNumResultCols(*hstmt, cols);
}

////////////////////////////////////////////////////////////////////////////////
// Get rows

RETCODE odbc_SQLRowCount(SQLHSTMT *hstmt, SQLLEN *rows){
return SQLRowCount(*hstmt, rows);
}

////////////////////////////////////////////////////////////////////////////////
// Logs

Expand Down Expand Up @@ -402,4 +409,4 @@ SQLSMALLINT TIMESTAMPOFFSET_STRUCT_timezone_hour(TIMESTAMPOFFSET_STRUCT *t){

SQLSMALLINT TIMESTAMPOFFSET_STRUCT_timezone_minute(TIMESTAMPOFFSET_STRUCT *t){
return t->timezone_minute;
}
}
2 changes: 1 addition & 1 deletion odbc.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description: Haskell binding to the ODBC API. This has been tested
suite runs on OS X, Windows and Linux.
copyright: FP Complete 2018
maintainer: [email protected]
version: 0.2.6
version: 0.2.7
license: BSD3
license-file: LICENSE
build-type: Simple
Expand Down
51 changes: 49 additions & 2 deletions src/Database/ODBC/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ module Database.ODBC.Internal
, Connection
-- * Executing queries
, exec
, execAffectedRows
, query
, Value(..)
, Binary(..)
Expand All @@ -38,6 +39,7 @@ module Database.ODBC.Internal
, Step(..)
-- * Parameters
, execWithParams
, execAffectedRowsWithParams
, queryWithParams
, streamWithParams
, Param(..)
Expand Down Expand Up @@ -280,7 +282,18 @@ exec ::
-> m ()
exec conn string = execWithParams conn string mempty

-- | Same as 'exec' but with parameters.
-- | Execute a statement on the database and returns number of affected rows.
--
-- @since 0.2.7
execAffectedRows ::
MonadIO m
=> Connection -- ^ A connection to the database.
-> Text -- ^ SQL statement.
-> m Int
execAffectedRows conn string = execAffectedRowsWithParams conn string mempty
{-# INLINE execAffectedRows #-}

-- | Same as 'execAffectedRows but with parameters.
--
-- @since 0.2.4
execWithParams ::
Expand All @@ -296,6 +309,22 @@ execWithParams conn string params =
"exec"
(\dbc -> withExecDirect dbc string params (fetchAllResults dbc)))

-- | Same as 'execAffectedRowsWithParams but returns number of affected rows.
--
-- @since 0.2.7
execAffectedRowsWithParams ::
MonadIO m
=> Connection -- ^ A connection to the database.
-> Text -- ^ SQL query with ? inside.
-> [Param] -- ^ Params matching the ? in the query string.
-> m Int
execAffectedRowsWithParams conn string params =
withBound
(withHDBC
conn
"exec"
(\dbc -> withExecDirect dbc string params (fetchAllResults' dbc)))

-- | Query and return a list of rows.
query ::
MonadIO m
Expand Down Expand Up @@ -549,6 +578,21 @@ fetchAllResults dbc stmt = do
(retcode == sql_success || retcode == sql_success_with_info)
(fetchAllResults dbc stmt)

-- | Fetch all results from possible multiple statements.
fetchAllResults' :: Ptr EnvAndDbc -> SQLHSTMT s -> IO Int
fetchAllResults' dbc stmt = countRows <* fetchAllResults dbc stmt
where
countRows = do
SQLLEN rows <-
withMalloc
(\sizep -> do
assertSuccess
dbc
"odbc_SQLRowCount"
(odbc_SQLRowCount stmt sizep)
peek sizep)
pure $! fromIntegral (max 0 rows)

-- | Fetch all rows from a statement.
fetchStatementRows :: Ptr EnvAndDbc -> SQLHSTMT s -> IO [[(Column,Value)]]
fetchStatementRows dbc stmt = do
Expand Down Expand Up @@ -1089,7 +1133,7 @@ newtype SQLCHAR = SQLCHAR CChar deriving (Show, Eq, Storable)
-- https://github.com/Microsoft/ODBC-Specification/blob/753d7e714b7eab9eaab4ad6105fdf4267d6ad6f6/Windows/inc/sqltypes.h#L88
newtype SQLSMALLINT = SQLSMALLINT Int16 deriving (Show, Eq, Storable, Num, Integral, Enum, Ord, Real)

-- https://github.com/Microsoft/ODBC-Specification/blob/753d7e714b7eab9eaab4ad6105fdf4267d6ad6f6/Windows/inc/sqltypes.h#L64
-- https://github.com/Microsoft/ODBC-Specification/blob/753d7e714b7eab9eaab4ad6105fdf4267d6ad6f6/Windows/inc/sqltypes.h#L641
newtype SQLLEN = SQLLEN Int64 deriving (Show, Eq, Storable, Num)

-- https://github.com/Microsoft/ODBC-Specification/blob/753d7e714b7eab9eaab4ad6105fdf4267d6ad6f6/Windows/inc/sqltypes.h#L65..L65
Expand Down Expand Up @@ -1168,6 +1212,9 @@ foreign import ccall "odbc odbc_SQLMoreResults"
foreign import ccall "odbc odbc_SQLNumResultCols"
odbc_SQLNumResultCols :: SQLHSTMT s -> Ptr SQLSMALLINT -> IO RETCODE

foreign import ccall "odbc odbc_SQLRowCount"
odbc_SQLRowCount :: SQLHSTMT s -> Ptr SQLLEN -> IO RETCODE

foreign import ccall "odbc odbc_SQLGetData"
odbc_SQLGetData
:: Ptr EnvAndDbc
Expand Down
16 changes: 13 additions & 3 deletions src/Database/ODBC/SQLServer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ module Database.ODBC.SQLServer

-- * Executing queries
, exec
, execAffectedRows
, query
, Value(..)
, Query
Expand Down Expand Up @@ -79,8 +80,6 @@ import Data.Fixed
import Data.Foldable
import Data.Int
import Data.Maybe
import Data.Monoid (Monoid, (<>))
import Data.Semigroup (Semigroup)
import Data.Sequence (Seq)
import qualified Data.Sequence as Seq
import Data.String
Expand Down Expand Up @@ -482,6 +481,17 @@ exec c q = Internal.execWithParams c rendered params
where
(rendered, params) = renderedAndParams q

-- | Execute a statement on the database and return number of affected rows.
execAffectedRows ::
MonadIO m
=> Connection -- ^ A connection to the database.
-> Query -- ^ SQL statement.
-> m Int
execAffectedRows c q = Internal.execAffectedRowsWithParams c rendered params
where
(rendered, params) = renderedAndParams q
{-# INLINE execAffectedRows #-}

--------------------------------------------------------------------------------
-- Query building

Expand All @@ -496,7 +506,7 @@ renderedAndParams q = (renderParts parts', params)
ValuePart v
| Just {} <- valueToParam v ->
case v of
TextValue t -> TextPart "CAST(? AS NVARCHAR(MAX))"
TextValue _ -> TextPart "CAST(? AS NVARCHAR(MAX))"
_ -> TextPart "?"
p -> p)
parts
Expand Down
1 change: 0 additions & 1 deletion src/Database/ODBC/TH.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ module Database.ODBC.TH
import Control.DeepSeq
import Data.Char
import Data.List (foldl1')
import Data.Monoid ((<>))
import Language.Haskell.TH (Q, Exp)
import qualified Language.Haskell.TH as TH
import Language.Haskell.TH.Quote (QuasiQuoter(..))
Expand Down
20 changes: 20 additions & 0 deletions test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ spec = do
(do describe "Connectivity" connectivity
describe "Regression tests" regressions
describe "Data retrieval" dataRetrieval
describe "Data affected" dataAffected
describe "Big data" bigData)
describe
"Database.ODBC.SQLServer"
Expand Down Expand Up @@ -192,6 +193,25 @@ connectivity = do
(do sequence_ [connectWithString >>= Internal.close | _ <- [1 :: Int .. 10]]
shouldBe True True)

dataAffected :: Spec
dataAffected = do
it
"Basic sanity check"
(do c <- connectWithString
_ <- Internal.execAffectedRows c "DROP TABLE IF EXISTS test"
arOnCreate <- Internal.execAffectedRows
c
"CREATE TABLE test (int integer, text text, bool bit, nt ntext, fl float)"
_ <- Internal.execAffectedRows
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this return 3 too ?

c
"INSERT INTO test VALUES (123, 'abc', 1, 'wib', 2.415), (456, 'def', 0, 'wibble',0.9999999999999), (NULL, NULL, NULL, NULL, NULL)"
arOnDelete <- Internal.execAffectedRows c "delete from test"
arOnDelete' <- Internal.execAffectedRows c "delete from test"
Internal.close c
shouldBe
[("create", arOnCreate), ("delete", arOnDelete), ("delete'", arOnDelete')]
[("create", 0), ("delete", 3), ("delete'", 0)])

dataRetrieval :: Spec
dataRetrieval = do
it
Expand Down