From da7e025509838e8bf45e51054d235a30f316d1dd Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Fri, 8 Dec 2017 21:27:03 -0500 Subject: [PATCH] Move splits/transactions to store --- internal/handlers/accounts_lua.go | 24 +- internal/handlers/gnucash.go | 4 +- internal/handlers/imports.go | 4 +- internal/handlers/transactions.go | 399 +------------------------ internal/handlers/transactions_test.go | 4 +- internal/models/transactions.go | 4 +- internal/store/db/transactions.go | 361 ++++++++++++++++++++++ internal/store/store.go | 16 + 8 files changed, 410 insertions(+), 406 deletions(-) create mode 100644 internal/store/db/transactions.go diff --git a/internal/handlers/accounts_lua.go b/internal/handlers/accounts_lua.go index 2036e62..2985ff5 100644 --- a/internal/handlers/accounts_lua.go +++ b/internal/handlers/accounts_lua.go @@ -6,7 +6,6 @@ import ( "github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/store/db" "github.com/yuin/gopher-lua" - "math/big" "strings" ) @@ -168,24 +167,29 @@ func luaAccountBalance(L *lua.LState) int { panic("SecurityId not in lua security_map") } date := luaWeakCheckTime(L, 2) - var b Balance - var rat *big.Rat + var splits *[]*models.Split if date != nil { end := luaWeakCheckTime(L, 3) if end != nil { - rat, err = GetAccountBalanceDateRange(tx, user, a.AccountId, date, end) + splits, err = tx.GetAccountSplitsDateRange(user, a.AccountId, date, end) } else { - rat, err = GetAccountBalanceDate(tx, user, a.AccountId, date) + splits, err = tx.GetAccountSplitsDate(user, a.AccountId, date) } } else { - rat, err = GetAccountBalance(tx, user, a.AccountId) + splits, err = tx.GetAccountSplits(user, a.AccountId) } if err != nil { - panic("Failed to GetAccountBalance:" + err.Error()) + panic("Failed to fetch splits for account:" + err.Error()) } - b.Amount = rat - b.Security = security - L.Push(BalanceToLua(L, &b)) + rat, err := BalanceFromSplits(splits) + if err != nil { + panic("Failed to calculate balance for account:" + err.Error()) + } + b := &Balance{ + Amount: rat, + Security: security, + } + L.Push(BalanceToLua(L, b)) return 1 } diff --git a/internal/handlers/gnucash.go b/internal/handlers/gnucash.go index 2399a6b..05fa0af 100644 --- a/internal/handlers/gnucash.go +++ b/internal/handlers/gnucash.go @@ -437,7 +437,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite } split.AccountId = acctId - exists, err := SplitAlreadyImported(context.Tx, split) + exists, err := context.Tx.SplitExists(split) if err != nil { log.Print("Error checking if split was already imported:", err) return NewError(999 /*Internal Error*/) @@ -446,7 +446,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite } } if !already_imported { - err := InsertTransaction(context.Tx, &transaction, user) + err := context.Tx.InsertTransaction(&transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 78cee60..b76ae68 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -187,7 +187,7 @@ func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) split.SecurityId = -1 } - exists, err := SplitAlreadyImported(tx, split) + exists, err := tx.SplitExists(split) if err != nil { log.Print("Error checking if split was already imported:", err) return NewError(999 /*Internal Error*/) @@ -202,7 +202,7 @@ func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) } for _, transaction := range transactions { - err := InsertTransaction(tx, &transaction, user) + err := tx.InsertTransaction(&transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 9ba4952..2d16da1 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -2,22 +2,16 @@ package handlers import ( "errors" - "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" "github.com/aclindsa/moneygo/internal/store/db" "log" "math/big" "net/http" "net/url" "strconv" - "time" ) -func SplitAlreadyImported(tx *db.Tx, s *models.Split) (bool, error) { - count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) - return count == 1, err -} - // Return a map of security ID's to big.Rat's containing the amount that // security is imbalanced by func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.Rat, error) { @@ -64,219 +58,6 @@ func TransactionBalanced(tx *db.Tx, t *models.Transaction) (bool, error) { return true, nil } -func GetTransaction(tx *db.Tx, transactionid int64, userid int64) (*models.Transaction, error) { - var t models.Transaction - - err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) - if err != nil { - return nil, err - } - - _, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) - if err != nil { - return nil, err - } - - return &t, nil -} - -func GetTransactions(tx *db.Tx, userid int64) (*[]models.Transaction, error) { - var transactions []models.Transaction - - _, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) - if err != nil { - return nil, err - } - - for i := range transactions { - _, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) - if err != nil { - return nil, err - } - } - - return &transactions, nil -} - -func incrementAccountVersions(tx *db.Tx, user *models.User, accountids []int64) error { - for i := range accountids { - account, err := tx.GetAccount(accountids[i], user.UserId) - if err != nil { - return err - } - account.AccountVersion++ - count, err := tx.Update(account) - if err != nil { - return err - } - if count != 1 { - return errors.New("Updated more than one account") - } - } - return nil -} - -type AccountMissingError struct{} - -func (ame AccountMissingError) Error() string { - return "Account missing" -} - -func InsertTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error { - // Map of any accounts with transaction splits being added - a_map := make(map[int64]bool) - for i := range t.Splits { - if t.Splits[i].AccountId != -1 { - existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) - if err != nil { - return err - } - if existing != 1 { - return AccountMissingError{} - } - a_map[t.Splits[i].AccountId] = true - } else if t.Splits[i].SecurityId == -1 { - return AccountMissingError{} - } - } - - //increment versions for all accounts - var a_ids []int64 - for id := range a_map { - a_ids = append(a_ids, id) - } - // ensure at least one of the splits is associated with an actual account - if len(a_ids) < 1 { - return AccountMissingError{} - } - err := incrementAccountVersions(tx, user, a_ids) - if err != nil { - return err - } - - t.UserId = user.UserId - err = tx.Insert(t) - if err != nil { - return err - } - - for i := range t.Splits { - t.Splits[i].TransactionId = t.TransactionId - t.Splits[i].SplitId = -1 - err = tx.Insert(t.Splits[i]) - if err != nil { - return err - } - } - - return nil -} - -func UpdateTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error { - var existing_splits []*models.Split - - _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) - if err != nil { - return err - } - - // Map of any accounts with transaction splits being added - a_map := make(map[int64]bool) - - // Make a map with any existing splits for this transaction - s_map := make(map[int64]bool) - for i := range existing_splits { - s_map[existing_splits[i].SplitId] = true - } - - // Insert splits, updating any pre-existing ones - for i := range t.Splits { - t.Splits[i].TransactionId = t.TransactionId - _, ok := s_map[t.Splits[i].SplitId] - if ok { - count, err := tx.Update(t.Splits[i]) - if err != nil { - return err - } - if count > 1 { - return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count) - } - delete(s_map, t.Splits[i].SplitId) - } else { - t.Splits[i].SplitId = -1 - err := tx.Insert(t.Splits[i]) - if err != nil { - return err - } - } - if t.Splits[i].AccountId != -1 { - a_map[t.Splits[i].AccountId] = true - } - } - - // Delete any remaining pre-existing splits - for i := range existing_splits { - _, ok := s_map[existing_splits[i].SplitId] - if existing_splits[i].AccountId != -1 { - a_map[existing_splits[i].AccountId] = true - } - if ok { - _, err := tx.Delete(existing_splits[i]) - if err != nil { - return err - } - } - } - - // Increment versions for all accounts with modified splits - var a_ids []int64 - for id := range a_map { - a_ids = append(a_ids, id) - } - err = incrementAccountVersions(tx, user, a_ids) - if err != nil { - return err - } - - count, err := tx.Update(t) - if err != nil { - return err - } - if count > 1 { - return fmt.Errorf("Updated %d transactions (expected 1)", count) - } - - return nil -} - -func DeleteTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error { - var accountids []int64 - _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) - if err != nil { - return err - } - - _, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) - if err != nil { - return err - } - - count, err := tx.Delete(t) - if err != nil { - return err - } - if count != 1 { - return errors.New("Deleted more than one transaction") - } - - err = incrementAccountVersions(tx, user, accountids) - if err != nil { - return err - } - - return nil -} - func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter { user, err := GetUserFromSession(context.Tx, r) if err != nil { @@ -311,9 +92,9 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return NewError(3 /*Invalid Request*/) } - err = InsertTransaction(context.Tx, &transaction, user) + err = context.Tx.InsertTransaction(&transaction, user) if err != nil { - if _, ok := err.(AccountMissingError); ok { + if _, ok := err.(store.AccountMissingError); ok { return NewError(3 /*Invalid Request*/) } else { log.Print(err) @@ -326,7 +107,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter if context.LastLevel() { //Return all Transactions var al models.TransactionList - transactions, err := GetTransactions(context.Tx, user.UserId) + transactions, err := context.Tx.GetTransactions(user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -339,7 +120,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter if err != nil { return NewError(3 /*Invalid Request*/) } - transaction, err := GetTransaction(context.Tx, transactionid, user.UserId) + transaction, err := context.Tx.GetTransaction(transactionid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -377,7 +158,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter } } - err = UpdateTransaction(context.Tx, &transaction, user) + err = context.Tx.UpdateTransaction(&transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -385,12 +166,12 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return &transaction } else if r.Method == "DELETE" { - transaction, err := GetTransaction(context.Tx, transactionid, user.UserId) + transaction, err := context.Tx.GetTransaction(transactionid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - err = DeleteTransaction(context.Tx, transaction, user) + err = context.Tx.DeleteTransaction(transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -402,41 +183,9 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return NewError(3 /*Invalid Request*/) } -func TransactionsBalanceDifference(tx *db.Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) { - var pageDifference, tmp big.Rat - for i := range transactions { - _, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) - if err != nil { - return nil, err - } - - // Sum up the amounts from the splits we're returning so we can return - // an ending balance - for j := range transactions[i].Splits { - if transactions[i].Splits[j].AccountId == accountid { - rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount) - if err != nil { - return nil, err - } - tmp.Add(&pageDifference, rat_amount) - pageDifference.Set(&tmp) - } - } - } - return &pageDifference, nil -} - -func GetAccountBalance(tx *db.Tx, user *models.User, accountid int64) (*big.Rat, error) { - var splits []models.Split - - sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" - _, err := tx.Select(&splits, sql, accountid, user.UserId) - if err != nil { - return nil, err - } - +func BalanceFromSplits(splits *[]*models.Split) (*big.Rat, error) { var balance, tmp big.Rat - for _, s := range splits { + for _, s := range *splits { rat_amount, err := models.GetBigAmount(s.Amount) if err != nil { return nil, err @@ -448,132 +197,6 @@ func GetAccountBalance(tx *db.Tx, user *models.User, accountid int64) (*big.Rat, return &balance, nil } -// Assumes accountid is valid and is owned by the current user -func GetAccountBalanceDate(tx *db.Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) { - var splits []models.Split - - sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" - _, err := tx.Select(&splits, sql, accountid, user.UserId, date) - if err != nil { - return nil, err - } - - var balance, tmp big.Rat - for _, s := range splits { - rat_amount, err := models.GetBigAmount(s.Amount) - if err != nil { - return nil, err - } - tmp.Add(&balance, rat_amount) - balance.Set(&tmp) - } - - return &balance, nil -} - -func GetAccountBalanceDateRange(tx *db.Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) { - var splits []models.Split - - sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" - _, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end) - if err != nil { - return nil, err - } - - var balance, tmp big.Rat - for _, s := range splits { - rat_amount, err := models.GetBigAmount(s.Amount) - if err != nil { - return nil, err - } - tmp.Add(&balance, rat_amount) - balance.Set(&tmp) - } - - return &balance, nil -} - -func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) { - var transactions []models.Transaction - var atl models.AccountTransactionsList - - var sqlsort, balanceLimitOffset string - var balanceLimitOffsetArg uint64 - if sort == "date-asc" { - sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC" - balanceLimitOffset = " LIMIT ?" - balanceLimitOffsetArg = page * limit - } else if sort == "date-desc" { - numSplits, err := tx.SelectInt("SELECT count(*) FROM splits") - if err != nil { - return nil, err - } - sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC" - balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits) - balanceLimitOffsetArg = (page + 1) * limit - } - - var sqloffset string - if page > 0 { - sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) - } - - account, err := tx.GetAccount(accountid, user.UserId) - if err != nil { - return nil, err - } - atl.Account = account - - sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset - _, err = tx.Select(&transactions, sql, user.UserId, accountid, limit) - if err != nil { - return nil, err - } - atl.Transactions = &transactions - - pageDifference, err := TransactionsBalanceDifference(tx, accountid, transactions) - if err != nil { - return nil, err - } - - count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) - if err != nil { - return nil, err - } - atl.TotalTransactions = count - - security, err := tx.GetSecurity(atl.Account.SecurityId, user.UserId) - if err != nil { - return nil, err - } - if security == nil { - return nil, errors.New("Security not found") - } - - // Sum all the splits for all transaction splits for this account that - // occurred before the page we're returning - var amounts []string - sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" - _, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) - if err != nil { - return nil, err - } - - var tmp, balance big.Rat - for _, amount := range amounts { - rat_amount, err := models.GetBigAmount(amount) - if err != nil { - return nil, err - } - tmp.Add(&balance, rat_amount) - balance.Set(&tmp) - } - atl.BeginningBalance = balance.FloatString(security.Precision) - atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) - - return &atl, nil -} - // Return only those transactions which have at least one split pertaining to // an account func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { @@ -609,7 +232,7 @@ func AccountTransactionsHandler(context *Context, r *http.Request, user *models. sort = sortstring } - accountTransactions, err := GetAccountTransactions(context.Tx, user, accountid, sort, page, limit) + accountTransactions, err := context.Tx.GetAccountTransactions(user, accountid, sort, page, limit) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/transactions_test.go b/internal/handlers/transactions_test.go index 0f3b68d..4d2ee54 100644 --- a/internal/handlers/transactions_test.go +++ b/internal/handlers/transactions_test.go @@ -276,7 +276,7 @@ func TestGetTransactions(t *testing.T) { found := false for _, tran := range *tl.Transactions { if tran.TransactionId == curr.TransactionId { - ensureTransactionsMatch(t, &curr, &tran, nil, true, true) + ensureTransactionsMatch(t, &curr, tran, nil, true, true) if _, ok := foundIds[tran.TransactionId]; ok { continue } @@ -410,7 +410,7 @@ func helperTestAccountTransactions(t *testing.T, d *TestData, account *models.Ac } if atl.Transactions != nil { for _, tran := range *atl.Transactions { - transactions = append(transactions, tran) + transactions = append(transactions, *tran) } lastFetchCount = int64(len(*atl.Transactions)) } else { diff --git a/internal/models/transactions.go b/internal/models/transactions.go index 8076995..5046e65 100644 --- a/internal/models/transactions.go +++ b/internal/models/transactions.go @@ -82,12 +82,12 @@ type Transaction struct { } type TransactionList struct { - Transactions *[]Transaction `json:"transactions"` + Transactions *[]*Transaction `json:"transactions"` } type AccountTransactionsList struct { Account *Account - Transactions *[]Transaction + Transactions *[]*Transaction TotalTransactions int64 BeginningBalance string EndingBalance string diff --git a/internal/store/db/transactions.go b/internal/store/db/transactions.go new file mode 100644 index 0000000..29168df --- /dev/null +++ b/internal/store/db/transactions.go @@ -0,0 +1,361 @@ +package db + +import ( + "errors" + "fmt" + "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" + "math/big" + "time" +) + +func (tx *Tx) incrementAccountVersions(user *models.User, accountids []int64) error { + for i := range accountids { + account, err := tx.GetAccount(accountids[i], user.UserId) + if err != nil { + return err + } + account.AccountVersion++ + count, err := tx.Update(account) + if err != nil { + return err + } + if count != 1 { + return errors.New("Updated more than one account") + } + } + return nil +} + +func (tx *Tx) InsertTransaction(t *models.Transaction, user *models.User) error { + // Map of any accounts with transaction splits being added + a_map := make(map[int64]bool) + for i := range t.Splits { + if t.Splits[i].AccountId != -1 { + existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) + if err != nil { + return err + } + if existing != 1 { + return store.AccountMissingError{} + } + a_map[t.Splits[i].AccountId] = true + } else if t.Splits[i].SecurityId == -1 { + return store.AccountMissingError{} + } + } + + //increment versions for all accounts + var a_ids []int64 + for id := range a_map { + a_ids = append(a_ids, id) + } + // ensure at least one of the splits is associated with an actual account + if len(a_ids) < 1 { + return store.AccountMissingError{} + } + err := tx.incrementAccountVersions(user, a_ids) + if err != nil { + return err + } + + t.UserId = user.UserId + err = tx.Insert(t) + if err != nil { + return err + } + + for i := range t.Splits { + t.Splits[i].TransactionId = t.TransactionId + t.Splits[i].SplitId = -1 + err = tx.Insert(t.Splits[i]) + if err != nil { + return err + } + } + + return nil +} + +func (tx *Tx) SplitExists(s *models.Split) (bool, error) { + count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) + return count == 1, err +} + +func (tx *Tx) GetTransaction(transactionid int64, userid int64) (*models.Transaction, error) { + var t models.Transaction + + err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) + if err != nil { + return nil, err + } + + _, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) + if err != nil { + return nil, err + } + + return &t, nil +} + +func (tx *Tx) GetTransactions(userid int64) (*[]*models.Transaction, error) { + var transactions []*models.Transaction + + _, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) + if err != nil { + return nil, err + } + + for i := range transactions { + _, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) + if err != nil { + return nil, err + } + } + + return &transactions, nil +} + +func (tx *Tx) UpdateTransaction(t *models.Transaction, user *models.User) error { + var existing_splits []*models.Split + + _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) + if err != nil { + return err + } + + // Map of any accounts with transaction splits being added + a_map := make(map[int64]bool) + + // Make a map with any existing splits for this transaction + s_map := make(map[int64]bool) + for i := range existing_splits { + s_map[existing_splits[i].SplitId] = true + } + + // Insert splits, updating any pre-existing ones + for i := range t.Splits { + t.Splits[i].TransactionId = t.TransactionId + _, ok := s_map[t.Splits[i].SplitId] + if ok { + count, err := tx.Update(t.Splits[i]) + if err != nil { + return err + } + if count > 1 { + return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count) + } + delete(s_map, t.Splits[i].SplitId) + } else { + t.Splits[i].SplitId = -1 + err := tx.Insert(t.Splits[i]) + if err != nil { + return err + } + } + if t.Splits[i].AccountId != -1 { + a_map[t.Splits[i].AccountId] = true + } + } + + // Delete any remaining pre-existing splits + for i := range existing_splits { + _, ok := s_map[existing_splits[i].SplitId] + if existing_splits[i].AccountId != -1 { + a_map[existing_splits[i].AccountId] = true + } + if ok { + _, err := tx.Delete(existing_splits[i]) + if err != nil { + return err + } + } + } + + // Increment versions for all accounts with modified splits + var a_ids []int64 + for id := range a_map { + a_ids = append(a_ids, id) + } + err = tx.incrementAccountVersions(user, a_ids) + if err != nil { + return err + } + + count, err := tx.Update(t) + if err != nil { + return err + } + if count > 1 { + return fmt.Errorf("Updated %d transactions (expected 1)", count) + } + + return nil +} + +func (tx *Tx) DeleteTransaction(t *models.Transaction, user *models.User) error { + var accountids []int64 + _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) + if err != nil { + return err + } + + _, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) + if err != nil { + return err + } + + count, err := tx.Delete(t) + if err != nil { + return err + } + if count != 1 { + return errors.New("Deleted more than one transaction") + } + + err = tx.incrementAccountVersions(user, accountids) + if err != nil { + return err + } + + return nil +} + +func (tx *Tx) GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error) { + var splits []*models.Split + + sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" + _, err := tx.Select(&splits, sql, accountid, user.UserId) + if err != nil { + return nil, err + } + return &splits, nil +} + +// Assumes accountid is valid and is owned by the current user +func (tx *Tx) GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error) { + var splits []*models.Split + + sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" + _, err := tx.Select(&splits, sql, accountid, user.UserId, date) + if err != nil { + return nil, err + } + return &splits, err +} + +func (tx *Tx) GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error) { + var splits []*models.Split + + sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" + _, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end) + if err != nil { + return nil, err + } + return &splits, nil +} + +func (tx *Tx) transactionsBalanceDifference(accountid int64, transactions []*models.Transaction) (*big.Rat, error) { + var pageDifference, tmp big.Rat + for i := range transactions { + _, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) + if err != nil { + return nil, err + } + + // Sum up the amounts from the splits we're returning so we can return + // an ending balance + for j := range transactions[i].Splits { + if transactions[i].Splits[j].AccountId == accountid { + rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount) + if err != nil { + return nil, err + } + tmp.Add(&pageDifference, rat_amount) + pageDifference.Set(&tmp) + } + } + } + return &pageDifference, nil +} + +func (tx *Tx) GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) { + var transactions []*models.Transaction + var atl models.AccountTransactionsList + + var sqlsort, balanceLimitOffset string + var balanceLimitOffsetArg uint64 + if sort == "date-asc" { + sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC" + balanceLimitOffset = " LIMIT ?" + balanceLimitOffsetArg = page * limit + } else if sort == "date-desc" { + numSplits, err := tx.SelectInt("SELECT count(*) FROM splits") + if err != nil { + return nil, err + } + sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC" + balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits) + balanceLimitOffsetArg = (page + 1) * limit + } + + var sqloffset string + if page > 0 { + sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) + } + + account, err := tx.GetAccount(accountid, user.UserId) + if err != nil { + return nil, err + } + atl.Account = account + + sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset + _, err = tx.Select(&transactions, sql, user.UserId, accountid, limit) + if err != nil { + return nil, err + } + atl.Transactions = &transactions + + pageDifference, err := tx.transactionsBalanceDifference(accountid, transactions) + if err != nil { + return nil, err + } + + count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) + if err != nil { + return nil, err + } + atl.TotalTransactions = count + + security, err := tx.GetSecurity(atl.Account.SecurityId, user.UserId) + if err != nil { + return nil, err + } + if security == nil { + return nil, errors.New("Security not found") + } + + // Sum all the splits for all transaction splits for this account that + // occurred before the page we're returning + var amounts []string + sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" + _, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) + if err != nil { + return nil, err + } + + var tmp, balance big.Rat + for _, amount := range amounts { + rat_amount, err := models.GetBigAmount(amount) + if err != nil { + return nil, err + } + tmp.Add(&balance, rat_amount) + balance.Set(&tmp) + } + atl.BeginningBalance = balance.FloatString(security.Precision) + atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) + + return &atl, nil +} diff --git a/internal/store/store.go b/internal/store/store.go index b1ffc8a..c890dd7 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -76,7 +76,23 @@ type AccountStore interface { DeleteAccount(account *models.Account) error } +type AccountMissingError struct{} + +func (ame AccountMissingError) Error() string { + return "Account missing" +} + type TransactionStore interface { + SplitExists(s *models.Split) (bool, error) + InsertTransaction(t *models.Transaction, user *models.User) error + GetTransaction(transactionid int64, userid int64) (*models.Transaction, error) + GetTransactions(userid int64) (*[]*models.Transaction, error) + UpdateTransaction(t *models.Transaction, user *models.User) error + DeleteTransaction(t *models.Transaction, user *models.User) error + GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error) + GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error) + GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error) + GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) } type ReportStore interface {