mirror of
				https://github.com/aclindsa/moneygo.git
				synced 2025-11-03 18:13:27 -05:00 
			
		
		
		
	Move accounts to store
This commit is contained in:
		@@ -3,44 +3,23 @@ package handlers
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetAccount(tx *db.Tx, accountid int64, userid int64) (*models.Account, error) {
 | 
			
		||||
	var a models.Account
 | 
			
		||||
 | 
			
		||||
	err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &a, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAccounts(tx *db.Tx, userid int64) (*[]models.Account, error) {
 | 
			
		||||
	var accounts []models.Account
 | 
			
		||||
 | 
			
		||||
	_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &accounts, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get (and attempt to create if it doesn't exist). Matches on UserId,
 | 
			
		||||
// SecurityId, Type, Name, and ParentAccountId
 | 
			
		||||
func GetCreateAccount(tx *db.Tx, a models.Account) (*models.Account, error) {
 | 
			
		||||
	var accounts []models.Account
 | 
			
		||||
	var account models.Account
 | 
			
		||||
 | 
			
		||||
	// Try to find the top-level trading account
 | 
			
		||||
	_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId)
 | 
			
		||||
	accounts, err := tx.FindMatchingAccounts(&a)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if len(accounts) == 1 {
 | 
			
		||||
		account = accounts[0]
 | 
			
		||||
	if len(*accounts) > 0 {
 | 
			
		||||
		account = *(*accounts)[0]
 | 
			
		||||
	} else {
 | 
			
		||||
		account.UserId = a.UserId
 | 
			
		||||
		account.SecurityId = a.SecurityId
 | 
			
		||||
@@ -143,120 +122,6 @@ func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Acc
 | 
			
		||||
	return a, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ParentAccountMissingError struct{}
 | 
			
		||||
 | 
			
		||||
func (pame ParentAccountMissingError) Error() string {
 | 
			
		||||
	return "Parent account missing"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TooMuchNestingError struct{}
 | 
			
		||||
 | 
			
		||||
func (tmne TooMuchNestingError) Error() string {
 | 
			
		||||
	return "Too much nesting"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CircularAccountsError struct{}
 | 
			
		||||
 | 
			
		||||
func (cae CircularAccountsError) Error() string {
 | 
			
		||||
	return "Would result in circular account relationship"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func insertUpdateAccount(tx *db.Tx, a *models.Account, insert bool) error {
 | 
			
		||||
	found := make(map[int64]bool)
 | 
			
		||||
	if !insert {
 | 
			
		||||
		found[a.AccountId] = true
 | 
			
		||||
	}
 | 
			
		||||
	parentid := a.ParentAccountId
 | 
			
		||||
	depth := 0
 | 
			
		||||
	for parentid != -1 {
 | 
			
		||||
		depth += 1
 | 
			
		||||
		if depth > 100 {
 | 
			
		||||
			return TooMuchNestingError{}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var a models.Account
 | 
			
		||||
		err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return ParentAccountMissingError{}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Insertion by itself can never result in circular dependencies
 | 
			
		||||
		if insert {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		found[parentid] = true
 | 
			
		||||
		parentid = a.ParentAccountId
 | 
			
		||||
		if _, ok := found[parentid]; ok {
 | 
			
		||||
			return CircularAccountsError{}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if insert {
 | 
			
		||||
		err := tx.Insert(a)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		oldacct, err := GetAccount(tx, a.AccountId, a.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		a.AccountVersion = oldacct.AccountVersion + 1
 | 
			
		||||
 | 
			
		||||
		count, err := tx.Update(a)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		if count != 1 {
 | 
			
		||||
			return errors.New("Updated more than one account")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InsertAccount(tx *db.Tx, a *models.Account) error {
 | 
			
		||||
	return insertUpdateAccount(tx, a, true)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateAccount(tx *db.Tx, a *models.Account) error {
 | 
			
		||||
	return insertUpdateAccount(tx, a, false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteAccount(tx *db.Tx, a *models.Account) error {
 | 
			
		||||
	if a.ParentAccountId != -1 {
 | 
			
		||||
		// Re-parent splits to this account's parent account if this account isn't a root account
 | 
			
		||||
		_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		// Delete splits if this account is a root account
 | 
			
		||||
		_, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Re-parent child accounts to this account's parent account
 | 
			
		||||
	_, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	count, err := tx.Delete(a)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if count != 1 {
 | 
			
		||||
		return errors.New("Was going to delete more than one account")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
			
		||||
	user, err := GetUserFromSession(context.Tx, r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -289,9 +154,9 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
			
		||||
			return NewError(3 /*Invalid Request*/)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = InsertAccount(context.Tx, &account)
 | 
			
		||||
		err = context.Tx.InsertAccount(&account)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if _, ok := err.(ParentAccountMissingError); ok {
 | 
			
		||||
			if _, ok := err.(store.ParentAccountMissingError); ok {
 | 
			
		||||
				return NewError(3 /*Invalid Request*/)
 | 
			
		||||
			} else {
 | 
			
		||||
				log.Print(err)
 | 
			
		||||
@@ -304,7 +169,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
			
		||||
		if context.LastLevel() {
 | 
			
		||||
			//Return all Accounts
 | 
			
		||||
			var al models.AccountList
 | 
			
		||||
			accounts, err := GetAccounts(context.Tx, user.UserId)
 | 
			
		||||
			accounts, err := context.Tx.GetAccounts(user.UserId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Print(err)
 | 
			
		||||
				return NewError(999 /*Internal Error*/)
 | 
			
		||||
@@ -320,7 +185,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
			
		||||
 | 
			
		||||
		if context.LastLevel() {
 | 
			
		||||
			// Return Account with this Id
 | 
			
		||||
			account, err := GetAccount(context.Tx, accountid, user.UserId)
 | 
			
		||||
			account, err := context.Tx.GetAccount(accountid, user.UserId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return NewError(3 /*Invalid Request*/)
 | 
			
		||||
			}
 | 
			
		||||
@@ -354,11 +219,11 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
			
		||||
				return NewError(3 /*Invalid Request*/)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			err = UpdateAccount(context.Tx, &account)
 | 
			
		||||
			err = context.Tx.UpdateAccount(&account)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				if _, ok := err.(ParentAccountMissingError); ok {
 | 
			
		||||
				if _, ok := err.(store.ParentAccountMissingError); ok {
 | 
			
		||||
					return NewError(3 /*Invalid Request*/)
 | 
			
		||||
				} else if _, ok := err.(CircularAccountsError); ok {
 | 
			
		||||
				} else if _, ok := err.(store.CircularAccountsError); ok {
 | 
			
		||||
					return NewError(3 /*Invalid Request*/)
 | 
			
		||||
				} else {
 | 
			
		||||
					log.Print(err)
 | 
			
		||||
@@ -368,12 +233,12 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
			
		||||
 | 
			
		||||
			return &account
 | 
			
		||||
		} else if r.Method == "DELETE" {
 | 
			
		||||
			account, err := GetAccount(context.Tx, accountid, user.UserId)
 | 
			
		||||
			account, err := context.Tx.GetAccount(accountid, user.UserId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return NewError(3 /*Invalid Request*/)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			err = DeleteAccount(context.Tx, account)
 | 
			
		||||
			err = context.Tx.DeleteAccount(account)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Print(err)
 | 
			
		||||
				return NewError(999 /*Internal Error*/)
 | 
			
		||||
 
 | 
			
		||||
@@ -29,14 +29,14 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) {
 | 
			
		||||
			return nil, errors.New("Couldn't find User in lua's Context")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		accounts, err := GetAccounts(tx, user.UserId)
 | 
			
		||||
		accounts, err := tx.GetAccounts(user.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		account_map = make(map[int64]*models.Account)
 | 
			
		||||
		for i := range *accounts {
 | 
			
		||||
			account_map[(*accounts)[i].AccountId] = &(*accounts)[i]
 | 
			
		||||
			account_map[(*accounts)[i].AccountId] = (*accounts)[i]
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ctx = context.WithValue(ctx, accountsContextKey, account_map)
 | 
			
		||||
 
 | 
			
		||||
@@ -38,13 +38,13 @@ func TestImportGnucash(t *testing.T) {
 | 
			
		||||
		}
 | 
			
		||||
		for i, account := range *accounts.Accounts {
 | 
			
		||||
			if account.Name == "Income" && account.Type == models.Income && account.ParentAccountId == -1 {
 | 
			
		||||
				income = &(*accounts.Accounts)[i]
 | 
			
		||||
				income = (*accounts.Accounts)[i]
 | 
			
		||||
			} else if account.Name == "Equity" && account.Type == models.Equity && account.ParentAccountId == -1 {
 | 
			
		||||
				equity = &(*accounts.Accounts)[i]
 | 
			
		||||
				equity = (*accounts.Accounts)[i]
 | 
			
		||||
			} else if account.Name == "Liabilities" && account.Type == models.Liability && account.ParentAccountId == -1 {
 | 
			
		||||
				liabilities = &(*accounts.Accounts)[i]
 | 
			
		||||
				liabilities = (*accounts.Accounts)[i]
 | 
			
		||||
			} else if account.Name == "Expenses" && account.Type == models.Expense && account.ParentAccountId == -1 {
 | 
			
		||||
				expenses = &(*accounts.Accounts)[i]
 | 
			
		||||
				expenses = (*accounts.Accounts)[i]
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if income == nil {
 | 
			
		||||
@@ -61,15 +61,15 @@ func TestImportGnucash(t *testing.T) {
 | 
			
		||||
		}
 | 
			
		||||
		for i, account := range *accounts.Accounts {
 | 
			
		||||
			if account.Name == "Salary" && account.Type == models.Income && account.ParentAccountId == income.AccountId {
 | 
			
		||||
				salary = &(*accounts.Accounts)[i]
 | 
			
		||||
				salary = (*accounts.Accounts)[i]
 | 
			
		||||
			} else if account.Name == "Opening Balances" && account.Type == models.Equity && account.ParentAccountId == equity.AccountId {
 | 
			
		||||
				openingbalances = &(*accounts.Accounts)[i]
 | 
			
		||||
				openingbalances = (*accounts.Accounts)[i]
 | 
			
		||||
			} else if account.Name == "Credit Card" && account.Type == models.Liability && account.ParentAccountId == liabilities.AccountId {
 | 
			
		||||
				creditcard = &(*accounts.Accounts)[i]
 | 
			
		||||
				creditcard = (*accounts.Accounts)[i]
 | 
			
		||||
			} else if account.Name == "Groceries" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId {
 | 
			
		||||
				groceries = &(*accounts.Accounts)[i]
 | 
			
		||||
				groceries = (*accounts.Accounts)[i]
 | 
			
		||||
			} else if account.Name == "Cable" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId {
 | 
			
		||||
				cable = &(*accounts.Accounts)[i]
 | 
			
		||||
				cable = (*accounts.Accounts)[i]
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if salary == nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -39,7 +39,7 @@ func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Return Account with this Id
 | 
			
		||||
	account, err := GetAccount(tx, accountid, user.UserId)
 | 
			
		||||
	account, err := tx.GetAccount(accountid, user.UserId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Print(err)
 | 
			
		||||
		return NewError(3 /*Invalid Request*/)
 | 
			
		||||
@@ -218,7 +218,7 @@ func OFXImportHandler(context *Context, r *http.Request, user *models.User, acco
 | 
			
		||||
		return NewError(3 /*Invalid Request*/)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	account, err := GetAccount(context.Tx, accountid, user.UserId)
 | 
			
		||||
	account, err := context.Tx.GetAccount(accountid, user.UserId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return NewError(3 /*Invalid Request*/)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -83,7 +83,7 @@ func findAccount(client *http.Client, name string, tipe models.AccountType, secu
 | 
			
		||||
	}
 | 
			
		||||
	for _, account := range *accounts.Accounts {
 | 
			
		||||
		if account.Name == name && account.Type == tipe && account.SecurityId == securityid {
 | 
			
		||||
			return &account, nil
 | 
			
		||||
			return account, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil, fmt.Errorf("Unable to find account: \"%s\"", name)
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ package handlers
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -77,7 +78,7 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security)
 | 
			
		||||
		return security, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	securities, err := tx.FindMatchingSecurities(userid, security)
 | 
			
		||||
	securities, err := tx.FindMatchingSecurities(security)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -215,7 +216,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			err = context.Tx.DeleteSecurity(security)
 | 
			
		||||
			if _, ok := err.(db.SecurityInUseError); ok {
 | 
			
		||||
			if _, ok := err.(store.SecurityInUseError); ok {
 | 
			
		||||
				return NewError(7 /*In Use Error*/)
 | 
			
		||||
			} else if err != nil {
 | 
			
		||||
				log.Print(err)
 | 
			
		||||
 
 | 
			
		||||
@@ -32,7 +32,7 @@ func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.R
 | 
			
		||||
		if t.Splits[i].AccountId != -1 {
 | 
			
		||||
			var err error
 | 
			
		||||
			var account *models.Account
 | 
			
		||||
			account, err = GetAccount(tx, t.Splits[i].AccountId, t.UserId)
 | 
			
		||||
			account, err = tx.GetAccount(t.Splits[i].AccountId, t.UserId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
@@ -100,7 +100,7 @@ func GetTransactions(tx *db.Tx, userid int64) (*[]models.Transaction, error) {
 | 
			
		||||
 | 
			
		||||
func incrementAccountVersions(tx *db.Tx, user *models.User, accountids []int64) error {
 | 
			
		||||
	for i := range accountids {
 | 
			
		||||
		account, err := GetAccount(tx, accountids[i], user.UserId)
 | 
			
		||||
		account, err := tx.GetAccount(accountids[i], user.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
@@ -297,7 +297,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
 | 
			
		||||
 | 
			
		||||
		for i := range transaction.Splits {
 | 
			
		||||
			transaction.Splits[i].SplitId = -1
 | 
			
		||||
			_, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId)
 | 
			
		||||
			_, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return NewError(3 /*Invalid Request*/)
 | 
			
		||||
			}
 | 
			
		||||
@@ -371,7 +371,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for i := range transaction.Splits {
 | 
			
		||||
				_, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId)
 | 
			
		||||
				_, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return NewError(3 /*Invalid Request*/)
 | 
			
		||||
				}
 | 
			
		||||
@@ -518,7 +518,7 @@ func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort
 | 
			
		||||
		sqloffset = fmt.Sprintf(" OFFSET %d", page*limit)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	account, err := GetAccount(tx, accountid, user.UserId)
 | 
			
		||||
	account, err := tx.GetAccount(accountid, user.UserId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user