diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index bf89ae9..d85ea65 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -127,6 +127,11 @@ func (atl *AccountTransactionsList) Write(w http.ResponseWriter) error { return enc.Encode(atl) } +func (atl *AccountTransactionsList) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(atl) +} + func (t *Transaction) Valid() bool { for i := range t.Splits { if !t.Splits[i].Valid() { @@ -686,8 +691,8 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa // Sum all the splits for all transaction splits for this account that // occurred before the page we're returning var amounts []string - sql = "SELECT splits.Amount FROM splits WHERE splits.AccountId=? AND splits.TransactionId IN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ")" - _, err = tx.Select(&amounts, sql, accountid, user.UserId, accountid, balanceLimitOffsetArg) + sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" + _, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) if err != nil { return nil, err } diff --git a/internal/handlers/transactions_test.go b/internal/handlers/transactions_test.go index fac64eb..76ba941 100644 --- a/internal/handlers/transactions_test.go +++ b/internal/handlers/transactions_test.go @@ -1,8 +1,10 @@ package handlers_test import ( + "fmt" "github.com/aclindsa/moneygo/internal/handlers" "net/http" + "net/url" "strconv" "testing" "time" @@ -32,6 +34,29 @@ func getTransactions(client *http.Client) (*handlers.TransactionList, error) { return &tl, nil } +func getAccountTransactions(client *http.Client, accountid, page, limit int64, sort string) (*handlers.AccountTransactionsList, error) { + var atl handlers.AccountTransactionsList + params := url.Values{} + + query := fmt.Sprintf("/account/%d/transactions/", accountid) + if page != 0 { + params.Set("page", fmt.Sprintf("%d", page)) + } + if limit != 0 { + params.Set("limit", fmt.Sprintf("%d", limit)) + } + if len(sort) != 0 { + params.Set("sort", sort) + query += "?" + params.Encode() + } + + err := read(client, &atl, query, "accounttransactions") + if err != nil { + return nil, err + } + return &atl, nil +} + func updateTransaction(client *http.Client, transaction *handlers.Transaction) (*handlers.Transaction, error) { var s handlers.Transaction err := update(client, transaction, &s, "/transaction/"+strconv.FormatInt(transaction.TransactionId, 10), "transaction") @@ -325,3 +350,69 @@ func TestDeleteTransaction(t *testing.T) { } }) } + +func TestAccountTransactions(t *testing.T) { + RunWith(t, &data[0], func(t *testing.T, d *TestData) { + for _, account := range d.accounts { + if account.UserId != d.users[0].UserId { + continue + } + atl, err := getAccountTransactions(d.clients[0], account.AccountId, 0, 0, "date-desc") + if err != nil { + t.Fatalf("Error fetching account transactions: %s\n", err) + } + + numtransactions := 0 + foundIds := make(map[int64]bool) + for i := 0; i < len(d.transactions); i++ { + curr := d.transactions[i] + + if curr.UserId != d.users[0].UserId { + continue + } + + // Don't consider this transaction if we didn't find a split + // for the account we're considering + account_found := false + for _, s := range curr.Splits { + if s.AccountId == account.AccountId { + account_found = true + break + } + } + if !account_found { + continue + } + + numtransactions += 1 + + found := false + if atl.Transactions != nil { + for _, tran := range *atl.Transactions { + if tran.TransactionId == curr.TransactionId { + ensureTransactionsMatch(t, &curr, &tran, nil, true, true) + if _, ok := foundIds[tran.TransactionId]; ok { + continue + } + foundIds[tran.TransactionId] = true + found = true + break + } + } + } + if !found { + t.Errorf("Unable to find matching transaction: %+v", curr) + t.Errorf("Transactions: %+v\n", atl.Transactions) + } + } + + if atl.Transactions == nil { + if numtransactions != 0 { + t.Fatalf("Expected %d transactions, received 0", numtransactions) + } + } else if numtransactions != len(*atl.Transactions) { + t.Fatalf("Expected %d transactions, received %d", numtransactions, len(*atl.Transactions)) + } + } + }) +}