From 00e1e899c036bd55382ececc24d2bdb52c807372 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Thu, 16 Nov 2017 21:29:36 -0500 Subject: [PATCH 1/4] Rebind all SQL queries to acommodate Postgres --- internal/handlers/handlers.go | 4 +-- internal/handlers/tx.go | 65 +++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 internal/handlers/tx.go diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 2022a2a..b330ed6 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -14,8 +14,6 @@ type ResponseWriterWriter interface { Write(http.ResponseWriter) error } -type Tx = gorp.Transaction - type Context struct { Tx *Tx User *User @@ -51,7 +49,7 @@ type APIHandler struct { } func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { - tx, err := ah.DB.Begin() + tx, err := GetTx(ah.DB) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/tx.go b/internal/handlers/tx.go new file mode 100644 index 0000000..ae19a4b --- /dev/null +++ b/internal/handlers/tx.go @@ -0,0 +1,65 @@ +package handlers + +import ( + "database/sql" + "gopkg.in/gorp.v1" + "strings" +) + +type Tx struct { + Dialect gorp.Dialect + Tx *gorp.Transaction +} + +func (tx *Tx) Rebind(query string) string { + chunks := strings.Split(query, "?") + str := chunks[0] + for i := 1; i < len(chunks); i++ { + str += tx.Dialect.BindVar(i-1) + chunks[i] + } + return str +} + +func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { + return tx.Tx.Select(i, tx.Rebind(query), args...) +} + +func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { + return tx.Tx.Exec(tx.Rebind(query), args...) +} + +func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) { + return tx.Tx.SelectInt(tx.Rebind(query), args...) +} + +func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error { + return tx.Tx.SelectOne(holder, tx.Rebind(query), args...) +} + +func (tx *Tx) Insert(list ...interface{}) error { + return tx.Tx.Insert(list...) +} + +func (tx *Tx) Update(list ...interface{}) (int64, error) { + return tx.Tx.Update(list...) +} + +func (tx *Tx) Delete(list ...interface{}) (int64, error) { + return tx.Tx.Delete(list...) +} + +func (tx *Tx) Commit() error { + return tx.Tx.Commit() +} + +func (tx *Tx) Rollback() error { + return tx.Tx.Rollback() +} + +func GetTx(db *gorp.DbMap) (*Tx, error) { + tx, err := db.Begin() + if err != nil { + return nil, err + } + return &Tx{db.Dialect, tx}, nil +} From 50dd7b1d2631b01b495e04f46d17512d16f5da38 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Fri, 17 Nov 2017 05:19:27 -0500 Subject: [PATCH 2/4] testing: Use Time.Equal for date comparisons Postgres actually preserves the timezone, unlike sqlite and mysql... --- internal/handlers/prices_test.go | 8 ++++---- internal/handlers/transactions_test.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/handlers/prices_test.go b/internal/handlers/prices_test.go index d362614..1cbca93 100644 --- a/internal/handlers/prices_test.go +++ b/internal/handlers/prices_test.go @@ -64,7 +64,7 @@ func TestCreatePrice(t *testing.T) { if p.CurrencyId != d.securities[orig.CurrencyId].SecurityId { t.Errorf("CurrencyId doesn't match") } - if p.Date != orig.Date { + if !p.Date.Equal(orig.Date) { t.Errorf("Date doesn't match") } if p.Value != orig.Value { @@ -94,7 +94,7 @@ func TestGetPrice(t *testing.T) { if p.CurrencyId != d.securities[orig.CurrencyId].SecurityId { t.Errorf("CurrencyId doesn't match") } - if p.Date != orig.Date { + if !p.Date.Equal(orig.Date) { t.Errorf("Date doesn't match") } if p.Value != orig.Value { @@ -131,7 +131,7 @@ func TestGetPrices(t *testing.T) { found := false for _, p := range *pl.Prices { - if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date == orig.Date && p.Value == orig.Value && p.RemoteId == orig.RemoteId { + if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date.Equal(orig.Date) && p.Value == orig.Value && p.RemoteId == orig.RemoteId { if _, ok := foundIds[p.PriceId]; ok { continue } @@ -177,7 +177,7 @@ func TestUpdatePrice(t *testing.T) { if p.CurrencyId != curr.CurrencyId { t.Errorf("CurrencyId doesn't match") } - if p.Date != curr.Date { + if !p.Date.Equal(curr.Date) { t.Errorf("Date doesn't match") } if p.Value != curr.Value { diff --git a/internal/handlers/transactions_test.go b/internal/handlers/transactions_test.go index 84f1c2e..9a7a9f2 100644 --- a/internal/handlers/transactions_test.go +++ b/internal/handlers/transactions_test.go @@ -87,8 +87,8 @@ func ensureTransactionsMatch(t *testing.T, expected, tran *handlers.Transaction, if tran.Description != expected.Description { t.Errorf("Description doesn't match") } - if tran.Date != expected.Date { - t.Errorf("Date doesn't match") + if !tran.Date.Equal(expected.Date) { + t.Errorf("Date (%+v) differs from expected (%+v)", tran.Date, expected.Date) } if len(tran.Splits) != len(expected.Splits) { From d5bb6ae26cf401376ceae6cf64e32f97dbeab412 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Thu, 16 Nov 2017 21:30:15 -0500 Subject: [PATCH 3/4] .travis.yml: Enable postgres --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 85bc80d..39468c2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,7 +14,7 @@ services: env: - MONEYGO_TEST_DB=sqlite - MONEYGO_TEST_DB=mysql MONEYGO_TEST_DSN="root@tcp(127.0.0.1)/moneygo_test?parseTime=true" -# - MONEYGO_TEST_DB=postgres MONEYGO_TEST_DSN="postgres://postgres@localhost/moneygo_test" + - MONEYGO_TEST_DB=postgres MONEYGO_TEST_DSN="postgres://postgres@localhost/moneygo_test" before_script: - sh -c "if [ $MONEYGO_TEST_DB = 'postgres' ]; then psql -c 'DROP DATABASE IF EXISTS moneygo_test;' -U postgres; fi" From b2359e126756320eecc0b9ac3d400e7ea70840c6 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Fri, 17 Nov 2017 05:50:00 -0500 Subject: [PATCH 4/4] Fixup account transactions to work for both Postgres and MySQL --- internal/handlers/transactions.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index d1b729b..a6adc8e 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -623,7 +623,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa var sqlsort, balanceLimitOffset string var balanceLimitOffsetArg uint64 if sort == "date-asc" { - sqlsort = " ORDER BY transactions.Date ASC" + sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC" balanceLimitOffset = " LIMIT ?" balanceLimitOffsetArg = page * limit } else if sort == "date-desc" { @@ -631,7 +631,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa if err != nil { return nil, err } - sqlsort = " ORDER BY transactions.Date DESC" + sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC" balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits) balanceLimitOffsetArg = (page + 1) * limit } @@ -676,7 +676,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa // Sum all the splits for all transaction splits for this account that // occurred before the page we're returning var amounts []string - sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" + sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" _, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) if err != nil { return nil, err