Merge pull request #26 from aclindsa/postgres

Postgres Support
This commit is contained in:
Aaron Lindsay 2017-11-17 08:05:35 -05:00 committed by GitHub
commit 144c89655a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 76 additions and 13 deletions

View File

@ -14,7 +14,7 @@ services:
env:
- MONEYGO_TEST_DB=sqlite
- MONEYGO_TEST_DB=mysql MONEYGO_TEST_DSN="root@tcp(127.0.0.1)/moneygo_test?parseTime=true"
# - MONEYGO_TEST_DB=postgres MONEYGO_TEST_DSN="postgres://postgres@localhost/moneygo_test"
- MONEYGO_TEST_DB=postgres MONEYGO_TEST_DSN="postgres://postgres@localhost/moneygo_test"
before_script:
- sh -c "if [ $MONEYGO_TEST_DB = 'postgres' ]; then psql -c 'DROP DATABASE IF EXISTS moneygo_test;' -U postgres; fi"

View File

@ -14,8 +14,6 @@ type ResponseWriterWriter interface {
Write(http.ResponseWriter) error
}
type Tx = gorp.Transaction
type Context struct {
Tx *Tx
User *User
@ -51,7 +49,7 @@ type APIHandler struct {
}
func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
tx, err := ah.DB.Begin()
tx, err := GetTx(ah.DB)
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)

View File

@ -64,7 +64,7 @@ func TestCreatePrice(t *testing.T) {
if p.CurrencyId != d.securities[orig.CurrencyId].SecurityId {
t.Errorf("CurrencyId doesn't match")
}
if p.Date != orig.Date {
if !p.Date.Equal(orig.Date) {
t.Errorf("Date doesn't match")
}
if p.Value != orig.Value {
@ -94,7 +94,7 @@ func TestGetPrice(t *testing.T) {
if p.CurrencyId != d.securities[orig.CurrencyId].SecurityId {
t.Errorf("CurrencyId doesn't match")
}
if p.Date != orig.Date {
if !p.Date.Equal(orig.Date) {
t.Errorf("Date doesn't match")
}
if p.Value != orig.Value {
@ -131,7 +131,7 @@ func TestGetPrices(t *testing.T) {
found := false
for _, p := range *pl.Prices {
if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date == orig.Date && p.Value == orig.Value && p.RemoteId == orig.RemoteId {
if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date.Equal(orig.Date) && p.Value == orig.Value && p.RemoteId == orig.RemoteId {
if _, ok := foundIds[p.PriceId]; ok {
continue
}
@ -177,7 +177,7 @@ func TestUpdatePrice(t *testing.T) {
if p.CurrencyId != curr.CurrencyId {
t.Errorf("CurrencyId doesn't match")
}
if p.Date != curr.Date {
if !p.Date.Equal(curr.Date) {
t.Errorf("Date doesn't match")
}
if p.Value != curr.Value {

View File

@ -623,7 +623,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
var sqlsort, balanceLimitOffset string
var balanceLimitOffsetArg uint64
if sort == "date-asc" {
sqlsort = " ORDER BY transactions.Date ASC"
sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC"
balanceLimitOffset = " LIMIT ?"
balanceLimitOffsetArg = page * limit
} else if sort == "date-desc" {
@ -631,7 +631,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
if err != nil {
return nil, err
}
sqlsort = " ORDER BY transactions.Date DESC"
sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC"
balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits)
balanceLimitOffsetArg = (page + 1) * limit
}
@ -676,7 +676,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
// Sum all the splits for all transaction splits for this account that
// occurred before the page we're returning
var amounts []string
sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?"
sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?"
_, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
if err != nil {
return nil, err

View File

@ -87,8 +87,8 @@ func ensureTransactionsMatch(t *testing.T, expected, tran *handlers.Transaction,
if tran.Description != expected.Description {
t.Errorf("Description doesn't match")
}
if tran.Date != expected.Date {
t.Errorf("Date doesn't match")
if !tran.Date.Equal(expected.Date) {
t.Errorf("Date (%+v) differs from expected (%+v)", tran.Date, expected.Date)
}
if len(tran.Splits) != len(expected.Splits) {

65
internal/handlers/tx.go Normal file
View File

@ -0,0 +1,65 @@
package handlers
import (
"database/sql"
"gopkg.in/gorp.v1"
"strings"
)
type Tx struct {
Dialect gorp.Dialect
Tx *gorp.Transaction
}
func (tx *Tx) Rebind(query string) string {
chunks := strings.Split(query, "?")
str := chunks[0]
for i := 1; i < len(chunks); i++ {
str += tx.Dialect.BindVar(i-1) + chunks[i]
}
return str
}
func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
return tx.Tx.Select(i, tx.Rebind(query), args...)
}
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.Tx.Exec(tx.Rebind(query), args...)
}
func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) {
return tx.Tx.SelectInt(tx.Rebind(query), args...)
}
func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error {
return tx.Tx.SelectOne(holder, tx.Rebind(query), args...)
}
func (tx *Tx) Insert(list ...interface{}) error {
return tx.Tx.Insert(list...)
}
func (tx *Tx) Update(list ...interface{}) (int64, error) {
return tx.Tx.Update(list...)
}
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
return tx.Tx.Delete(list...)
}
func (tx *Tx) Commit() error {
return tx.Tx.Commit()
}
func (tx *Tx) Rollback() error {
return tx.Tx.Rollback()
}
func GetTx(db *gorp.DbMap) (*Tx, error) {
tx, err := db.Begin()
if err != nil {
return nil, err
}
return &Tx{db.Dialect, tx}, nil
}