backend: Add ability to get Transactions by Account

This commit is contained in:
Aaron Lindsay 2015-07-11 08:58:36 -04:00
parent 73536302bf
commit be57d44ffe
3 changed files with 247 additions and 14 deletions

View File

@ -3,8 +3,10 @@ package main
import (
"encoding/json"
"errors"
"gopkg.in/gorp.v1"
"log"
"net/http"
"regexp"
"strings"
)
@ -25,12 +27,23 @@ type Account struct {
ParentAccountId int64 // -1 if this account is at the root
Type int64
Name string
// monotonically-increasing account transaction version number. Used for
// allowing a client to ensure they have a consistent version when paging
// through transactions.
Version int64
}
type AccountList struct {
Accounts *[]Account `json:"accounts"`
}
var accountTransactionsRE *regexp.Regexp
func init() {
accountTransactionsRE = regexp.MustCompile(`^/account/[0-9]+/transactions/?$`)
}
func (a *Account) Write(w http.ResponseWriter) error {
enc := json.NewEncoder(w)
return enc.Encode(a)
@ -56,6 +69,17 @@ func GetAccount(accountid int64, userid int64) (*Account, error) {
return &a, nil
}
func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64) (*Account, error) {
var a Account
err := transaction.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
if err != nil {
return nil, err
}
return &a, nil
}
func GetAccounts(userid int64) (*[]Account, error) {
var accounts []Account
@ -97,6 +121,14 @@ func insertUpdateAccount(a *Account, insert bool) error {
return err
}
} else {
oldacct, err := GetAccountTx(transaction, a.AccountId, a.UserId)
if err != nil {
transaction.Rollback()
return err
}
a.Version = oldacct.Version + 1
count, err := transaction.Update(a)
if err != nil {
transaction.Rollback()
@ -195,6 +227,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
}
account.AccountId = -1
account.UserId = user.UserId
account.Version = 0
if GetSecurity(account.SecurityId) == nil {
WriteError(w, 3 /*Invalid Request*/)
@ -214,8 +247,10 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
WriteSuccess(w)
} else if r.Method == "GET" {
accountid, err := GetURLID(r.URL.Path)
if err != nil {
var accountid int64
n, err := GetURLPieces(r.URL.Path, "/account/%d", &accountid)
if err != nil || n != 1 {
//Return all Accounts
var al AccountList
accounts, err := GetAccounts(user.UserId)
@ -232,12 +267,20 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
return
}
} else {
// if URL looks like /account/[0-9]+/transactions, use the account
// transaction handler
if accountTransactionsRE.MatchString(r.URL.Path) {
AccountTransactionsHandler(w, r, user, accountid)
return
}
// Return Account with this Id
account, err := GetAccount(accountid, user.UserId)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
}
err = account.Write(w)
if err != nil {
WriteError(w, 999 /*Internal Error*/)

View File

@ -3,9 +3,13 @@ package main
import (
"encoding/json"
"errors"
"fmt"
"gopkg.in/gorp.v1"
"log"
"math/big"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
@ -56,6 +60,11 @@ type TransactionList struct {
Transactions *[]Transaction `json:"transactions"`
}
type AccountTransactionsList struct {
Account *Account `json:"account"`
Transactions *[]Transaction `json:"transactions"`
}
func (t *Transaction) Write(w http.ResponseWriter) error {
enc := json.NewEncoder(w)
return enc.Encode(t)
@ -71,6 +80,11 @@ func (tl *TransactionList) Write(w http.ResponseWriter) error {
return enc.Encode(tl)
}
func (atl *AccountTransactionsList) Write(w http.ResponseWriter) error {
enc := json.NewEncoder(w)
return enc.Encode(atl)
}
func (t *Transaction) Valid() bool {
for i := range t.Splits {
if !t.Splits[i].Valid() {
@ -152,18 +166,38 @@ func GetTransactions(userid int64) (*[]Transaction, error) {
return &transactions, nil
}
func incrementAccountVersions(transaction *gorp.Transaction, user *User, accountids []int64) error {
for i := range accountids {
account, err := GetAccountTx(transaction, accountids[i], user.UserId)
if err != nil {
return err
}
account.Version++
count, err := transaction.Update(account)
if err != nil {
return err
}
if count != 1 {
return errors.New("Updated more than one account")
}
}
return nil
}
type AccountMissingError struct{}
func (ame AccountMissingError) Error() string {
return "Account missing"
}
func InsertTransaction(t *Transaction) error {
func InsertTransaction(t *Transaction, user *User) error {
transaction, err := DB.Begin()
if err != nil {
return err
}
// Map of any accounts with transaction splits being added
a_map := make(map[int64]bool)
for i := range t.Splits {
existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId)
if err != nil {
@ -174,6 +208,18 @@ func InsertTransaction(t *Transaction) error {
transaction.Rollback()
return AccountMissingError{}
}
a_map[t.Splits[i].AccountId] = true
}
//increment versions for all accounts
var a_ids []int64
for id := range a_map {
a_ids = append(a_ids, id)
}
err = incrementAccountVersions(transaction, user, a_ids)
if err != nil {
transaction.Rollback()
return err
}
err = transaction.Insert(t)
@ -201,7 +247,7 @@ func InsertTransaction(t *Transaction) error {
return nil
}
func UpdateTransaction(t *Transaction) error {
func UpdateTransaction(t *Transaction, user *User) error {
transaction, err := DB.Begin()
if err != nil {
return err
@ -215,16 +261,19 @@ func UpdateTransaction(t *Transaction) error {
return err
}
// Map of any accounts with transaction splits being added
a_map := make(map[int64]bool)
// Make a map with any existing splits for this transaction
m := make(map[int64]int64)
s_map := make(map[int64]bool)
for i := range existing_splits {
m[existing_splits[i].SplitId] = existing_splits[i].SplitId
s_map[existing_splits[i].SplitId] = true
}
// Insert splits, updating any pre-existing ones
for i := range t.Splits {
t.Splits[i].TransactionId = t.TransactionId
_, ok := m[t.Splits[i].SplitId]
_, ok := s_map[t.Splits[i].SplitId]
if ok {
count, err := transaction.Update(t.Splits[i])
if err != nil {
@ -243,13 +292,15 @@ func UpdateTransaction(t *Transaction) error {
return err
}
}
a_map[t.Splits[i].AccountId] = true
}
// Delete any remaining pre-existing splits
for i := range existing_splits {
s, ok := m[existing_splits[i].SplitId]
_, ok := s_map[existing_splits[i].SplitId]
a_map[existing_splits[i].AccountId] = true
if ok {
_, err := transaction.Delete(s)
_, err := transaction.Delete(existing_splits[i])
if err != nil {
transaction.Rollback()
return err
@ -257,6 +308,17 @@ func UpdateTransaction(t *Transaction) error {
}
}
// Increment versions for all accounts with modified splits
var a_ids []int64
for id := range a_map {
a_ids = append(a_ids, id)
}
err = incrementAccountVersions(transaction, user, a_ids)
if err != nil {
transaction.Rollback()
return err
}
count, err := transaction.Update(t)
if err != nil {
transaction.Rollback()
@ -276,13 +338,20 @@ func UpdateTransaction(t *Transaction) error {
return nil
}
func DeleteTransaction(t *Transaction) error {
func DeleteTransaction(t *Transaction, user *User) error {
transaction, err := DB.Begin()
if err != nil {
return err
}
_, err = transaction.Exec("DELETE from splits where TransactionId=?", t.TransactionId)
var accountids []int64
_, err = transaction.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=?", t.TransactionId)
if err != nil {
transaction.Rollback()
return err
}
_, err = transaction.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId)
if err != nil {
transaction.Rollback()
return err
@ -298,6 +367,12 @@ func DeleteTransaction(t *Transaction) error {
return errors.New("Deleted more than one transaction")
}
err = incrementAccountVersions(transaction, user, accountids)
if err != nil {
transaction.Rollback()
return err
}
err = transaction.Commit()
if err != nil {
transaction.Rollback()
@ -344,7 +419,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
}
}
err = InsertTransaction(&transaction)
err = InsertTransaction(&transaction, user)
if err != nil {
if _, ok := err.(AccountMissingError); ok {
WriteError(w, 3 /*Invalid Request*/)
@ -358,6 +433,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
WriteSuccess(w)
} else if r.Method == "GET" {
transactionid, err := GetURLID(r.URL.Path)
if err != nil {
//Return all Transactions
var al TransactionList
@ -423,7 +499,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
}
}
err = UpdateTransaction(&transaction)
err = UpdateTransaction(&transaction, user)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -444,7 +520,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
return
}
err = DeleteTransaction(transaction)
err = DeleteTransaction(transaction, user)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
@ -455,3 +531,111 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) {
}
}
}
func GetAccountTransactions(user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
var transactions []Transaction
var atl AccountTransactionsList
var sqlsort string
if sort == "date-asc" {
sqlsort = " ORDER BY transactions.Date ASC"
} else if sort == "date-desc" {
sqlsort = " ORDER BY transactions.Date DESC"
}
var sqloffset string
if page > 0 {
sqloffset = fmt.Sprintf(" OFFSET %d", page*limit)
}
transaction, err := DB.Begin()
if err != nil {
return nil, err
}
account, err := GetAccountTx(transaction, accountid, user.UserId)
if err != nil {
transaction.Rollback()
return nil, err
}
atl.Account = account
sql := "SELECT transactions.* from transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset
_, err = transaction.Select(&transactions, sql, user.UserId, accountid, limit)
if err != nil {
transaction.Rollback()
return nil, err
}
atl.Transactions = &transactions
for i := range transactions {
_, err = transaction.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId)
if err != nil {
transaction.Rollback()
return nil, err
}
}
err = transaction.Commit()
if err != nil {
transaction.Rollback()
return nil, err
}
return &atl, nil
}
// Return only those transactions which have at least one split pertaining to
// an account
func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request,
user *User, accountid int64) {
var page uint64 = 0
var limit uint64 = 50
var sort string = "date-desc"
query, _ := url.ParseQuery(r.URL.RawQuery)
pagestring := query.Get("page")
if pagestring != "" {
p, err := strconv.ParseUint(pagestring, 10, 0)
if err != nil {
WriteError(w, 3 /*Invalid Request*/)
return
}
page = p
}
limitstring := query.Get("limit")
if limitstring != "" {
l, err := strconv.ParseUint(limitstring, 10, 0)
if err != nil || l > 100 {
WriteError(w, 3 /*Invalid Request*/)
return
}
limit = l
}
sortstring := query.Get("sort")
if sortstring != "" {
if sortstring != "date-asc" && sortstring != "date-desc" {
WriteError(w, 3 /*Invalid Request*/)
return
}
sort = sortstring
}
accountTransactions, err := GetAccountTransactions(user, accountid, sort, page, limit)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
err = accountTransactions.Write(w)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
}

View File

@ -12,6 +12,12 @@ func GetURLID(url string) (int64, error) {
return strconv.ParseInt(pieces[len(pieces)-1], 10, 0)
}
func GetURLPieces(url string, format string, a ...interface{}) (int, error) {
url = strings.Replace(url, "/", " ", -1)
format = strings.Replace(format, "/", " ", -1)
return fmt.Sscanf(url, format, a...)
}
func WriteSuccess(w http.ResponseWriter) {
fmt.Fprint(w, "{}")
}