diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index cb6cdde..fa9fce8 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -276,22 +276,55 @@ func (pame ParentAccountMissingError) Error() string { return "Parent account missing" } +type TooMuchNestingError struct{} + +func (tmne TooMuchNestingError) Error() string { + return "Too much nesting" +} + +type CircularAccountsError struct{} + +func (cae CircularAccountsError) Error() string { + return "Would result in circular account relationship" +} + func insertUpdateAccount(db *DB, a *Account, insert bool) error { transaction, err := db.Begin() if err != nil { return err } - if a.ParentAccountId != -1 { - existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", a.ParentAccountId) + found := make(map[int64]bool) + if !insert { + found[a.AccountId] = true + } + parentid := a.ParentAccountId + depth := 0 + for parentid != -1 { + depth += 1 + if depth > 100 { + transaction.Rollback() + return TooMuchNestingError{} + } + + var a Account + err := transaction.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) if err != nil { - transaction.Rollback() - return err - } - if existing != 1 { transaction.Rollback() return ParentAccountMissingError{} } + + // Insertion by itself can never result in circular dependencies + if insert { + break + } + + found[parentid] = true + parentid = a.ParentAccountId + if _, ok := found[parentid]; ok { + transaction.Rollback() + return CircularAccountsError{} + } } if insert { @@ -537,6 +570,8 @@ func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) { if err != nil { if _, ok := err.(ParentAccountMissingError); ok { WriteError(w, 3 /*Invalid Request*/) + } else if _, ok := err.(CircularAccountsError); ok { + WriteError(w, 3 /*Invalid Request*/) } else { WriteError(w, 999 /*Internal Error*/) log.Print(err) diff --git a/internal/handlers/accounts_test.go b/internal/handlers/accounts_test.go index 6541b58..5b2f3da 100644 --- a/internal/handlers/accounts_test.go +++ b/internal/handlers/accounts_test.go @@ -128,7 +128,15 @@ func TestUpdateAccount(t *testing.T) { t.Fatalf("Expected error updating account with invalid parent: %+v\n", a) } - // TODO ensure you can't create cycles with ParentAccountId + orig = data[0].accounts[0] + curr = d.accounts[0] + child := d.accounts[1] + curr.ParentAccountId = child.AccountId + + a, err = updateAccount(d.clients[orig.UserId], &curr) + if err == nil { + t.Fatalf("Expected error updating account with circular parent relationship: %+v\n", a) + } }) }