mirror of
https://github.com/aclindsa/moneygo.git
synced 2024-12-25 23:23:21 -05:00
commit
144c89655a
@ -14,7 +14,7 @@ services:
|
||||
env:
|
||||
- MONEYGO_TEST_DB=sqlite
|
||||
- MONEYGO_TEST_DB=mysql MONEYGO_TEST_DSN="root@tcp(127.0.0.1)/moneygo_test?parseTime=true"
|
||||
# - MONEYGO_TEST_DB=postgres MONEYGO_TEST_DSN="postgres://postgres@localhost/moneygo_test"
|
||||
- MONEYGO_TEST_DB=postgres MONEYGO_TEST_DSN="postgres://postgres@localhost/moneygo_test"
|
||||
|
||||
before_script:
|
||||
- sh -c "if [ $MONEYGO_TEST_DB = 'postgres' ]; then psql -c 'DROP DATABASE IF EXISTS moneygo_test;' -U postgres; fi"
|
||||
|
@ -14,8 +14,6 @@ type ResponseWriterWriter interface {
|
||||
Write(http.ResponseWriter) error
|
||||
}
|
||||
|
||||
type Tx = gorp.Transaction
|
||||
|
||||
type Context struct {
|
||||
Tx *Tx
|
||||
User *User
|
||||
@ -51,7 +49,7 @@ type APIHandler struct {
|
||||
}
|
||||
|
||||
func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
|
||||
tx, err := ah.DB.Begin()
|
||||
tx, err := GetTx(ah.DB)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return NewError(999 /*Internal Error*/)
|
||||
|
@ -64,7 +64,7 @@ func TestCreatePrice(t *testing.T) {
|
||||
if p.CurrencyId != d.securities[orig.CurrencyId].SecurityId {
|
||||
t.Errorf("CurrencyId doesn't match")
|
||||
}
|
||||
if p.Date != orig.Date {
|
||||
if !p.Date.Equal(orig.Date) {
|
||||
t.Errorf("Date doesn't match")
|
||||
}
|
||||
if p.Value != orig.Value {
|
||||
@ -94,7 +94,7 @@ func TestGetPrice(t *testing.T) {
|
||||
if p.CurrencyId != d.securities[orig.CurrencyId].SecurityId {
|
||||
t.Errorf("CurrencyId doesn't match")
|
||||
}
|
||||
if p.Date != orig.Date {
|
||||
if !p.Date.Equal(orig.Date) {
|
||||
t.Errorf("Date doesn't match")
|
||||
}
|
||||
if p.Value != orig.Value {
|
||||
@ -131,7 +131,7 @@ func TestGetPrices(t *testing.T) {
|
||||
|
||||
found := false
|
||||
for _, p := range *pl.Prices {
|
||||
if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date == orig.Date && p.Value == orig.Value && p.RemoteId == orig.RemoteId {
|
||||
if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date.Equal(orig.Date) && p.Value == orig.Value && p.RemoteId == orig.RemoteId {
|
||||
if _, ok := foundIds[p.PriceId]; ok {
|
||||
continue
|
||||
}
|
||||
@ -177,7 +177,7 @@ func TestUpdatePrice(t *testing.T) {
|
||||
if p.CurrencyId != curr.CurrencyId {
|
||||
t.Errorf("CurrencyId doesn't match")
|
||||
}
|
||||
if p.Date != curr.Date {
|
||||
if !p.Date.Equal(curr.Date) {
|
||||
t.Errorf("Date doesn't match")
|
||||
}
|
||||
if p.Value != curr.Value {
|
||||
|
@ -623,7 +623,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
|
||||
var sqlsort, balanceLimitOffset string
|
||||
var balanceLimitOffsetArg uint64
|
||||
if sort == "date-asc" {
|
||||
sqlsort = " ORDER BY transactions.Date ASC"
|
||||
sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC"
|
||||
balanceLimitOffset = " LIMIT ?"
|
||||
balanceLimitOffsetArg = page * limit
|
||||
} else if sort == "date-desc" {
|
||||
@ -631,7 +631,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlsort = " ORDER BY transactions.Date DESC"
|
||||
sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC"
|
||||
balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits)
|
||||
balanceLimitOffsetArg = (page + 1) * limit
|
||||
}
|
||||
@ -676,7 +676,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
|
||||
// Sum all the splits for all transaction splits for this account that
|
||||
// occurred before the page we're returning
|
||||
var amounts []string
|
||||
sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?"
|
||||
sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?"
|
||||
_, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -87,8 +87,8 @@ func ensureTransactionsMatch(t *testing.T, expected, tran *handlers.Transaction,
|
||||
if tran.Description != expected.Description {
|
||||
t.Errorf("Description doesn't match")
|
||||
}
|
||||
if tran.Date != expected.Date {
|
||||
t.Errorf("Date doesn't match")
|
||||
if !tran.Date.Equal(expected.Date) {
|
||||
t.Errorf("Date (%+v) differs from expected (%+v)", tran.Date, expected.Date)
|
||||
}
|
||||
|
||||
if len(tran.Splits) != len(expected.Splits) {
|
||||
|
65
internal/handlers/tx.go
Normal file
65
internal/handlers/tx.go
Normal file
@ -0,0 +1,65 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"gopkg.in/gorp.v1"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Tx struct {
|
||||
Dialect gorp.Dialect
|
||||
Tx *gorp.Transaction
|
||||
}
|
||||
|
||||
func (tx *Tx) Rebind(query string) string {
|
||||
chunks := strings.Split(query, "?")
|
||||
str := chunks[0]
|
||||
for i := 1; i < len(chunks); i++ {
|
||||
str += tx.Dialect.BindVar(i-1) + chunks[i]
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
|
||||
return tx.Tx.Select(i, tx.Rebind(query), args...)
|
||||
}
|
||||
|
||||
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return tx.Tx.Exec(tx.Rebind(query), args...)
|
||||
}
|
||||
|
||||
func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) {
|
||||
return tx.Tx.SelectInt(tx.Rebind(query), args...)
|
||||
}
|
||||
|
||||
func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error {
|
||||
return tx.Tx.SelectOne(holder, tx.Rebind(query), args...)
|
||||
}
|
||||
|
||||
func (tx *Tx) Insert(list ...interface{}) error {
|
||||
return tx.Tx.Insert(list...)
|
||||
}
|
||||
|
||||
func (tx *Tx) Update(list ...interface{}) (int64, error) {
|
||||
return tx.Tx.Update(list...)
|
||||
}
|
||||
|
||||
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
|
||||
return tx.Tx.Delete(list...)
|
||||
}
|
||||
|
||||
func (tx *Tx) Commit() error {
|
||||
return tx.Tx.Commit()
|
||||
}
|
||||
|
||||
func (tx *Tx) Rollback() error {
|
||||
return tx.Tx.Rollback()
|
||||
}
|
||||
|
||||
func GetTx(db *gorp.DbMap) (*Tx, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{db.Dialect, tx}, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user