From be57d44ffe6b8ff9f434c759d6f2980fd8dbb41f Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sat, 11 Jul 2015 08:58:36 -0400 Subject: [PATCH] backend: Add ability to get Transactions by Account --- accounts.go | 47 ++++++++++- transactions.go | 208 +++++++++++++++++++++++++++++++++++++++++++++--- util.go | 6 ++ 3 files changed, 247 insertions(+), 14 deletions(-) diff --git a/accounts.go b/accounts.go index dc2665c..2cb578f 100644 --- a/accounts.go +++ b/accounts.go @@ -3,8 +3,10 @@ package main import ( "encoding/json" "errors" + "gopkg.in/gorp.v1" "log" "net/http" + "regexp" "strings" ) @@ -25,12 +27,23 @@ type Account struct { ParentAccountId int64 // -1 if this account is at the root Type int64 Name string + + // monotonically-increasing account transaction version number. Used for + // allowing a client to ensure they have a consistent version when paging + // through transactions. + Version int64 } type AccountList struct { Accounts *[]Account `json:"accounts"` } +var accountTransactionsRE *regexp.Regexp + +func init() { + accountTransactionsRE = regexp.MustCompile(`^/account/[0-9]+/transactions/?$`) +} + func (a *Account) Write(w http.ResponseWriter) error { enc := json.NewEncoder(w) return enc.Encode(a) @@ -56,6 +69,17 @@ func GetAccount(accountid int64, userid int64) (*Account, error) { return &a, nil } +func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64) (*Account, error) { + var a Account + + err := transaction.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) + if err != nil { + return nil, err + } + + return &a, nil +} + func GetAccounts(userid int64) (*[]Account, error) { var accounts []Account @@ -97,6 +121,14 @@ func insertUpdateAccount(a *Account, insert bool) error { return err } } else { + oldacct, err := GetAccountTx(transaction, a.AccountId, a.UserId) + if err != nil { + transaction.Rollback() + return err + } + + a.Version = oldacct.Version + 1 + count, err := transaction.Update(a) if err != nil { transaction.Rollback() @@ -195,6 +227,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { } account.AccountId = -1 account.UserId = user.UserId + account.Version = 0 if GetSecurity(account.SecurityId) == nil { WriteError(w, 3 /*Invalid Request*/) @@ -214,8 +247,10 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { WriteSuccess(w) } else if r.Method == "GET" { - accountid, err := GetURLID(r.URL.Path) - if err != nil { + var accountid int64 + n, err := GetURLPieces(r.URL.Path, "/account/%d", &accountid) + + if err != nil || n != 1 { //Return all Accounts var al AccountList accounts, err := GetAccounts(user.UserId) @@ -232,12 +267,20 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { return } } else { + // if URL looks like /account/[0-9]+/transactions, use the account + // transaction handler + if accountTransactionsRE.MatchString(r.URL.Path) { + AccountTransactionsHandler(w, r, user, accountid) + return + } + // Return Account with this Id account, err := GetAccount(accountid, user.UserId) if err != nil { WriteError(w, 3 /*Invalid Request*/) return } + err = account.Write(w) if err != nil { WriteError(w, 999 /*Internal Error*/) diff --git a/transactions.go b/transactions.go index 03f8304..93ac4de 100644 --- a/transactions.go +++ b/transactions.go @@ -3,9 +3,13 @@ package main import ( "encoding/json" "errors" + "fmt" + "gopkg.in/gorp.v1" "log" "math/big" "net/http" + "net/url" + "strconv" "strings" "time" ) @@ -56,6 +60,11 @@ type TransactionList struct { Transactions *[]Transaction `json:"transactions"` } +type AccountTransactionsList struct { + Account *Account `json:"account"` + Transactions *[]Transaction `json:"transactions"` +} + func (t *Transaction) Write(w http.ResponseWriter) error { enc := json.NewEncoder(w) return enc.Encode(t) @@ -71,6 +80,11 @@ func (tl *TransactionList) Write(w http.ResponseWriter) error { return enc.Encode(tl) } +func (atl *AccountTransactionsList) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(atl) +} + func (t *Transaction) Valid() bool { for i := range t.Splits { if !t.Splits[i].Valid() { @@ -152,18 +166,38 @@ func GetTransactions(userid int64) (*[]Transaction, error) { return &transactions, nil } +func incrementAccountVersions(transaction *gorp.Transaction, user *User, accountids []int64) error { + for i := range accountids { + account, err := GetAccountTx(transaction, accountids[i], user.UserId) + if err != nil { + return err + } + account.Version++ + count, err := transaction.Update(account) + if err != nil { + return err + } + if count != 1 { + return errors.New("Updated more than one account") + } + } + return nil +} + type AccountMissingError struct{} func (ame AccountMissingError) Error() string { return "Account missing" } -func InsertTransaction(t *Transaction) error { +func InsertTransaction(t *Transaction, user *User) error { transaction, err := DB.Begin() if err != nil { return err } + // Map of any accounts with transaction splits being added + a_map := make(map[int64]bool) for i := range t.Splits { existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) if err != nil { @@ -174,6 +208,18 @@ func InsertTransaction(t *Transaction) error { transaction.Rollback() return AccountMissingError{} } + a_map[t.Splits[i].AccountId] = true + } + + //increment versions for all accounts + var a_ids []int64 + for id := range a_map { + a_ids = append(a_ids, id) + } + err = incrementAccountVersions(transaction, user, a_ids) + if err != nil { + transaction.Rollback() + return err } err = transaction.Insert(t) @@ -201,7 +247,7 @@ func InsertTransaction(t *Transaction) error { return nil } -func UpdateTransaction(t *Transaction) error { +func UpdateTransaction(t *Transaction, user *User) error { transaction, err := DB.Begin() if err != nil { return err @@ -215,16 +261,19 @@ func UpdateTransaction(t *Transaction) error { return err } + // Map of any accounts with transaction splits being added + a_map := make(map[int64]bool) + // Make a map with any existing splits for this transaction - m := make(map[int64]int64) + s_map := make(map[int64]bool) for i := range existing_splits { - m[existing_splits[i].SplitId] = existing_splits[i].SplitId + s_map[existing_splits[i].SplitId] = true } // Insert splits, updating any pre-existing ones for i := range t.Splits { t.Splits[i].TransactionId = t.TransactionId - _, ok := m[t.Splits[i].SplitId] + _, ok := s_map[t.Splits[i].SplitId] if ok { count, err := transaction.Update(t.Splits[i]) if err != nil { @@ -243,13 +292,15 @@ func UpdateTransaction(t *Transaction) error { return err } } + a_map[t.Splits[i].AccountId] = true } // Delete any remaining pre-existing splits for i := range existing_splits { - s, ok := m[existing_splits[i].SplitId] + _, ok := s_map[existing_splits[i].SplitId] + a_map[existing_splits[i].AccountId] = true if ok { - _, err := transaction.Delete(s) + _, err := transaction.Delete(existing_splits[i]) if err != nil { transaction.Rollback() return err @@ -257,6 +308,17 @@ func UpdateTransaction(t *Transaction) error { } } + // Increment versions for all accounts with modified splits + var a_ids []int64 + for id := range a_map { + a_ids = append(a_ids, id) + } + err = incrementAccountVersions(transaction, user, a_ids) + if err != nil { + transaction.Rollback() + return err + } + count, err := transaction.Update(t) if err != nil { transaction.Rollback() @@ -276,13 +338,20 @@ func UpdateTransaction(t *Transaction) error { return nil } -func DeleteTransaction(t *Transaction) error { +func DeleteTransaction(t *Transaction, user *User) error { transaction, err := DB.Begin() if err != nil { return err } - _, err = transaction.Exec("DELETE from splits where TransactionId=?", t.TransactionId) + var accountids []int64 + _, err = transaction.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=?", t.TransactionId) + if err != nil { + transaction.Rollback() + return err + } + + _, err = transaction.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) if err != nil { transaction.Rollback() return err @@ -298,6 +367,12 @@ func DeleteTransaction(t *Transaction) error { return errors.New("Deleted more than one transaction") } + err = incrementAccountVersions(transaction, user, accountids) + if err != nil { + transaction.Rollback() + return err + } + err = transaction.Commit() if err != nil { transaction.Rollback() @@ -344,7 +419,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) { } } - err = InsertTransaction(&transaction) + err = InsertTransaction(&transaction, user) if err != nil { if _, ok := err.(AccountMissingError); ok { WriteError(w, 3 /*Invalid Request*/) @@ -358,6 +433,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) { WriteSuccess(w) } else if r.Method == "GET" { transactionid, err := GetURLID(r.URL.Path) + if err != nil { //Return all Transactions var al TransactionList @@ -423,7 +499,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) { } } - err = UpdateTransaction(&transaction) + err = UpdateTransaction(&transaction, user) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -444,7 +520,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) { return } - err = DeleteTransaction(transaction) + err = DeleteTransaction(transaction, user) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -455,3 +531,111 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) { } } } + +func GetAccountTransactions(user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { + var transactions []Transaction + var atl AccountTransactionsList + + var sqlsort string + if sort == "date-asc" { + sqlsort = " ORDER BY transactions.Date ASC" + } else if sort == "date-desc" { + sqlsort = " ORDER BY transactions.Date DESC" + } + + var sqloffset string + if page > 0 { + sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) + } + + transaction, err := DB.Begin() + if err != nil { + return nil, err + } + + account, err := GetAccountTx(transaction, accountid, user.UserId) + if err != nil { + transaction.Rollback() + return nil, err + } + atl.Account = account + + sql := "SELECT transactions.* from transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset + _, err = transaction.Select(&transactions, sql, user.UserId, accountid, limit) + if err != nil { + transaction.Rollback() + return nil, err + } + atl.Transactions = &transactions + + for i := range transactions { + _, err = transaction.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) + if err != nil { + transaction.Rollback() + return nil, err + } + } + + err = transaction.Commit() + if err != nil { + transaction.Rollback() + return nil, err + } + + return &atl, nil +} + +// Return only those transactions which have at least one split pertaining to +// an account +func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request, + user *User, accountid int64) { + + var page uint64 = 0 + var limit uint64 = 50 + var sort string = "date-desc" + + query, _ := url.ParseQuery(r.URL.RawQuery) + + pagestring := query.Get("page") + if pagestring != "" { + p, err := strconv.ParseUint(pagestring, 10, 0) + if err != nil { + WriteError(w, 3 /*Invalid Request*/) + return + } + page = p + } + + limitstring := query.Get("limit") + if limitstring != "" { + l, err := strconv.ParseUint(limitstring, 10, 0) + if err != nil || l > 100 { + WriteError(w, 3 /*Invalid Request*/) + return + } + limit = l + } + + sortstring := query.Get("sort") + if sortstring != "" { + if sortstring != "date-asc" && sortstring != "date-desc" { + WriteError(w, 3 /*Invalid Request*/) + return + } + sort = sortstring + } + + accountTransactions, err := GetAccountTransactions(user, accountid, sort, page, limit) + if err != nil { + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + + err = accountTransactions.Write(w) + if err != nil { + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } +} diff --git a/util.go b/util.go index d25ec55..1627920 100644 --- a/util.go +++ b/util.go @@ -12,6 +12,12 @@ func GetURLID(url string) (int64, error) { return strconv.ParseInt(pieces[len(pieces)-1], 10, 0) } +func GetURLPieces(url string, format string, a ...interface{}) (int, error) { + url = strings.Replace(url, "/", " ", -1) + format = strings.Replace(format, "/", " ", -1) + return fmt.Sscanf(url, format, a...) +} + func WriteSuccess(w http.ResponseWriter) { fmt.Fprint(w, "{}") }