From 3326c3b29284ded6ffd7b8ca61660077bdb5ae7f Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Thu, 7 Dec 2017 20:47:55 -0500 Subject: [PATCH] Move accounts to store --- internal/handlers/accounts.go | 161 +++--------------------------- internal/handlers/accounts_lua.go | 4 +- internal/handlers/gnucash_test.go | 18 ++-- internal/handlers/imports.go | 4 +- internal/handlers/ofx_test.go | 2 +- internal/handlers/securities.go | 5 +- internal/handlers/transactions.go | 10 +- internal/models/accounts.go | 2 +- internal/store/db/accounts.go | 133 ++++++++++++++++++++++++ internal/store/db/securities.go | 17 +--- internal/store/store.go | 38 ++++++- 11 files changed, 211 insertions(+), 183 deletions(-) create mode 100644 internal/store/db/accounts.go diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index d5ddcd8..60a2199 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -3,44 +3,23 @@ package handlers import ( "errors" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" ) -func GetAccount(tx *db.Tx, accountid int64, userid int64) (*models.Account, error) { - var a models.Account - - err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) - if err != nil { - return nil, err - } - return &a, nil -} - -func GetAccounts(tx *db.Tx, userid int64) (*[]models.Account, error) { - var accounts []models.Account - - _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) - if err != nil { - return nil, err - } - return &accounts, nil -} - // Get (and attempt to create if it doesn't exist). Matches on UserId, // SecurityId, Type, Name, and ParentAccountId func GetCreateAccount(tx *db.Tx, a models.Account) (*models.Account, error) { - var accounts []models.Account var account models.Account - // Try to find the top-level trading account - _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId) + accounts, err := tx.FindMatchingAccounts(&a) if err != nil { return nil, err } - if len(accounts) == 1 { - account = accounts[0] + if len(*accounts) > 0 { + account = *(*accounts)[0] } else { account.UserId = a.UserId account.SecurityId = a.SecurityId @@ -143,120 +122,6 @@ func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Acc return a, nil } -type ParentAccountMissingError struct{} - -func (pame ParentAccountMissingError) Error() string { - return "Parent account missing" -} - -type TooMuchNestingError struct{} - -func (tmne TooMuchNestingError) Error() string { - return "Too much nesting" -} - -type CircularAccountsError struct{} - -func (cae CircularAccountsError) Error() string { - return "Would result in circular account relationship" -} - -func insertUpdateAccount(tx *db.Tx, a *models.Account, insert bool) error { - found := make(map[int64]bool) - if !insert { - found[a.AccountId] = true - } - parentid := a.ParentAccountId - depth := 0 - for parentid != -1 { - depth += 1 - if depth > 100 { - return TooMuchNestingError{} - } - - var a models.Account - err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) - if err != nil { - return ParentAccountMissingError{} - } - - // Insertion by itself can never result in circular dependencies - if insert { - break - } - - found[parentid] = true - parentid = a.ParentAccountId - if _, ok := found[parentid]; ok { - return CircularAccountsError{} - } - } - - if insert { - err := tx.Insert(a) - if err != nil { - return err - } - } else { - oldacct, err := GetAccount(tx, a.AccountId, a.UserId) - if err != nil { - return err - } - - a.AccountVersion = oldacct.AccountVersion + 1 - - count, err := tx.Update(a) - if err != nil { - return err - } - if count != 1 { - return errors.New("Updated more than one account") - } - } - - return nil -} - -func InsertAccount(tx *db.Tx, a *models.Account) error { - return insertUpdateAccount(tx, a, true) -} - -func UpdateAccount(tx *db.Tx, a *models.Account) error { - return insertUpdateAccount(tx, a, false) -} - -func DeleteAccount(tx *db.Tx, a *models.Account) error { - if a.ParentAccountId != -1 { - // Re-parent splits to this account's parent account if this account isn't a root account - _, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) - if err != nil { - return err - } - } else { - // Delete splits if this account is a root account - _, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId) - if err != nil { - return err - } - } - - // Re-parent child accounts to this account's parent account - _, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId) - if err != nil { - return err - } - - count, err := tx.Delete(a) - if err != nil { - return err - } - if count != 1 { - return errors.New("Was going to delete more than one account") - } - - return nil -} - func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { user, err := GetUserFromSession(context.Tx, r) if err != nil { @@ -289,9 +154,9 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - err = InsertAccount(context.Tx, &account) + err = context.Tx.InsertAccount(&account) if err != nil { - if _, ok := err.(ParentAccountMissingError); ok { + if _, ok := err.(store.ParentAccountMissingError); ok { return NewError(3 /*Invalid Request*/) } else { log.Print(err) @@ -304,7 +169,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { if context.LastLevel() { //Return all Accounts var al models.AccountList - accounts, err := GetAccounts(context.Tx, user.UserId) + accounts, err := context.Tx.GetAccounts(user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -320,7 +185,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { if context.LastLevel() { // Return Account with this Id - account, err := GetAccount(context.Tx, accountid, user.UserId) + account, err := context.Tx.GetAccount(accountid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -354,11 +219,11 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - err = UpdateAccount(context.Tx, &account) + err = context.Tx.UpdateAccount(&account) if err != nil { - if _, ok := err.(ParentAccountMissingError); ok { + if _, ok := err.(store.ParentAccountMissingError); ok { return NewError(3 /*Invalid Request*/) - } else if _, ok := err.(CircularAccountsError); ok { + } else if _, ok := err.(store.CircularAccountsError); ok { return NewError(3 /*Invalid Request*/) } else { log.Print(err) @@ -368,12 +233,12 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { return &account } else if r.Method == "DELETE" { - account, err := GetAccount(context.Tx, accountid, user.UserId) + account, err := context.Tx.GetAccount(accountid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - err = DeleteAccount(context.Tx, account) + err = context.Tx.DeleteAccount(account) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/accounts_lua.go b/internal/handlers/accounts_lua.go index ee91eb2..2036e62 100644 --- a/internal/handlers/accounts_lua.go +++ b/internal/handlers/accounts_lua.go @@ -29,14 +29,14 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) { return nil, errors.New("Couldn't find User in lua's Context") } - accounts, err := GetAccounts(tx, user.UserId) + accounts, err := tx.GetAccounts(user.UserId) if err != nil { return nil, err } account_map = make(map[int64]*models.Account) for i := range *accounts { - account_map[(*accounts)[i].AccountId] = &(*accounts)[i] + account_map[(*accounts)[i].AccountId] = (*accounts)[i] } ctx = context.WithValue(ctx, accountsContextKey, account_map) diff --git a/internal/handlers/gnucash_test.go b/internal/handlers/gnucash_test.go index 960078f..3bbd2df 100644 --- a/internal/handlers/gnucash_test.go +++ b/internal/handlers/gnucash_test.go @@ -38,13 +38,13 @@ func TestImportGnucash(t *testing.T) { } for i, account := range *accounts.Accounts { if account.Name == "Income" && account.Type == models.Income && account.ParentAccountId == -1 { - income = &(*accounts.Accounts)[i] + income = (*accounts.Accounts)[i] } else if account.Name == "Equity" && account.Type == models.Equity && account.ParentAccountId == -1 { - equity = &(*accounts.Accounts)[i] + equity = (*accounts.Accounts)[i] } else if account.Name == "Liabilities" && account.Type == models.Liability && account.ParentAccountId == -1 { - liabilities = &(*accounts.Accounts)[i] + liabilities = (*accounts.Accounts)[i] } else if account.Name == "Expenses" && account.Type == models.Expense && account.ParentAccountId == -1 { - expenses = &(*accounts.Accounts)[i] + expenses = (*accounts.Accounts)[i] } } if income == nil { @@ -61,15 +61,15 @@ func TestImportGnucash(t *testing.T) { } for i, account := range *accounts.Accounts { if account.Name == "Salary" && account.Type == models.Income && account.ParentAccountId == income.AccountId { - salary = &(*accounts.Accounts)[i] + salary = (*accounts.Accounts)[i] } else if account.Name == "Opening Balances" && account.Type == models.Equity && account.ParentAccountId == equity.AccountId { - openingbalances = &(*accounts.Accounts)[i] + openingbalances = (*accounts.Accounts)[i] } else if account.Name == "Credit Card" && account.Type == models.Liability && account.ParentAccountId == liabilities.AccountId { - creditcard = &(*accounts.Accounts)[i] + creditcard = (*accounts.Accounts)[i] } else if account.Name == "Groceries" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId { - groceries = &(*accounts.Accounts)[i] + groceries = (*accounts.Accounts)[i] } else if account.Name == "Cable" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId { - cable = &(*accounts.Accounts)[i] + cable = (*accounts.Accounts)[i] } } if salary == nil { diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index ea74042..78cee60 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -39,7 +39,7 @@ func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) } // Return Account with this Id - account, err := GetAccount(tx, accountid, user.UserId) + account, err := tx.GetAccount(accountid, user.UserId) if err != nil { log.Print(err) return NewError(3 /*Invalid Request*/) @@ -218,7 +218,7 @@ func OFXImportHandler(context *Context, r *http.Request, user *models.User, acco return NewError(3 /*Invalid Request*/) } - account, err := GetAccount(context.Tx, accountid, user.UserId) + account, err := context.Tx.GetAccount(accountid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/ofx_test.go b/internal/handlers/ofx_test.go index 4d78f25..af8712f 100644 --- a/internal/handlers/ofx_test.go +++ b/internal/handlers/ofx_test.go @@ -83,7 +83,7 @@ func findAccount(client *http.Client, name string, tipe models.AccountType, secu } for _, account := range *accounts.Accounts { if account.Name == name && account.Type == tipe && account.SecurityId == securityid { - return &account, nil + return account, nil } } return nil, fmt.Errorf("Unable to find account: \"%s\"", name) diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index ec9544b..7c720eb 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -5,6 +5,7 @@ package handlers import ( "errors" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" @@ -77,7 +78,7 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) return security, nil } - securities, err := tx.FindMatchingSecurities(userid, security) + securities, err := tx.FindMatchingSecurities(security) if err != nil { return nil, err } @@ -215,7 +216,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { } err = context.Tx.DeleteSecurity(security) - if _, ok := err.(db.SecurityInUseError); ok { + if _, ok := err.(store.SecurityInUseError); ok { return NewError(7 /*In Use Error*/) } else if err != nil { log.Print(err) diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 707d29a..9ba4952 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -32,7 +32,7 @@ func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.R if t.Splits[i].AccountId != -1 { var err error var account *models.Account - account, err = GetAccount(tx, t.Splits[i].AccountId, t.UserId) + account, err = tx.GetAccount(t.Splits[i].AccountId, t.UserId) if err != nil { return nil, err } @@ -100,7 +100,7 @@ func GetTransactions(tx *db.Tx, userid int64) (*[]models.Transaction, error) { func incrementAccountVersions(tx *db.Tx, user *models.User, accountids []int64) error { for i := range accountids { - account, err := GetAccount(tx, accountids[i], user.UserId) + account, err := tx.GetAccount(accountids[i], user.UserId) if err != nil { return err } @@ -297,7 +297,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter for i := range transaction.Splits { transaction.Splits[i].SplitId = -1 - _, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId) + _, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -371,7 +371,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter } for i := range transaction.Splits { - _, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId) + _, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -518,7 +518,7 @@ func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) } - account, err := GetAccount(tx, accountid, user.UserId) + account, err := tx.GetAccount(accountid, user.UserId) if err != nil { return nil, err } diff --git a/internal/models/accounts.go b/internal/models/accounts.go index fdfac98..7941ed7 100644 --- a/internal/models/accounts.go +++ b/internal/models/accounts.go @@ -94,7 +94,7 @@ type Account struct { } type AccountList struct { - Accounts *[]Account `json:"accounts"` + Accounts *[]*Account `json:"accounts"` } func (a *Account) Write(w http.ResponseWriter) error { diff --git a/internal/store/db/accounts.go b/internal/store/db/accounts.go new file mode 100644 index 0000000..07d0208 --- /dev/null +++ b/internal/store/db/accounts.go @@ -0,0 +1,133 @@ +package db + +import ( + "errors" + "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" +) + +func (tx *Tx) GetAccount(accountid int64, userid int64) (*models.Account, error) { + var account models.Account + + err := tx.SelectOne(&account, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) + if err != nil { + return nil, err + } + return &account, nil +} + +func (tx *Tx) GetAccounts(userid int64) (*[]*models.Account, error) { + var accounts []*models.Account + + _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) + if err != nil { + return nil, err + } + return &accounts, nil +} + +func (tx *Tx) FindMatchingAccounts(account *models.Account) (*[]*models.Account, error) { + var accounts []*models.Account + + _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC", account.UserId, account.SecurityId, account.Type, account.Name, account.ParentAccountId) + if err != nil { + return nil, err + } + return &accounts, nil +} + +func (tx *Tx) insertUpdateAccount(account *models.Account, insert bool) error { + found := make(map[int64]bool) + if !insert { + found[account.AccountId] = true + } + parentid := account.ParentAccountId + depth := 0 + for parentid != -1 { + depth += 1 + if depth > 100 { + return store.TooMuchNestingError{} + } + + var a models.Account + err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) + if err != nil { + return store.ParentAccountMissingError{} + } + + // Insertion by itself can never result in circular dependencies + if insert { + break + } + + found[parentid] = true + parentid = a.ParentAccountId + if _, ok := found[parentid]; ok { + return store.CircularAccountsError{} + } + } + + if insert { + err := tx.Insert(account) + if err != nil { + return err + } + } else { + oldacct, err := tx.GetAccount(account.AccountId, account.UserId) + if err != nil { + return err + } + + account.AccountVersion = oldacct.AccountVersion + 1 + + count, err := tx.Update(account) + if err != nil { + return err + } + if count != 1 { + return errors.New("Updated more than one account") + } + } + + return nil +} + +func (tx *Tx) InsertAccount(account *models.Account) error { + return tx.insertUpdateAccount(account, true) +} + +func (tx *Tx) UpdateAccount(account *models.Account) error { + return tx.insertUpdateAccount(account, false) +} + +func (tx *Tx) DeleteAccount(account *models.Account) error { + if account.ParentAccountId != -1 { + // Re-parent splits to this account's parent account if this account isn't a root account + _, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", account.ParentAccountId, account.AccountId) + if err != nil { + return err + } + } else { + // Delete splits if this account is a root account + _, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", account.AccountId) + if err != nil { + return err + } + } + + // Re-parent child accounts to this account's parent account + _, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", account.ParentAccountId, account.AccountId) + if err != nil { + return err + } + + count, err := tx.Delete(account) + if err != nil { + return err + } + if count != 1 { + return errors.New("Was going to delete more than one account") + } + + return nil +} diff --git a/internal/store/db/securities.go b/internal/store/db/securities.go index 0acba29..83a5659 100644 --- a/internal/store/db/securities.go +++ b/internal/store/db/securities.go @@ -3,16 +3,9 @@ package db import ( "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" ) -type SecurityInUseError struct { - message string -} - -func (e SecurityInUseError) Error() string { - return e.message -} - func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) { var s models.Security @@ -33,10 +26,10 @@ func (tx *Tx) GetSecurities(userid int64) (*[]*models.Security, error) { return &securities, nil } -func (tx *Tx) FindMatchingSecurities(userid int64, security *models.Security) (*[]*models.Security, error) { +func (tx *Tx) FindMatchingSecurities(security *models.Security) (*[]*models.Security, error) { var securities []*models.Security - _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision) + _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", security.UserId, security.Type, security.AlternateId, security.Precision) if err != nil { return nil, err } @@ -67,14 +60,14 @@ func (tx *Tx) DeleteSecurity(s *models.Security) error { accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) if accounts != 0 { - return SecurityInUseError{"One or more accounts still use this security"} + return store.SecurityInUseError{"One or more accounts still use this security"} } user, err := tx.GetUser(s.UserId) if err != nil { return err } else if user.DefaultCurrency == s.SecurityId { - return SecurityInUseError{"Cannot delete security which is user's default currency"} + return store.SecurityInUseError{"Cannot delete security which is user's default currency"} } // Remove all prices involving this security (either of this security, or diff --git a/internal/store/store.go b/internal/store/store.go index 86d6c66..fe99c9c 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -20,15 +20,50 @@ type UserStore interface { DeleteUser(user *models.User) error } +type SecurityInUseError struct { + Message string +} + +func (e SecurityInUseError) Error() string { + return e.Message +} + type SecurityStore interface { InsertSecurity(security *models.Security) error GetSecurity(securityid int64, userid int64) (*models.Security, error) GetSecurities(userid int64) (*[]*models.Security, error) - FindMatchingSecurities(userid int64, security *models.Security) (*[]*models.Security, error) + FindMatchingSecurities(security *models.Security) (*[]*models.Security, error) UpdateSecurity(security *models.Security) error DeleteSecurity(security *models.Security) error } +type ParentAccountMissingError struct{} + +func (pame ParentAccountMissingError) Error() string { + return "Parent account missing" +} + +type TooMuchNestingError struct{} + +func (tmne TooMuchNestingError) Error() string { + return "Too much account nesting" +} + +type CircularAccountsError struct{} + +func (cae CircularAccountsError) Error() string { + return "Would result in circular account relationship" +} + +type AccountStore interface { + InsertAccount(account *models.Account) error + GetAccount(accountid int64, userid int64) (*models.Account, error) + GetAccounts(userid int64) (*[]*models.Account, error) + FindMatchingAccounts(account *models.Account) (*[]*models.Account, error) + UpdateAccount(account *models.Account) error + DeleteAccount(account *models.Account) error +} + type Tx interface { Commit() error Rollback() error @@ -36,6 +71,7 @@ type Tx interface { SessionStore UserStore SecurityStore + AccountStore } type Store interface {