mirror of
				https://github.com/aclindsa/moneygo.git
				synced 2025-11-03 18:13:27 -05:00 
			
		
		
		
	testing: Test for and protect against circular accounts
This commit is contained in:
		@@ -276,22 +276,55 @@ func (pame ParentAccountMissingError) Error() string {
 | 
				
			|||||||
	return "Parent account missing"
 | 
						return "Parent account missing"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type TooMuchNestingError struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (tmne TooMuchNestingError) Error() string {
 | 
				
			||||||
 | 
						return "Too much nesting"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type CircularAccountsError struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (cae CircularAccountsError) Error() string {
 | 
				
			||||||
 | 
						return "Would result in circular account relationship"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func insertUpdateAccount(db *DB, a *Account, insert bool) error {
 | 
					func insertUpdateAccount(db *DB, a *Account, insert bool) error {
 | 
				
			||||||
	transaction, err := db.Begin()
 | 
						transaction, err := db.Begin()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if a.ParentAccountId != -1 {
 | 
						found := make(map[int64]bool)
 | 
				
			||||||
		existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", a.ParentAccountId)
 | 
						if !insert {
 | 
				
			||||||
 | 
							found[a.AccountId] = true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						parentid := a.ParentAccountId
 | 
				
			||||||
 | 
						depth := 0
 | 
				
			||||||
 | 
						for parentid != -1 {
 | 
				
			||||||
 | 
							depth += 1
 | 
				
			||||||
 | 
							if depth > 100 {
 | 
				
			||||||
 | 
								transaction.Rollback()
 | 
				
			||||||
 | 
								return TooMuchNestingError{}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							var a Account
 | 
				
			||||||
 | 
							err := transaction.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			transaction.Rollback()
 | 
					 | 
				
			||||||
			return err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if existing != 1 {
 | 
					 | 
				
			||||||
			transaction.Rollback()
 | 
								transaction.Rollback()
 | 
				
			||||||
			return ParentAccountMissingError{}
 | 
								return ParentAccountMissingError{}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Insertion by itself can never result in circular dependencies
 | 
				
			||||||
 | 
							if insert {
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							found[parentid] = true
 | 
				
			||||||
 | 
							parentid = a.ParentAccountId
 | 
				
			||||||
 | 
							if _, ok := found[parentid]; ok {
 | 
				
			||||||
 | 
								transaction.Rollback()
 | 
				
			||||||
 | 
								return CircularAccountsError{}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if insert {
 | 
						if insert {
 | 
				
			||||||
@@ -537,6 +570,8 @@ func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) {
 | 
				
			|||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				if _, ok := err.(ParentAccountMissingError); ok {
 | 
									if _, ok := err.(ParentAccountMissingError); ok {
 | 
				
			||||||
					WriteError(w, 3 /*Invalid Request*/)
 | 
										WriteError(w, 3 /*Invalid Request*/)
 | 
				
			||||||
 | 
									} else if _, ok := err.(CircularAccountsError); ok {
 | 
				
			||||||
 | 
										WriteError(w, 3 /*Invalid Request*/)
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					WriteError(w, 999 /*Internal Error*/)
 | 
										WriteError(w, 999 /*Internal Error*/)
 | 
				
			||||||
					log.Print(err)
 | 
										log.Print(err)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -128,7 +128,15 @@ func TestUpdateAccount(t *testing.T) {
 | 
				
			|||||||
			t.Fatalf("Expected error updating account with invalid parent: %+v\n", a)
 | 
								t.Fatalf("Expected error updating account with invalid parent: %+v\n", a)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// TODO ensure you can't create cycles with ParentAccountId
 | 
							orig = data[0].accounts[0]
 | 
				
			||||||
 | 
							curr = d.accounts[0]
 | 
				
			||||||
 | 
							child := d.accounts[1]
 | 
				
			||||||
 | 
							curr.ParentAccountId = child.AccountId
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							a, err = updateAccount(d.clients[orig.UserId], &curr)
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								t.Fatalf("Expected error updating account with circular parent relationship: %+v\n", a)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user