1
0
mirror of https://github.com/aclindsa/moneygo.git synced 2024-12-26 23:42:29 -05:00

Merge pull request #26 from aclindsa/postgres

Postgres Support
This commit is contained in:
Aaron Lindsay 2017-11-17 08:05:35 -05:00 committed by GitHub
commit 144c89655a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 76 additions and 13 deletions

View File

@ -14,7 +14,7 @@ services:
env: env:
- MONEYGO_TEST_DB=sqlite - MONEYGO_TEST_DB=sqlite
- MONEYGO_TEST_DB=mysql MONEYGO_TEST_DSN="root@tcp(127.0.0.1)/moneygo_test?parseTime=true" - MONEYGO_TEST_DB=mysql MONEYGO_TEST_DSN="root@tcp(127.0.0.1)/moneygo_test?parseTime=true"
# - MONEYGO_TEST_DB=postgres MONEYGO_TEST_DSN="postgres://postgres@localhost/moneygo_test" - MONEYGO_TEST_DB=postgres MONEYGO_TEST_DSN="postgres://postgres@localhost/moneygo_test"
before_script: before_script:
- sh -c "if [ $MONEYGO_TEST_DB = 'postgres' ]; then psql -c 'DROP DATABASE IF EXISTS moneygo_test;' -U postgres; fi" - sh -c "if [ $MONEYGO_TEST_DB = 'postgres' ]; then psql -c 'DROP DATABASE IF EXISTS moneygo_test;' -U postgres; fi"

View File

@ -14,8 +14,6 @@ type ResponseWriterWriter interface {
Write(http.ResponseWriter) error Write(http.ResponseWriter) error
} }
type Tx = gorp.Transaction
type Context struct { type Context struct {
Tx *Tx Tx *Tx
User *User User *User
@ -51,7 +49,7 @@ type APIHandler struct {
} }
func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
tx, err := ah.DB.Begin() tx, err := GetTx(ah.DB)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)

View File

@ -64,7 +64,7 @@ func TestCreatePrice(t *testing.T) {
if p.CurrencyId != d.securities[orig.CurrencyId].SecurityId { if p.CurrencyId != d.securities[orig.CurrencyId].SecurityId {
t.Errorf("CurrencyId doesn't match") t.Errorf("CurrencyId doesn't match")
} }
if p.Date != orig.Date { if !p.Date.Equal(orig.Date) {
t.Errorf("Date doesn't match") t.Errorf("Date doesn't match")
} }
if p.Value != orig.Value { if p.Value != orig.Value {
@ -94,7 +94,7 @@ func TestGetPrice(t *testing.T) {
if p.CurrencyId != d.securities[orig.CurrencyId].SecurityId { if p.CurrencyId != d.securities[orig.CurrencyId].SecurityId {
t.Errorf("CurrencyId doesn't match") t.Errorf("CurrencyId doesn't match")
} }
if p.Date != orig.Date { if !p.Date.Equal(orig.Date) {
t.Errorf("Date doesn't match") t.Errorf("Date doesn't match")
} }
if p.Value != orig.Value { if p.Value != orig.Value {
@ -131,7 +131,7 @@ func TestGetPrices(t *testing.T) {
found := false found := false
for _, p := range *pl.Prices { for _, p := range *pl.Prices {
if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date == orig.Date && p.Value == orig.Value && p.RemoteId == orig.RemoteId { if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date.Equal(orig.Date) && p.Value == orig.Value && p.RemoteId == orig.RemoteId {
if _, ok := foundIds[p.PriceId]; ok { if _, ok := foundIds[p.PriceId]; ok {
continue continue
} }
@ -177,7 +177,7 @@ func TestUpdatePrice(t *testing.T) {
if p.CurrencyId != curr.CurrencyId { if p.CurrencyId != curr.CurrencyId {
t.Errorf("CurrencyId doesn't match") t.Errorf("CurrencyId doesn't match")
} }
if p.Date != curr.Date { if !p.Date.Equal(curr.Date) {
t.Errorf("Date doesn't match") t.Errorf("Date doesn't match")
} }
if p.Value != curr.Value { if p.Value != curr.Value {

View File

@ -623,7 +623,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
var sqlsort, balanceLimitOffset string var sqlsort, balanceLimitOffset string
var balanceLimitOffsetArg uint64 var balanceLimitOffsetArg uint64
if sort == "date-asc" { if sort == "date-asc" {
sqlsort = " ORDER BY transactions.Date ASC" sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC"
balanceLimitOffset = " LIMIT ?" balanceLimitOffset = " LIMIT ?"
balanceLimitOffsetArg = page * limit balanceLimitOffsetArg = page * limit
} else if sort == "date-desc" { } else if sort == "date-desc" {
@ -631,7 +631,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlsort = " ORDER BY transactions.Date DESC" sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC"
balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits) balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits)
balanceLimitOffsetArg = (page + 1) * limit balanceLimitOffsetArg = (page + 1) * limit
} }
@ -676,7 +676,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
// Sum all the splits for all transaction splits for this account that // Sum all the splits for all transaction splits for this account that
// occurred before the page we're returning // occurred before the page we're returning
var amounts []string var amounts []string
sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?"
_, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) _, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -87,8 +87,8 @@ func ensureTransactionsMatch(t *testing.T, expected, tran *handlers.Transaction,
if tran.Description != expected.Description { if tran.Description != expected.Description {
t.Errorf("Description doesn't match") t.Errorf("Description doesn't match")
} }
if tran.Date != expected.Date { if !tran.Date.Equal(expected.Date) {
t.Errorf("Date doesn't match") t.Errorf("Date (%+v) differs from expected (%+v)", tran.Date, expected.Date)
} }
if len(tran.Splits) != len(expected.Splits) { if len(tran.Splits) != len(expected.Splits) {

65
internal/handlers/tx.go Normal file
View File

@ -0,0 +1,65 @@
package handlers
import (
"database/sql"
"gopkg.in/gorp.v1"
"strings"
)
type Tx struct {
Dialect gorp.Dialect
Tx *gorp.Transaction
}
func (tx *Tx) Rebind(query string) string {
chunks := strings.Split(query, "?")
str := chunks[0]
for i := 1; i < len(chunks); i++ {
str += tx.Dialect.BindVar(i-1) + chunks[i]
}
return str
}
func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
return tx.Tx.Select(i, tx.Rebind(query), args...)
}
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.Tx.Exec(tx.Rebind(query), args...)
}
func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) {
return tx.Tx.SelectInt(tx.Rebind(query), args...)
}
func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error {
return tx.Tx.SelectOne(holder, tx.Rebind(query), args...)
}
func (tx *Tx) Insert(list ...interface{}) error {
return tx.Tx.Insert(list...)
}
func (tx *Tx) Update(list ...interface{}) (int64, error) {
return tx.Tx.Update(list...)
}
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
return tx.Tx.Delete(list...)
}
func (tx *Tx) Commit() error {
return tx.Tx.Commit()
}
func (tx *Tx) Rollback() error {
return tx.Tx.Rollback()
}
func GetTx(db *gorp.DbMap) (*Tx, error) {
tx, err := db.Begin()
if err != nil {
return nil, err
}
return &Tx{db.Dialect, tx}, nil
}