From 4e53a5e59c209a742a8bb1ee36538945be6bc559 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sat, 14 Oct 2017 14:20:50 -0400 Subject: [PATCH 1/2] Use SQL transactions for the entirety of every request --- internal/handlers/accounts.go | 188 +++++--------- internal/handlers/accounts_lua.go | 16 +- internal/handlers/errors.go | 8 +- internal/handlers/gnucash.go | 82 ++----- internal/handlers/handlers.go | 62 +++-- internal/handlers/imports.go | 159 ++++-------- internal/handlers/prices.go | 21 +- internal/handlers/reports.go | 133 ++++------ internal/handlers/securities.go | 150 ++++------- internal/handlers/securities_lua.go | 12 +- internal/handlers/sessions.go | 92 ++++--- internal/handlers/transactions.go | 369 +++++++--------------------- internal/handlers/users.go | 176 ++++--------- internal/handlers/util.go | 17 ++ 14 files changed, 496 insertions(+), 989 deletions(-) diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index d278b6c..5ba8708 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -129,10 +129,10 @@ func (al *AccountList) Read(json_str string) error { return dec.Decode(al) } -func GetAccount(db *DB, accountid int64, userid int64) (*Account, error) { +func GetAccount(tx *Tx, accountid int64, userid int64) (*Account, error) { var a Account - err := db.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) + err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) if err != nil { return nil, err } @@ -150,10 +150,10 @@ func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64) return &a, nil } -func GetAccounts(db *DB, userid int64) (*[]Account, error) { +func GetAccounts(tx *Tx, userid int64) (*[]Account, error) { var accounts []Account - _, err := db.Select(&accounts, "SELECT * from accounts where UserId=?", userid) + _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) if err != nil { return nil, err } @@ -293,12 +293,7 @@ func (cae CircularAccountsError) Error() string { return "Would result in circular account relationship" } -func insertUpdateAccount(db *DB, a *Account, insert bool) error { - transaction, err := db.Begin() - if err != nil { - return err - } - +func insertUpdateAccount(tx *Tx, a *Account, insert bool) error { found := make(map[int64]bool) if !insert { found[a.AccountId] = true @@ -308,14 +303,12 @@ func insertUpdateAccount(db *DB, a *Account, insert bool) error { for parentid != -1 { depth += 1 if depth > 100 { - transaction.Rollback() return TooMuchNestingError{} } var a Account - err := transaction.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) + err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) if err != nil { - transaction.Rollback() return ParentAccountMissingError{} } @@ -327,107 +320,79 @@ func insertUpdateAccount(db *DB, a *Account, insert bool) error { found[parentid] = true parentid = a.ParentAccountId if _, ok := found[parentid]; ok { - transaction.Rollback() return CircularAccountsError{} } } if insert { - err = transaction.Insert(a) + err := tx.Insert(a) if err != nil { - transaction.Rollback() return err } } else { - oldacct, err := GetAccountTx(transaction, a.AccountId, a.UserId) + oldacct, err := GetAccountTx(tx, a.AccountId, a.UserId) if err != nil { - transaction.Rollback() return err } a.AccountVersion = oldacct.AccountVersion + 1 - count, err := transaction.Update(a) + count, err := tx.Update(a) if err != nil { - transaction.Rollback() return err } if count != 1 { - transaction.Rollback() return errors.New("Updated more than one account") } } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } -func InsertAccount(db *DB, a *Account) error { - return insertUpdateAccount(db, a, true) +func InsertAccount(tx *Tx, a *Account) error { + return insertUpdateAccount(tx, a, true) } -func UpdateAccount(db *DB, a *Account) error { - return insertUpdateAccount(db, a, false) +func UpdateAccount(tx *Tx, a *Account) error { + return insertUpdateAccount(tx, a, false) } -func DeleteAccount(db *DB, a *Account) error { - transaction, err := db.Begin() - if err != nil { - return err - } - +func DeleteAccount(tx *Tx, a *Account) error { if a.ParentAccountId != -1 { // Re-parent splits to this account's parent account if this account isn't a root account - _, err = transaction.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) + _, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) if err != nil { - transaction.Rollback() return err } } else { // Delete splits if this account is a root account - _, err = transaction.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId) + _, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId) if err != nil { - transaction.Rollback() return err } } // Re-parent child accounts to this account's parent account - _, err = transaction.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId) + _, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId) if err != nil { - transaction.Rollback() return err } - count, err := transaction.Delete(a) + count, err := tx.Delete(a) if err != nil { - transaction.Rollback() return err } if count != 1 { - transaction.Rollback() return errors.New("Was going to delete more than one account") } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } -func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) { - user, err := GetUserFromSession(db, r) +func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter { + user, err := GetUserFromSession(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } if r.Method == "POST" { @@ -439,59 +404,46 @@ func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) { n, err := GetURLPieces(r.URL.Path, "/account/%d/import/%s", &accountid, &importtype) if err != nil || n != 2 { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - AccountImportHandler(db, w, r, user, accountid, importtype) - return + return AccountImportHandler(tx, r, user, accountid, importtype) } account_json := r.PostFormValue("account") if account_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var account Account err := account.Read(account_json) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } account.AccountId = -1 account.UserId = user.UserId account.AccountVersion = 0 - security, err := GetSecurity(db, account.SecurityId, user.UserId) + security, err := GetSecurity(tx, account.SecurityId, user.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } if security == nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = InsertAccount(db, &account) + err = InsertAccount(tx, &account) if err != nil { if _, ok := err.(ParentAccountMissingError); ok { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } else { - WriteError(w, 999 /*Internal Error*/) log.Print(err) + return NewError(999 /*Internal Error*/) } - return } - w.WriteHeader(201 /*Created*/) - err = account.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return ResponseWrapper{201, &account} } else if r.Method == "GET" { var accountid int64 n, err := GetURLPieces(r.URL.Path, "/account/%d", &accountid) @@ -499,112 +451,86 @@ func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) { if err != nil || n != 1 { //Return all Accounts var al AccountList - accounts, err := GetAccounts(db, user.UserId) + accounts, err := GetAccounts(tx, user.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } al.Accounts = accounts - err = (&al).Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &al } else { // if URL looks like /account/[0-9]+/transactions, use the account // transaction handler if accountTransactionsRE.MatchString(r.URL.Path) { - AccountTransactionsHandler(db, w, r, user, accountid) - return + return AccountTransactionsHandler(tx, r, user, accountid) } // Return Account with this Id - account, err := GetAccount(db, accountid, user.UserId) + account, err := GetAccount(tx, accountid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = account.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return account } } else { accountid, err := GetURLID(r.URL.Path) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } if r.Method == "PUT" { account_json := r.PostFormValue("account") if account_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var account Account err := account.Read(account_json) if err != nil || account.AccountId != accountid { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } account.UserId = user.UserId - security, err := GetSecurity(db, account.SecurityId, user.UserId) + security, err := GetSecurity(tx, account.SecurityId, user.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } if security == nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } if account.ParentAccountId == account.AccountId { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = UpdateAccount(db, &account) + err = UpdateAccount(tx, &account) if err != nil { if _, ok := err.(ParentAccountMissingError); ok { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } else if _, ok := err.(CircularAccountsError); ok { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } else { - WriteError(w, 999 /*Internal Error*/) log.Print(err) + return NewError(999 /*Internal Error*/) } - return } - err = account.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &account } else if r.Method == "DELETE" { - account, err := GetAccount(db, accountid, user.UserId) + account, err := GetAccount(tx, accountid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = DeleteAccount(db, account) + err = DeleteAccount(tx, account) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - WriteSuccess(w) + return SuccessWriter{} } } + return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/accounts_lua.go b/internal/handlers/accounts_lua.go index 928b0ba..4edf440 100644 --- a/internal/handlers/accounts_lua.go +++ b/internal/handlers/accounts_lua.go @@ -15,9 +15,9 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { ctx := L.Context() - db, ok := ctx.Value(dbContextKey).(*DB) + tx, ok := ctx.Value(dbContextKey).(*Tx) if !ok { - return nil, errors.New("Couldn't find DB in lua's Context") + return nil, errors.New("Couldn't find tx in lua's Context") } account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account) @@ -27,7 +27,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { return nil, errors.New("Couldn't find User in lua's Context") } - accounts, err := GetAccounts(db, user.UserId) + accounts, err := GetAccounts(tx, user.UserId) if err != nil { return nil, err } @@ -149,9 +149,9 @@ func luaAccountBalance(L *lua.LState) int { a := luaCheckAccount(L, 1) ctx := L.Context() - db, ok := ctx.Value(dbContextKey).(*DB) + tx, ok := ctx.Value(dbContextKey).(*Tx) if !ok { - panic("Couldn't find DB in lua's Context") + panic("Couldn't find tx in lua's Context") } user, ok := ctx.Value(userContextKey).(*User) if !ok { @@ -171,12 +171,12 @@ func luaAccountBalance(L *lua.LState) int { if date != nil { end := luaWeakCheckTime(L, 3) if end != nil { - rat, err = GetAccountBalanceDateRange(db, user, a.AccountId, date, end) + rat, err = GetAccountBalanceDateRange(tx, user, a.AccountId, date, end) } else { - rat, err = GetAccountBalanceDate(db, user, a.AccountId, date) + rat, err = GetAccountBalanceDate(tx, user, a.AccountId, date) } } else { - rat, err = GetAccountBalance(db, user, a.AccountId) + rat, err = GetAccountBalance(tx, user, a.AccountId) } if err != nil { panic("Failed to GetAccountBalance:" + err.Error()) diff --git a/internal/handlers/errors.go b/internal/handlers/errors.go index 17a5c0c..e5f4041 100644 --- a/internal/handlers/errors.go +++ b/internal/handlers/errors.go @@ -38,13 +38,17 @@ var error_codes = map[int]string{ 999: "Internal Error", } -func WriteError(w http.ResponseWriter, error_code int) { +func NewError(error_code int) *Error { msg, ok := error_codes[error_code] if !ok { log.Printf("Error: WriteError received error code of %d", error_code) msg = error_codes[999] } - e := Error{error_code, msg} + return &Error{error_code, msg} +} + +func WriteError(w http.ResponseWriter, error_code int) { + e := NewError(error_code) err := e.Write(w) if err != nil { diff --git a/internal/handlers/gnucash.go b/internal/handlers/gnucash.go index 24f6b7c..3ce1358 100644 --- a/internal/handlers/gnucash.go +++ b/internal/handlers/gnucash.go @@ -308,42 +308,37 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { return &gncimport, nil } -func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { - user, err := GetUserFromSession(db, r) +func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { + user, err := GetUserFromSession(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } if r.Method != "POST" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } multipartReader, err := r.MultipartReader() if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } // Assume there is only one 'part' and it's the one we care about part, err := multipartReader.NextPart() if err != nil { if err == io.EOF { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } else { - WriteError(w, 999 /*Internal Error*/) log.Print(err) + return NewError(999 /*Internal Error*/) } - return } bufread := bufio.NewReader(part) gzHeader, err := bufread.Peek(2) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } // Does this look like a gzipped file? @@ -351,9 +346,8 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { if gzHeader[0] == 0x1f && gzHeader[1] == 0x8b { gzr, err := gzip.NewReader(bufread) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } gnucashImport, err = ImportGnucash(gzr) } else { @@ -361,15 +355,7 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { } if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return - } - - sqltransaction, err := db.Begin() - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return + return NewError(3 /*Invalid Request*/) } // Import securities, building map from Gnucash security IDs to our @@ -377,13 +363,11 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { securityMap := make(map[int64]int64) for _, security := range gnucashImport.Securities { securityId := security.SecurityId // save off because it could be updated - s, err := ImportGetCreateSecurity(sqltransaction, user.UserId, &security) + s, err := ImportGetCreateSecurity(tx, user.UserId, &security) if err != nil { - sqltransaction.Rollback() - WriteError(w, 6 /*Import Error*/) log.Print(err) log.Print(security) - return + return NewError(6 /*Import Error*/) } securityMap[securityId] = s.SecurityId } @@ -394,12 +378,10 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { price.CurrencyId = securityMap[price.CurrencyId] price.PriceId = 0 - err := CreatePriceIfNotExist(sqltransaction, &price) + err := CreatePriceIfNotExist(tx, &price) if err != nil { - sqltransaction.Rollback() - WriteError(w, 6 /*Import Error*/) log.Print(err) - return + return NewError(6 /*Import Error*/) } } @@ -425,12 +407,10 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { account.ParentAccountId = accountMap[account.ParentAccountId] } account.SecurityId = securityMap[account.SecurityId] - a, err := GetCreateAccountTx(sqltransaction, account) + a, err := GetCreateAccountTx(tx, account) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } accountMap[account.AccountId] = a.AccountId accountsRemaining-- @@ -438,10 +418,8 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { } if accountsRemaining == accountsRemainingLast { //We didn't make any progress in importing the next level of accounts, so there must be a circular parent-child relationship, so give up and tell the user they're wrong - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(fmt.Errorf("Circular account parent-child relationship when importing %s", part.FileName())) - return + return NewError(999 /*Internal Error*/) } accountsRemainingLast = accountsRemaining } @@ -453,41 +431,27 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { for _, split := range transaction.Splits { acctId, ok := accountMap[split.AccountId] if !ok { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(fmt.Errorf("Error: Split's AccountID Doesn't exist: %d\n", split.AccountId)) - return + return NewError(999 /*Internal Error*/) } split.AccountId = acctId - exists, err := split.AlreadyImportedTx(sqltransaction) + exists, err := split.AlreadyImportedTx(tx) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Error checking if split was already imported:", err) - return + return NewError(999 /*Internal Error*/) } else if exists { already_imported = true } } if !already_imported { - err := InsertTransactionTx(sqltransaction, &transaction, user) + err := InsertTransactionTx(tx, &transaction, user) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } } } - err = sqltransaction.Commit() - if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } - - WriteSuccess(w) + return SuccessWriter{} } diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 985d8d5..d518e92 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -2,30 +2,64 @@ package handlers import ( "gopkg.in/gorp.v1" + "log" "net/http" ) -// Create a closure over db, allowing the handlers to look like a -// http.HandlerFunc -type DB = gorp.DbMap -type DBHandler func(http.ResponseWriter, *http.Request, *DB) +// But who writes the ResponseWriterWriter? +type ResponseWriterWriter interface { + Write(http.ResponseWriter) error +} +type Tx = gorp.Transaction +type TxHandler func(*http.Request, *Tx) ResponseWriterWriter -func DBHandlerFunc(h DBHandler, db *DB) http.HandlerFunc { +func TxHandlerFunc(t TxHandler, db *gorp.DbMap) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - h(w, r, db) + tx, err := db.Begin() + if err != nil { + log.Print(err) + WriteError(w, 999 /*Internal Error*/) + return + } + defer func() { + if r := recover(); r != nil { + tx.Rollback() + WriteError(w, 999 /*Internal Error*/) + panic(r) + } + }() + + writer := t(r, tx) + + if e, ok := writer.(*Error); ok { + tx.Rollback() + e.Write(w) + } else { + err = tx.Commit() + if err != nil { + log.Print(err) + WriteError(w, 999 /*Internal Error*/) + } else { + err = writer.Write(w) + if err != nil { + log.Print(err) + WriteError(w, 999 /*Internal Error*/) + } + } + } } } -func GetHandler(db *DB) *http.ServeMux { +func GetHandler(db *gorp.DbMap) *http.ServeMux { servemux := http.NewServeMux() - servemux.HandleFunc("/session/", DBHandlerFunc(SessionHandler, db)) - servemux.HandleFunc("/user/", DBHandlerFunc(UserHandler, db)) - servemux.HandleFunc("/security/", DBHandlerFunc(SecurityHandler, db)) + servemux.HandleFunc("/session/", TxHandlerFunc(SessionHandler, db)) + servemux.HandleFunc("/user/", TxHandlerFunc(UserHandler, db)) + servemux.HandleFunc("/security/", TxHandlerFunc(SecurityHandler, db)) servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler) - servemux.HandleFunc("/account/", DBHandlerFunc(AccountHandler, db)) - servemux.HandleFunc("/transaction/", DBHandlerFunc(TransactionHandler, db)) - servemux.HandleFunc("/import/gnucash", DBHandlerFunc(GnucashImportHandler, db)) - servemux.HandleFunc("/report/", DBHandlerFunc(ReportHandler, db)) + servemux.HandleFunc("/account/", TxHandlerFunc(AccountHandler, db)) + servemux.HandleFunc("/transaction/", TxHandlerFunc(TransactionHandler, db)) + servemux.HandleFunc("/import/gnucash", TxHandlerFunc(GnucashImportHandler, db)) + servemux.HandleFunc("/report/", TxHandlerFunc(ReportHandler, db)) return servemux } diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 28a5922..9470a26 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -22,48 +22,35 @@ func (od *OFXDownload) Read(json_str string) error { return dec.Decode(od) } -func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, accountid int64) { +func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseWriterWriter { itl, err := ImportOFX(r) if err != nil { //TODO is this necessarily an invalid request (what if it was an error on our end)? - WriteError(w, 3 /*Invalid Request*/) log.Print(err) - return + return NewError(3 /*Invalid Request*/) } if len(itl.Accounts) != 1 { - WriteError(w, 3 /*Invalid Request*/) log.Printf("Found %d accounts when importing OFX, expected 1", len(itl.Accounts)) - return - } - - sqltransaction, err := db.Begin() - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return + return NewError(3 /*Invalid Request*/) } // Return Account with this Id - account, err := GetAccountTx(sqltransaction, accountid, user.UserId) + account, err := GetAccountTx(tx, accountid, user.UserId) if err != nil { - sqltransaction.Rollback() - WriteError(w, 3 /*Invalid Request*/) log.Print(err) - return + return NewError(3 /*Invalid Request*/) } importedAccount := itl.Accounts[0] if len(account.ExternalAccountId) > 0 && account.ExternalAccountId != importedAccount.ExternalAccountId { - sqltransaction.Rollback() - WriteError(w, 3 /*Invalid Request*/) log.Printf("OFX import has \"%s\" as ExternalAccountId, but the account being imported to has\"%s\"", importedAccount.ExternalAccountId, account.ExternalAccountId) - return + return NewError(3 /*Invalid Request*/) } // Find matching existing securities or create new ones for those @@ -74,21 +61,17 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc // save off since ImportGetCreateSecurity overwrites SecurityId on // ofxsecurity oldsecurityid := ofxsecurity.SecurityId - security, err := ImportGetCreateSecurity(sqltransaction, user.UserId, &ofxsecurity) + security, err := ImportGetCreateSecurity(tx, user.UserId, &ofxsecurity) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } securitymap[oldsecurityid] = *security } if account.SecurityId != securitymap[importedAccount.SecurityId].SecurityId { - sqltransaction.Rollback() - WriteError(w, 3 /*Invalid Request*/) log.Printf("OFX import account's SecurityId (%d) does not match this account's (%d)", securitymap[importedAccount.SecurityId].SecurityId, account.SecurityId) - return + return NewError(3 /*Invalid Request*/) } // TODO Ensure all transactions have at least one split in the account @@ -99,10 +82,8 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc transaction.UserId = user.UserId if !transaction.Valid() { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Unexpected invalid transaction from OFX import") - return + return NewError(999 /*Internal Error*/) } // Ensure that either AccountId or SecurityId is set for this split, @@ -112,10 +93,8 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc split.Status = Imported if split.AccountId != -1 { if split.AccountId != importedAccount.AccountId { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Imported split's AccountId wasn't -1 but also didn't match the account") - return + return NewError(999 /*Internal Error*/) } split.AccountId = account.AccountId } else if split.SecurityId != -1 { @@ -123,12 +102,10 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc // TODO try to auto-match splits to existing accounts based on past transactions that look like this one if split.ImportSplitType == TradingAccount { // Find/make trading account if we're that type of split - trading_account, err := GetTradingAccount(sqltransaction, user.UserId, sec.SecurityId) + trading_account, err := GetTradingAccount(tx, user.UserId, sec.SecurityId) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Couldn't find split's SecurityId in map during OFX import") - return + return NewError(999 /*Internal Error*/) } split.AccountId = trading_account.AccountId split.SecurityId = -1 @@ -140,12 +117,10 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc SecurityId: sec.SecurityId, Type: account.Type, } - subaccount, err := GetCreateAccountTx(sqltransaction, *subaccount) + subaccount, err := GetCreateAccountTx(tx, *subaccount) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } split.AccountId = subaccount.AccountId split.SecurityId = -1 @@ -153,49 +128,39 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc split.SecurityId = sec.SecurityId } } else { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Couldn't find split's SecurityId in map during OFX import") - return + return NewError(999 /*Internal Error*/) } } else { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Neither Split.AccountId Split.SecurityId was set during OFX import") - return + return NewError(999 /*Internal Error*/) } } - imbalances, err := transaction.GetImbalancesTx(sqltransaction) + imbalances, err := transaction.GetImbalancesTx(tx) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } // Fixup any imbalances in transactions var zero big.Rat for imbalanced_security, imbalance := range imbalances { if imbalance.Cmp(&zero) != 0 { - imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, imbalanced_security) + imbalanced_account, err := GetImbalanceAccount(tx, user.UserId, imbalanced_security) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } // Add new split to fixup imbalance split := new(Split) r := new(big.Rat) r.Neg(&imbalance) - security, err := GetSecurityTx(sqltransaction, imbalanced_security, user.UserId) + security, err := GetSecurityTx(tx, imbalanced_security, user.UserId) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } split.Amount = r.FloatString(security.Precision) split.SecurityId = -1 @@ -210,24 +175,20 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc var already_imported bool for _, split := range transaction.Splits { if split.SecurityId != -1 || split.AccountId == -1 { - imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, split.SecurityId) + imbalanced_account, err := GetImbalanceAccount(tx, user.UserId, split.SecurityId) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } split.AccountId = imbalanced_account.AccountId split.SecurityId = -1 } - exists, err := split.AlreadyImportedTx(sqltransaction) + exists, err := split.AlreadyImportedTx(tx) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Error checking if split was already imported:", err) - return + return NewError(999 /*Internal Error*/) } else if exists { already_imported = true } @@ -239,55 +200,38 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc } for _, transaction := range transactions { - err := InsertTransactionTx(sqltransaction, &transaction, user) + err := InsertTransactionTx(tx, &transaction, user) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } } - err = sqltransaction.Commit() - if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } - - WriteSuccess(w) + return SuccessWriter{} } -func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) { +func OFXImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter { download_json := r.PostFormValue("ofxdownload") if download_json == "" { - log.Print("download_json") - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var ofxdownload OFXDownload err := ofxdownload.Read(download_json) if err != nil { - log.Print("ofxdownload.Read") - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - account, err := GetAccount(db, accountid, user.UserId) + account, err := GetAccount(tx, accountid, user.UserId) if err != nil { - log.Print("GetAccount") - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } ofxver := ofxgo.OfxVersion203 if len(account.OFXVersion) != 0 { ofxver, err = ofxgo.NewOfxVersion(account.OFXVersion) if err != nil { - log.Print("NewOfxVersion") - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } } @@ -308,9 +252,8 @@ func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User transactionuid, err := ofxgo.RandomUID() if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Println("Error creating uid for transaction:", err) - return + return NewError(999 /*Internal Error*/) } if account.Type == Investment { @@ -343,8 +286,7 @@ func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User // Import generic bank transactions acctTypeEnum, err := ofxgo.NewAcctType(account.OFXAcctType) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } statementRequest := ofxgo.StatementRequest{ TrnUID: *transactionuid, @@ -361,49 +303,46 @@ func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User response, err := client.RequestNoParse(&query) if err != nil { // TODO this could be an error talking with the OFX server... - WriteError(w, 3 /*Invalid Request*/) log.Print(err) - return + return NewError(3 /*Invalid Request*/) } defer response.Body.Close() - ofxImportHelper(db, response.Body, w, user, accountid) + return ofxImportHelper(tx, response.Body, user, accountid) } -func OFXFileImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) { +func OFXFileImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter { multipartReader, err := r.MultipartReader() if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } // assume there is only one 'part' part, err := multipartReader.NextPart() if err != nil { if err == io.EOF { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) log.Print("Encountered unexpected EOF") } else { - WriteError(w, 999 /*Internal Error*/) + return NewError(999 /*Internal Error*/) log.Print(err) } - return } - ofxImportHelper(db, part, w, user, accountid) + return ofxImportHelper(tx, part, user, accountid) } /* * Assumes the User is a valid, signed-in user, but accountid has not yet been validated */ -func AccountImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) { +func AccountImportHandler(tx *Tx, r *http.Request, user *User, accountid int64, importtype string) ResponseWriterWriter { switch importtype { case "ofx": - OFXImportHandler(db, w, r, user, accountid) + return OFXImportHandler(tx, r, user, accountid) case "ofxfile": - OFXFileImportHandler(db, w, r, user, accountid) + return OFXFileImportHandler(tx, r, user, accountid) default: - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } } diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index 42fa512..9be73f0 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -91,23 +91,6 @@ func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Securi } } -func GetClosestPrice(db *DB, security, currency *Security, date *time.Time) (*Price, error) { - transaction, err := db.Begin() - if err != nil { - return nil, err - } - - price, err := GetClosestPriceTx(transaction, security, currency, date) - if err != nil { - transaction.Rollback() - return nil, err - } - - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - - return price, nil +func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { + return GetClosestPriceTx(tx, security, currency, date) } diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index f42d00f..7ddbb4e 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -77,36 +77,36 @@ func (r *Tabulation) Write(w http.ResponseWriter) error { return enc.Encode(r) } -func GetReport(db *DB, reportid int64, userid int64) (*Report, error) { +func GetReport(tx *Tx, reportid int64, userid int64) (*Report, error) { var r Report - err := db.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) + err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) if err != nil { return nil, err } return &r, nil } -func GetReports(db *DB, userid int64) (*[]Report, error) { +func GetReports(tx *Tx, userid int64) (*[]Report, error) { var reports []Report - _, err := db.Select(&reports, "SELECT * from reports where UserId=?", userid) + _, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid) if err != nil { return nil, err } return &reports, nil } -func InsertReport(db *DB, r *Report) error { - err := db.Insert(r) +func InsertReport(tx *Tx, r *Report) error { + err := tx.Insert(r) if err != nil { return err } return nil } -func UpdateReport(db *DB, r *Report) error { - count, err := db.Update(r) +func UpdateReport(tx *Tx, r *Report) error { + count, err := tx.Update(r) if err != nil { return err } @@ -116,8 +116,8 @@ func UpdateReport(db *DB, r *Report) error { return nil } -func DeleteReport(db *DB, r *Report) error { - count, err := db.Delete(r) +func DeleteReport(tx *Tx, r *Report) error { + count, err := tx.Delete(r) if err != nil { return err } @@ -127,14 +127,14 @@ func DeleteReport(db *DB, r *Report) error { return nil } -func runReport(db *DB, user *User, report *Report) (*Tabulation, error) { +func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) { // Create a new LState without opening the default libs for security L := lua.NewState(lua.Options{SkipOpenLibs: true}) defer L.Close() // Create a new context holding the current user with a timeout ctx := context.WithValue(context.Background(), userContextKey, user) - ctx = context.WithValue(ctx, dbContextKey, db) + ctx = context.WithValue(ctx, dbContextKey, tx) ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second) defer cancel() L.SetContext(ctx) @@ -191,79 +191,60 @@ func runReport(db *DB, user *User, report *Report) (*Tabulation, error) { } } -func ReportTabulationHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, reportid int64) { - report, err := GetReport(db, reportid, user.UserId) +func ReportTabulationHandler(tx *Tx, r *http.Request, user *User, reportid int64) ResponseWriterWriter { + report, err := GetReport(tx, reportid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - tabulation, err := runReport(db, user, report) + tabulation, err := runReport(tx, user, report) if err != nil { // TODO handle different failure cases differently log.Print("runReport returned:", err) - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } tabulation.ReportId = reportid - err = tabulation.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return tabulation } -func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) { - user, err := GetUserFromSession(db, r) +func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { + user, err := GetUserFromSession(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } if r.Method == "POST" { report_json := r.PostFormValue("report") if report_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var report Report err := report.Read(report_json) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } report.ReportId = -1 report.UserId = user.UserId - err = InsertReport(db, &report) + err = InsertReport(tx, &report) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - w.WriteHeader(201 /*Created*/) - err = report.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return ResponseWrapper{201, &report} } else if r.Method == "GET" { if reportTabulationRE.MatchString(r.URL.Path) { var reportid int64 n, err := GetURLPieces(r.URL.Path, "/report/%d/tabulation", &reportid) if err != nil || n != 1 { - WriteError(w, 999 /*InternalError*/) log.Print(err) - return + return NewError(999 /*InternalError*/) } - ReportTabulationHandler(db, w, r, user, reportid) - return + return ReportTabulationHandler(tx, r, user, reportid) } var reportid int64 @@ -271,84 +252,62 @@ func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) { if err != nil || n != 1 { //Return all Reports var rl ReportList - reports, err := GetReports(db, user.UserId) + reports, err := GetReports(tx, user.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } rl.Reports = reports - err = (&rl).Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &rl } else { // Return Report with this Id - report, err := GetReport(db, reportid, user.UserId) + report, err := GetReport(tx, reportid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = report.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return report } } else { reportid, err := GetURLID(r.URL.Path) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } if r.Method == "PUT" { report_json := r.PostFormValue("report") if report_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var report Report err := report.Read(report_json) if err != nil || report.ReportId != reportid { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } report.UserId = user.UserId - err = UpdateReport(db, &report) + err = UpdateReport(tx, &report) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - err = report.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &report } else if r.Method == "DELETE" { - report, err := GetReport(db, reportid, user.UserId) + report, err := GetReport(tx, reportid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = DeleteReport(db, report) + err = DeleteReport(tx, report) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - WriteSuccess(w) + return SuccessWriter{} } } + return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index 7333d65..c4e7a47 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -103,10 +103,10 @@ func FindCurrencyTemplate(iso4217 int64) *Security { return nil } -func GetSecurity(db *DB, securityid int64, userid int64) (*Security, error) { +func GetSecurity(tx *Tx, securityid int64, userid int64) (*Security, error) { var s Security - err := db.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) + err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) if err != nil { return nil, err } @@ -123,18 +123,18 @@ func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64 return &s, nil } -func GetSecurities(db *DB, userid int64) (*[]*Security, error) { +func GetSecurities(tx *Tx, userid int64) (*[]*Security, error) { var securities []*Security - _, err := db.Select(&securities, "SELECT * from securities where UserId=?", userid) + _, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) if err != nil { return nil, err } return &securities, nil } -func InsertSecurity(db *DB, s *Security) error { - err := db.Insert(s) +func InsertSecurity(tx *Tx, s *Security) error { + err := tx.Insert(s) if err != nil { return err } @@ -149,37 +149,22 @@ func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error { return nil } -func UpdateSecurity(db *DB, s *Security) error { - transaction, err := db.Begin() +func UpdateSecurity(tx *Tx, s *Security) (err error) { + user, err := GetUserTx(tx, s.UserId) if err != nil { - return err - } - - user, err := GetUserTx(transaction, s.UserId) - if err != nil { - transaction.Rollback() - return err + return } else if user.DefaultCurrency == s.SecurityId && s.Type != Currency { - transaction.Rollback() return errors.New("Cannot change security which is user's default currency to be non-currency") } - count, err := transaction.Update(s) + count, err := tx.Update(s) if err != nil { - transaction.Rollback() - return err + return } if count != 1 { - transaction.Rollback() return errors.New("Updated more than one security") } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } @@ -191,53 +176,36 @@ func (e SecurityInUseError) Error() string { return e.message } -func DeleteSecurity(db *DB, s *Security) error { - transaction, err := db.Begin() - if err != nil { - return err - } - +func DeleteSecurity(tx *Tx, s *Security) error { // First, ensure no accounts are using this security - accounts, err := transaction.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) + accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) if accounts != 0 { - transaction.Rollback() return SecurityInUseError{"One or more accounts still use this security"} } - user, err := GetUserTx(transaction, s.UserId) + user, err := GetUserTx(tx, s.UserId) if err != nil { - transaction.Rollback() return err } else if user.DefaultCurrency == s.SecurityId { - transaction.Rollback() return SecurityInUseError{"Cannot delete security which is user's default currency"} } // Remove all prices involving this security (either of this security, or // using it as a currency) - _, err = transaction.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) + _, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) if err != nil { - transaction.Rollback() return err } - count, err := transaction.Delete(s) + count, err := tx.Delete(s) if err != nil { - transaction.Rollback() return err } if count != 1 { - transaction.Rollback() return errors.New("Deleted more than one security") } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } @@ -294,43 +262,33 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi return security, nil } -func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) { - user, err := GetUserFromSession(db, r) +func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { + user, err := GetUserFromSessionTx(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } if r.Method == "POST" { security_json := r.PostFormValue("security") if security_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var security Security err := security.Read(security_json) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } security.SecurityId = -1 security.UserId = user.UserId - err = InsertSecurity(db, &security) + err = InsertSecurityTx(tx, &security) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - w.WriteHeader(201 /*Created*/) - err = security.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return ResponseWrapper{201, &security} } else if r.Method == "GET" { var securityid int64 n, err := GetURLPieces(r.URL.Path, "/security/%d", &securityid) @@ -339,87 +297,65 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) { //Return all securities var sl SecurityList - securities, err := GetSecurities(db, user.UserId) + securities, err := GetSecurities(tx, user.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } sl.Securities = securities - err = (&sl).Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &sl } else { - security, err := GetSecurity(db, securityid, user.UserId) + security, err := GetSecurityTx(tx, securityid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = security.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return security } } else { securityid, err := GetURLID(r.URL.Path) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } if r.Method == "PUT" { security_json := r.PostFormValue("security") if security_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var security Security err := security.Read(security_json) if err != nil || security.SecurityId != securityid { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } security.UserId = user.UserId - err = UpdateSecurity(db, &security) + err = UpdateSecurity(tx, &security) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - err = security.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &security } else if r.Method == "DELETE" { - security, err := GetSecurity(db, securityid, user.UserId) + security, err := GetSecurityTx(tx, securityid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = DeleteSecurity(db, security) + err = DeleteSecurity(tx, security) if _, ok := err.(SecurityInUseError); ok { - WriteError(w, 7 /*In Use Error*/) + return NewError(7 /*In Use Error*/) } else if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - WriteSuccess(w) + return SuccessWriter{} } } + return NewError(3 /*Invalid Request*/) } func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) { diff --git a/internal/handlers/securities_lua.go b/internal/handlers/securities_lua.go index a13614b..4acd979 100644 --- a/internal/handlers/securities_lua.go +++ b/internal/handlers/securities_lua.go @@ -13,9 +13,9 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { ctx := L.Context() - db, ok := ctx.Value(dbContextKey).(*DB) + tx, ok := ctx.Value(dbContextKey).(*Tx) if !ok { - return nil, errors.New("Couldn't find DB in lua's Context") + return nil, errors.New("Couldn't find tx in lua's Context") } security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security) @@ -25,7 +25,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { return nil, errors.New("Couldn't find User in lua's Context") } - securities, err := GetSecurities(db, user.UserId) + securities, err := GetSecurities(tx, user.UserId) if err != nil { return nil, err } @@ -155,12 +155,12 @@ func luaClosestPrice(L *lua.LState) int { date := luaCheckTime(L, 3) ctx := L.Context() - db, ok := ctx.Value(dbContextKey).(*DB) + tx, ok := ctx.Value(dbContextKey).(*Tx) if !ok { - panic("Couldn't find DB in lua's Context") + panic("Couldn't find tx in lua's Context") } - p, err := GetClosestPrice(db, s, c, date) + p, err := GetClosestPrice(tx, s, c, date) if err != nil { L.Push(lua.LNil) } else { diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index 2cfa56a..3cc38d2 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -28,7 +28,7 @@ func (s *Session) Read(json_str string) error { return dec.Decode(s) } -func GetSession(db *DB, r *http.Request) (*Session, error) { +func GetSession(tx *Tx, r *http.Request) (*Session, error) { var s Session cookie, err := r.Cookie("moneygo-session") @@ -37,18 +37,33 @@ func GetSession(db *DB, r *http.Request) (*Session, error) { } s.SessionSecret = cookie.Value - err = db.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) + err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) if err != nil { return nil, err } return &s, nil } -func DeleteSessionIfExists(db *DB, r *http.Request) error { - // TODO do this in one transaction - session, err := GetSession(db, r) +func GetSessionTx(tx *Tx, r *http.Request) (*Session, error) { + var s Session + + cookie, err := r.Cookie("moneygo-session") + if err != nil { + return nil, fmt.Errorf("moneygo-session cookie not set") + } + s.SessionSecret = cookie.Value + + err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) + if err != nil { + return nil, err + } + return &s, nil +} + +func DeleteSessionIfExists(tx *Tx, r *http.Request) error { + session, err := GetSessionTx(tx, r) if err == nil { - _, err := db.Delete(session) + _, err := tx.Delete(session) if err != nil { return err } @@ -64,7 +79,17 @@ func NewSessionCookie() (string, error) { return base64.StdEncoding.EncodeToString(bits), nil } -func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) { +type NewSessionWriter struct { + session *Session + cookie *http.Cookie +} + +func (n *NewSessionWriter) Write(w http.ResponseWriter) error { + http.SetCookie(w, n.cookie) + return n.session.Write(w) +} + +func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) { s := Session{} session_secret, err := NewSessionCookie() @@ -81,79 +106,66 @@ func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (* Secure: true, HttpOnly: true, } - http.SetCookie(w, &cookie) s.SessionSecret = session_secret s.UserId = userid - err = db.Insert(&s) + err = tx.Insert(&s) if err != nil { return nil, err } - return &s, nil + return &NewSessionWriter{&s, &cookie}, nil } -func SessionHandler(w http.ResponseWriter, r *http.Request, db *DB) { +func SessionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { if r.Method == "POST" || r.Method == "PUT" { user_json := r.PostFormValue("user") if user_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } user := User{} err := user.Read(user_json) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - dbuser, err := GetUserByUsername(db, user.Username) + dbuser, err := GetUserByUsername(tx, user.Username) if err != nil { - WriteError(w, 2 /*Unauthorized Access*/) - return + return NewError(2 /*Unauthorized Access*/) } user.HashPassword() if user.PasswordHash != dbuser.PasswordHash { - WriteError(w, 2 /*Unauthorized Access*/) - return + return NewError(2 /*Unauthorized Access*/) } - err = DeleteSessionIfExists(db, r) + err = DeleteSessionIfExists(tx, r) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - session, err := NewSession(db, w, r, dbuser.UserId) + sessionwriter, err := NewSession(tx, r, dbuser.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) - return - } - - err = session.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } + return sessionwriter } else if r.Method == "GET" { - s, err := GetSession(db, r) + s, err := GetSessionTx(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } - s.Write(w) + return s } else if r.Method == "DELETE" { - err := DeleteSessionIfExists(db, r) + err := DeleteSessionIfExists(tx, r) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - WriteSuccess(w) + return SuccessWriter{} } + return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 72c5ab2..5a75ec9 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -178,72 +178,48 @@ func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) { return true, nil } -func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) { +func GetTransaction(tx *Tx, transactionid int64, userid int64) (*Transaction, error) { var t Transaction - transaction, err := db.Begin() + err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) if err != nil { return nil, err } - err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) + _, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) if err != nil { - transaction.Rollback() - return nil, err - } - - _, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) - if err != nil { - transaction.Rollback() - return nil, err - } - - err = transaction.Commit() - if err != nil { - transaction.Rollback() return nil, err } return &t, nil } -func GetTransactions(db *DB, userid int64) (*[]Transaction, error) { +func GetTransactions(tx *Tx, userid int64) (*[]Transaction, error) { var transactions []Transaction - transaction, err := db.Begin() - if err != nil { - return nil, err - } - - _, err = transaction.Select(&transactions, "SELECT * from transactions where UserId=?", userid) + _, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) if err != nil { return nil, err } for i := range transactions { - _, err := transaction.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) + _, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) if err != nil { return nil, err } } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - return &transactions, nil } -func incrementAccountVersions(transaction *gorp.Transaction, user *User, accountids []int64) error { +func incrementAccountVersions(tx *Tx, user *User, accountids []int64) error { for i := range accountids { - account, err := GetAccountTx(transaction, accountids[i], user.UserId) + account, err := GetAccountTx(tx, accountids[i], user.UserId) if err != nil { return err } account.AccountVersion++ - count, err := transaction.Update(account) + count, err := tx.Update(account) if err != nil { return err } @@ -260,12 +236,12 @@ func (ame AccountMissingError) Error() string { return "Account missing" } -func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error { +func InsertTransactionTx(tx *Tx, t *Transaction, user *User) error { // Map of any accounts with transaction splits being added a_map := make(map[int64]bool) for i := range t.Splits { if t.Splits[i].AccountId != -1 { - existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) + existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) if err != nil { return err } @@ -287,13 +263,13 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us if len(a_ids) < 1 { return AccountMissingError{} } - err := incrementAccountVersions(transaction, user, a_ids) + err := incrementAccountVersions(tx, user, a_ids) if err != nil { return err } t.UserId = user.UserId - err = transaction.Insert(t) + err = tx.Insert(t) if err != nil { return err } @@ -301,7 +277,7 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us for i := range t.Splits { t.Splits[i].TransactionId = t.TransactionId t.Splits[i].SplitId = -1 - err = transaction.Insert(t.Splits[i]) + err = tx.Insert(t.Splits[i]) if err != nil { return err } @@ -310,31 +286,19 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us return nil } -func InsertTransaction(db *DB, t *Transaction, user *User) error { - transaction, err := db.Begin() +func InsertTransaction(tx *Tx, t *Transaction, user *User) error { + err := InsertTransactionTx(tx, t, user) if err != nil { return err } - err = InsertTransactionTx(transaction, t, user) - if err != nil { - transaction.Rollback() - return err - } - - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } -func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error { +func UpdateTransactionTx(tx *Tx, t *Transaction, user *User) error { var existing_splits []*Split - _, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) + _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) if err != nil { return err } @@ -353,7 +317,7 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us t.Splits[i].TransactionId = t.TransactionId _, ok := s_map[t.Splits[i].SplitId] if ok { - count, err := transaction.Update(t.Splits[i]) + count, err := tx.Update(t.Splits[i]) if err != nil { return err } @@ -363,7 +327,7 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us delete(s_map, t.Splits[i].SplitId) } else { t.Splits[i].SplitId = -1 - err := transaction.Insert(t.Splits[i]) + err := tx.Insert(t.Splits[i]) if err != nil { return err } @@ -380,7 +344,7 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us a_map[existing_splits[i].AccountId] = true } if ok { - _, err := transaction.Delete(existing_splits[i]) + _, err := tx.Delete(existing_splits[i]) if err != nil { return err } @@ -392,12 +356,12 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us for id := range a_map { a_ids = append(a_ids, id) } - err = incrementAccountVersions(transaction, user, a_ids) + err = incrementAccountVersions(tx, user, a_ids) if err != nil { return err } - count, err := transaction.Update(t) + count, err := tx.Update(t) if err != nil { return err } @@ -408,257 +372,165 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us return nil } -func DeleteTransaction(db *DB, t *Transaction, user *User) error { - transaction, err := db.Begin() - if err != nil { - return err - } - +func DeleteTransaction(tx *Tx, t *Transaction, user *User) error { var accountids []int64 - _, err = transaction.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) + _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) + _, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) if err != nil { - transaction.Rollback() return err } - count, err := transaction.Delete(t) + count, err := tx.Delete(t) if err != nil { - transaction.Rollback() return err } if count != 1 { - transaction.Rollback() return errors.New("Deleted more than one transaction") } - err = incrementAccountVersions(transaction, user, accountids) + err = incrementAccountVersions(tx, user, accountids) if err != nil { - transaction.Rollback() - return err - } - - err = transaction.Commit() - if err != nil { - transaction.Rollback() return err } return nil } -func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) { - user, err := GetUserFromSession(db, r) +func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { + user, err := GetUserFromSession(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } if r.Method == "POST" { transaction_json := r.PostFormValue("transaction") if transaction_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var transaction Transaction err := transaction.Read(transaction_json) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } transaction.TransactionId = -1 transaction.UserId = user.UserId - sqltx, err := db.Begin() + balanced, err := transaction.Balanced(tx) if err != nil { - WriteError(w, 999 /*Internal Error*/) + return NewError(999 /*Internal Error*/) log.Print(err) - return - } - - balanced, err := transaction.Balanced(sqltx) - if err != nil { - sqltx.Rollback() - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return } if !transaction.Valid() || !balanced { - sqltx.Rollback() - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } for i := range transaction.Splits { transaction.Splits[i].SplitId = -1 - _, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId) + _, err := GetAccountTx(tx, transaction.Splits[i].AccountId, user.UserId) if err != nil { - sqltx.Rollback() - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } } - err = InsertTransactionTx(sqltx, &transaction, user) + err = InsertTransactionTx(tx, &transaction, user) if err != nil { if _, ok := err.(AccountMissingError); ok { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } else { - WriteError(w, 999 /*Internal Error*/) log.Print(err) + return NewError(999 /*Internal Error*/) } - sqltx.Rollback() - return } - err = sqltx.Commit() - if err != nil { - sqltx.Rollback() - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } - - err = transaction.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &transaction } else if r.Method == "GET" { transactionid, err := GetURLID(r.URL.Path) if err != nil { //Return all Transactions var al TransactionList - transactions, err := GetTransactions(db, user.UserId) + transactions, err := GetTransactions(tx, user.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } al.Transactions = transactions - err = (&al).Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &al } else { //Return Transaction with this Id - transaction, err := GetTransaction(db, transactionid, user.UserId) + transaction, err := GetTransaction(tx, transactionid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return - } - err = transaction.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return + return NewError(3 /*Invalid Request*/) } + return transaction } } else { transactionid, err := GetURLID(r.URL.Path) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } if r.Method == "PUT" { transaction_json := r.PostFormValue("transaction") if transaction_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var transaction Transaction err := transaction.Read(transaction_json) if err != nil || transaction.TransactionId != transactionid { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } transaction.UserId = user.UserId - sqltx, err := db.Begin() + balanced, err := transaction.Balanced(tx) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return - } - - balanced, err := transaction.Balanced(sqltx) - if err != nil { - sqltx.Rollback() - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return + return NewError(999 /*Internal Error*/) } if !transaction.Valid() || !balanced { - sqltx.Rollback() - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } for i := range transaction.Splits { - _, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId) + _, err := GetAccountTx(tx, transaction.Splits[i].AccountId, user.UserId) if err != nil { - sqltx.Rollback() - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } } - err = UpdateTransactionTx(sqltx, &transaction, user) + err = UpdateTransactionTx(tx, &transaction, user) if err != nil { - sqltx.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - err = sqltx.Commit() - if err != nil { - sqltx.Rollback() - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } - - err = transaction.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &transaction } else if r.Method == "DELETE" { transactionid, err := GetURLID(r.URL.Path) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - transaction, err := GetTransaction(db, transactionid, user.UserId) + transaction, err := GetTransaction(tx, transactionid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = DeleteTransaction(db, transaction, user) + err = DeleteTransaction(tx, transaction, user) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - WriteSuccess(w) + return SuccessWriter{} } } + return NewError(3 /*Invalid Request*/) } func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int64, transactions []Transaction) (*big.Rat, error) { @@ -685,17 +557,12 @@ func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int6 return &pageDifference, nil } -func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) { +func GetAccountBalance(tx *Tx, user *User, accountid int64) (*big.Rat, error) { var splits []Split - transaction, err := db.Begin() - if err != nil { - return nil, err - } sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" - _, err = transaction.Select(&splits, sql, accountid, user.UserId) + _, err := tx.Select(&splits, sql, accountid, user.UserId) if err != nil { - transaction.Rollback() return nil, err } @@ -703,34 +570,22 @@ func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) { for _, s := range splits { rat_amount, err := GetBigAmount(s.Amount) if err != nil { - transaction.Rollback() return nil, err } tmp.Add(&balance, rat_amount) balance.Set(&tmp) } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - return &balance, nil } // Assumes accountid is valid and is owned by the current user -func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) { +func GetAccountBalanceDate(tx *Tx, user *User, accountid int64, date *time.Time) (*big.Rat, error) { var splits []Split - transaction, err := db.Begin() - if err != nil { - return nil, err - } sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" - _, err = transaction.Select(&splits, sql, accountid, user.UserId, date) + _, err := tx.Select(&splits, sql, accountid, user.UserId, date) if err != nil { - transaction.Rollback() return nil, err } @@ -738,33 +593,21 @@ func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) for _, s := range splits { rat_amount, err := GetBigAmount(s.Amount) if err != nil { - transaction.Rollback() return nil, err } tmp.Add(&balance, rat_amount) balance.Set(&tmp) } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - return &balance, nil } -func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { +func GetAccountBalanceDateRange(tx *Tx, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { var splits []Split - transaction, err := db.Begin() - if err != nil { - return nil, err - } sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" - _, err = transaction.Select(&splits, sql, accountid, user.UserId, begin, end) + _, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end) if err != nil { - transaction.Rollback() return nil, err } @@ -772,31 +615,19 @@ func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end for _, s := range splits { rat_amount, err := GetBigAmount(s.Amount) if err != nil { - transaction.Rollback() return nil, err } tmp.Add(&balance, rat_amount) balance.Set(&tmp) } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - return &balance, nil } -func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { +func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { var transactions []Transaction var atl AccountTransactionsList - transaction, err := db.Begin() - if err != nil { - return nil, err - } - var sqlsort, balanceLimitOffset string var balanceLimitOffsetArg uint64 if sort == "date-asc" { @@ -804,9 +635,8 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa balanceLimitOffset = " LIMIT ?" balanceLimitOffsetArg = page * limit } else if sort == "date-desc" { - numSplits, err := transaction.SelectInt("SELECT count(*) FROM splits") + numSplits, err := tx.SelectInt("SELECT count(*) FROM splits") if err != nil { - transaction.Rollback() return nil, err } sqlsort = " ORDER BY transactions.Date DESC" @@ -819,41 +649,35 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) } - account, err := GetAccountTx(transaction, accountid, user.UserId) + account, err := GetAccountTx(tx, accountid, user.UserId) if err != nil { - transaction.Rollback() return nil, err } atl.Account = account sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset - _, err = transaction.Select(&transactions, sql, user.UserId, accountid, limit) + _, err = tx.Select(&transactions, sql, user.UserId, accountid, limit) if err != nil { - transaction.Rollback() return nil, err } atl.Transactions = &transactions - pageDifference, err := TransactionsBalanceDifference(transaction, accountid, transactions) + pageDifference, err := TransactionsBalanceDifference(tx, accountid, transactions) if err != nil { - transaction.Rollback() return nil, err } - count, err := transaction.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) + count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) if err != nil { - transaction.Rollback() return nil, err } atl.TotalTransactions = count - security, err := GetSecurityTx(transaction, atl.Account.SecurityId, user.UserId) + security, err := GetSecurityTx(tx, atl.Account.SecurityId, user.UserId) if err != nil { - transaction.Rollback() return nil, err } if security == nil { - transaction.Rollback() return nil, errors.New("Security not found") } @@ -861,9 +685,8 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa // occurred before the page we're returning var amounts []string sql = "SELECT splits.Amount FROM splits WHERE splits.AccountId=? AND splits.TransactionId IN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ")" - _, err = transaction.Select(&amounts, sql, accountid, user.UserId, accountid, balanceLimitOffsetArg) + _, err = tx.Select(&amounts, sql, accountid, user.UserId, accountid, balanceLimitOffsetArg) if err != nil { - transaction.Rollback() return nil, err } @@ -871,7 +694,6 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa for _, amount := range amounts { rat_amount, err := GetBigAmount(amount) if err != nil { - transaction.Rollback() return nil, err } tmp.Add(&balance, rat_amount) @@ -880,20 +702,12 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa atl.BeginningBalance = balance.FloatString(security.Precision) atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - return &atl, nil } // Return only those transactions which have at least one split pertaining to // an account -func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, - user *User, accountid int64) { - +func AccountTransactionsHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter { var page uint64 = 0 var limit uint64 = 50 var sort string = "date-desc" @@ -904,8 +718,7 @@ func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, if pagestring != "" { p, err := strconv.ParseUint(pagestring, 10, 0) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } page = p } @@ -914,8 +727,7 @@ func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, if limitstring != "" { l, err := strconv.ParseUint(limitstring, 10, 0) if err != nil || l > 100 { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } limit = l } @@ -923,23 +735,16 @@ func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, sortstring := query.Get("sort") if sortstring != "" { if sortstring != "date-asc" && sortstring != "date-desc" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } sort = sortstring } - accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit) + accountTransactions, err := GetAccountTransactions(tx, user, accountid, sort, page, limit) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - err = accountTransactions.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return accountTransactions } diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 99ea60a..f6f75e1 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "gopkg.in/gorp.v1" "io" "log" "net/http" @@ -47,61 +46,52 @@ func (u *User) HashPassword() { u.Password = "" } -func GetUser(db *DB, userid int64) (*User, error) { +func GetUser(tx *Tx, userid int64) (*User, error) { var u User - err := db.SelectOne(&u, "SELECT * from users where UserId=?", userid) + err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) if err != nil { return nil, err } return &u, nil } -func GetUserTx(transaction *gorp.Transaction, userid int64) (*User, error) { +func GetUserTx(tx *Tx, userid int64) (*User, error) { var u User - err := transaction.SelectOne(&u, "SELECT * from users where UserId=?", userid) + err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) if err != nil { return nil, err } return &u, nil } -func GetUserByUsername(db *DB, username string) (*User, error) { +func GetUserByUsername(tx *Tx, username string) (*User, error) { var u User - err := db.SelectOne(&u, "SELECT * from users where Username=?", username) + err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) if err != nil { return nil, err } return &u, nil } -func InsertUser(db *DB, u *User) error { - transaction, err := db.Begin() - if err != nil { - return err - } - +func InsertUser(tx *Tx, u *User) error { security_template := FindCurrencyTemplate(u.DefaultCurrency) if security_template == nil { - transaction.Rollback() return errors.New("Invalid ISO4217 Default Currency") } - existing, err := transaction.SelectInt("SELECT count(*) from users where Username=?", u.Username) + existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username) if err != nil { - transaction.Rollback() return err } if existing > 0 { - transaction.Rollback() return UserExistsError{} } - err = transaction.Insert(u) + err = tx.Insert(u) if err != nil { - transaction.Rollback() return err } @@ -110,201 +100,146 @@ func InsertUser(db *DB, u *User) error { security = *security_template security.UserId = u.UserId - err = InsertSecurityTx(transaction, &security) + err = InsertSecurityTx(tx, &security) if err != nil { - transaction.Rollback() return err } // Update the user's DefaultCurrency to our new SecurityId u.DefaultCurrency = security.SecurityId - count, err := transaction.Update(u) + count, err := tx.Update(u) if err != nil { - transaction.Rollback() return err } else if count != 1 { - transaction.Rollback() return errors.New("Would have updated more than one user") } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } -func GetUserFromSession(db *DB, r *http.Request) (*User, error) { - s, err := GetSession(db, r) +func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) { + s, err := GetSession(tx, r) if err != nil { return nil, err } - return GetUser(db, s.UserId) + return GetUser(tx, s.UserId) } -func UpdateUser(db *DB, u *User) error { - transaction, err := db.Begin() +func GetUserFromSessionTx(tx *Tx, r *http.Request) (*User, error) { + s, err := GetSessionTx(tx, r) if err != nil { - return err + return nil, err } + return GetUserTx(tx, s.UserId) +} - security, err := GetSecurityTx(transaction, u.DefaultCurrency, u.UserId) +func UpdateUser(tx *Tx, u *User) error { + security, err := GetSecurityTx(tx, u.DefaultCurrency, u.UserId) if err != nil { - transaction.Rollback() return err } else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency { - transaction.Rollback() return errors.New("UserId and DefaultCurrency don't match the fetched security") } else if security.Type != Currency { - transaction.Rollback() return errors.New("New DefaultCurrency security is not a currency") } - count, err := transaction.Update(u) + count, err := tx.Update(u) if err != nil { - transaction.Rollback() return err } else if count != 1 { - transaction.Rollback() return errors.New("Would have updated more than one user") } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } -func DeleteUser(db *DB, u *User) error { - transaction, err := db.Begin() +func DeleteUser(tx *Tx, u *User) error { + count, err := tx.Delete(u) if err != nil { return err } - - count, err := transaction.Delete(u) - if err != nil { - transaction.Rollback() - return err - } if count != 1 { - transaction.Rollback() return fmt.Errorf("No user to delete") } - _, err = transaction.Exec("DELETE FROM prices WHERE prices.PriceId IN (SELECT prices.PriceId FROM prices INNER JOIN securities ON prices.SecurityId=securities.SecurityId WHERE securities.UserId=?)", u.UserId) + _, err = tx.Exec("DELETE FROM prices WHERE prices.PriceId IN (SELECT prices.PriceId FROM prices INNER JOIN securities ON prices.SecurityId=securities.SecurityId WHERE securities.UserId=?)", u.UserId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM splits WHERE splits.SplitId IN (SELECT splits.SplitId FROM splits INNER JOIN transactions ON splits.TransactionId=transactions.TransactionId WHERE transactions.UserId=?)", u.UserId) + _, err = tx.Exec("DELETE FROM splits WHERE splits.SplitId IN (SELECT splits.SplitId FROM splits INNER JOIN transactions ON splits.TransactionId=transactions.TransactionId WHERE transactions.UserId=?)", u.UserId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId) + _, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId) + _, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId) + _, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId) + _, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId) + _, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId) if err != nil { - transaction.Rollback() - return err - } - - err = transaction.Commit() - if err != nil { - transaction.Rollback() return err } return nil } -func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { +func UserHandler(r *http.Request, tx *Tx) ResponseWriterWriter { if r.Method == "POST" { user_json := r.PostFormValue("user") if user_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var user User err := user.Read(user_json) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } user.UserId = -1 user.HashPassword() - err = InsertUser(db, &user) + err = InsertUser(tx, &user) if err != nil { if _, ok := err.(UserExistsError); ok { - WriteError(w, 4 /*User Exists*/) + return NewError(4 /*User Exists*/) } else { - WriteError(w, 999 /*Internal Error*/) log.Print(err) + return NewError(999 /*Internal Error*/) } - return } - w.WriteHeader(201 /*Created*/) - err = user.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return ResponseWrapper{201, &user} } else { - user, err := GetUserFromSession(db, r) + user, err := GetUserFromSession(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } userid, err := GetURLID(r.URL.Path) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } if userid != user.UserId { - WriteError(w, 2 /*Unauthorized Access*/) - return + return NewError(2 /*Unauthorized Access*/) } if r.Method == "GET" { - err = user.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return user } else if r.Method == "PUT" { user_json := r.PostFormValue("user") if user_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } // Save old PWHash in case the new password is bogus @@ -312,8 +247,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { err = user.Read(user_json) if err != nil || user.UserId != userid { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } // If the user didn't create a new password, keep their old one @@ -324,27 +258,21 @@ func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { user.PasswordHash = old_pwhash } - err = UpdateUser(db, user) + err = UpdateUser(tx, user) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - err = user.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return user } else if r.Method == "DELETE" { - err := DeleteUser(db, user) + err := DeleteUser(tx, user) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - WriteSuccess(w) + return SuccessWriter{} } } + return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/util.go b/internal/handlers/util.go index 3579a13..b3e9fa7 100644 --- a/internal/handlers/util.go +++ b/internal/handlers/util.go @@ -18,6 +18,23 @@ func GetURLPieces(url string, format string, a ...interface{}) (int, error) { return fmt.Sscanf(url, format, a...) } +type ResponseWrapper struct { + Code int + Writer ResponseWriterWriter +} + +func (r ResponseWrapper) Write(w http.ResponseWriter) error { + w.WriteHeader(r.Code) + return r.Writer.Write(w) +} + +type SuccessWriter struct{} + +func (s SuccessWriter) Write(w http.ResponseWriter) error { + fmt.Fprint(w, "{}") + return nil +} + func WriteSuccess(w http.ResponseWriter) { fmt.Fprint(w, "{}") } From 2ff1f474323037ac26c59e4cc32227a72f6af7fd Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sat, 14 Oct 2017 19:41:13 -0400 Subject: [PATCH 2/2] Remove duplicate *Tx versions of database access methods Simplify naming to remove "Tx" now that all handlers only have access to transactions anyway, and always use "tx" as the name of the variable representing the SQL transactions (to make it less likely to cause confusion with monetary transactions). --- internal/handlers/accounts.go | 40 ++++++++++------------------ internal/handlers/gnucash.go | 6 ++--- internal/handlers/imports.go | 12 ++++----- internal/handlers/prices.go | 31 +++++++++------------- internal/handlers/securities.go | 39 +++++++-------------------- internal/handlers/sessions.go | 20 ++------------ internal/handlers/transactions.go | 44 ++++++++++++------------------- internal/handlers/users.go | 22 ++-------------- 8 files changed, 67 insertions(+), 147 deletions(-) diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index 5ba8708..4d1ff86 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -3,7 +3,6 @@ package handlers import ( "encoding/json" "errors" - "gopkg.in/gorp.v1" "log" "net/http" "regexp" @@ -139,17 +138,6 @@ func GetAccount(tx *Tx, accountid int64, userid int64) (*Account, error) { return &a, nil } -func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64) (*Account, error) { - var a Account - - err := transaction.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) - if err != nil { - return nil, err - } - - return &a, nil -} - func GetAccounts(tx *Tx, userid int64) (*[]Account, error) { var accounts []Account @@ -162,12 +150,12 @@ func GetAccounts(tx *Tx, userid int64) (*[]Account, error) { // Get (and attempt to create if it doesn't exist). Matches on UserId, // SecurityId, Type, Name, and ParentAccountId -func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, error) { +func GetCreateAccount(tx *Tx, a Account) (*Account, error) { var accounts []Account var account Account // Try to find the top-level trading account - _, err := transaction.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId) + _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId) if err != nil { return nil, err } @@ -180,7 +168,7 @@ func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, err account.Name = a.Name account.ParentAccountId = a.ParentAccountId - err = transaction.Insert(&account) + err = tx.Insert(&account) if err != nil { return nil, err } @@ -190,11 +178,11 @@ func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, err // Get (and attempt to create if it doesn't exist) the security/currency // trading account for the supplied security/currency -func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) { +func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*Account, error) { var tradingAccount Account var account Account - user, err := GetUserTx(transaction, userid) + user, err := GetUser(tx, userid) if err != nil { return nil, err } @@ -206,12 +194,12 @@ func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid i tradingAccount.ParentAccountId = -1 // Find/create the top-level trading account - ta, err := GetCreateAccountTx(transaction, tradingAccount) + ta, err := GetCreateAccount(tx, tradingAccount) if err != nil { return nil, err } - security, err := GetSecurityTx(transaction, securityid, userid) + security, err := GetSecurity(tx, securityid, userid) if err != nil { return nil, err } @@ -222,7 +210,7 @@ func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid i account.SecurityId = securityid account.Type = Trading - a, err := GetCreateAccountTx(transaction, account) + a, err := GetCreateAccount(tx, account) if err != nil { return nil, err } @@ -232,14 +220,14 @@ func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid i // Get (and attempt to create if it doesn't exist) the security/currency // imbalance account for the supplied security/currency -func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) { +func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*Account, error) { var imbalanceAccount Account var account Account xxxtemplate := FindSecurityTemplate("XXX", Currency) if xxxtemplate == nil { return nil, errors.New("Couldn't find XXX security template") } - xxxsecurity, err := ImportGetCreateSecurity(transaction, userid, xxxtemplate) + xxxsecurity, err := ImportGetCreateSecurity(tx, userid, xxxtemplate) if err != nil { return nil, errors.New("Couldn't create XXX security") } @@ -251,12 +239,12 @@ func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid imbalanceAccount.Type = Bank // Find/create the top-level trading account - ia, err := GetCreateAccountTx(transaction, imbalanceAccount) + ia, err := GetCreateAccount(tx, imbalanceAccount) if err != nil { return nil, err } - security, err := GetSecurityTx(transaction, securityid, userid) + security, err := GetSecurity(tx, securityid, userid) if err != nil { return nil, err } @@ -267,7 +255,7 @@ func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid account.SecurityId = securityid account.Type = Bank - a, err := GetCreateAccountTx(transaction, account) + a, err := GetCreateAccount(tx, account) if err != nil { return nil, err } @@ -330,7 +318,7 @@ func insertUpdateAccount(tx *Tx, a *Account, insert bool) error { return err } } else { - oldacct, err := GetAccountTx(tx, a.AccountId, a.UserId) + oldacct, err := GetAccount(tx, a.AccountId, a.UserId) if err != nil { return err } diff --git a/internal/handlers/gnucash.go b/internal/handlers/gnucash.go index 3ce1358..c787c31 100644 --- a/internal/handlers/gnucash.go +++ b/internal/handlers/gnucash.go @@ -407,7 +407,7 @@ func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { account.ParentAccountId = accountMap[account.ParentAccountId] } account.SecurityId = securityMap[account.SecurityId] - a, err := GetCreateAccountTx(tx, account) + a, err := GetCreateAccount(tx, account) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -436,7 +436,7 @@ func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } split.AccountId = acctId - exists, err := split.AlreadyImportedTx(tx) + exists, err := split.AlreadyImported(tx) if err != nil { log.Print("Error checking if split was already imported:", err) return NewError(999 /*Internal Error*/) @@ -445,7 +445,7 @@ func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } } if !already_imported { - err := InsertTransactionTx(tx, &transaction, user) + err := InsertTransaction(tx, &transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 9470a26..1f712af 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -37,7 +37,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW } // Return Account with this Id - account, err := GetAccountTx(tx, accountid, user.UserId) + account, err := GetAccount(tx, accountid, user.UserId) if err != nil { log.Print(err) return NewError(3 /*Invalid Request*/) @@ -117,7 +117,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW SecurityId: sec.SecurityId, Type: account.Type, } - subaccount, err := GetCreateAccountTx(tx, *subaccount) + subaccount, err := GetCreateAccount(tx, *subaccount) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -137,7 +137,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW } } - imbalances, err := transaction.GetImbalancesTx(tx) + imbalances, err := transaction.GetImbalances(tx) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -157,7 +157,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW split := new(Split) r := new(big.Rat) r.Neg(&imbalance) - security, err := GetSecurityTx(tx, imbalanced_security, user.UserId) + security, err := GetSecurity(tx, imbalanced_security, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -185,7 +185,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW split.SecurityId = -1 } - exists, err := split.AlreadyImportedTx(tx) + exists, err := split.AlreadyImported(tx) if err != nil { log.Print("Error checking if split was already imported:", err) return NewError(999 /*Internal Error*/) @@ -200,7 +200,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW } for _, transaction := range transactions { - err := InsertTransactionTx(tx, &transaction, user) + err := InsertTransaction(tx, &transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index 9be73f0..81b32d0 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -1,7 +1,6 @@ package handlers import ( - "gopkg.in/gorp.v1" "time" ) @@ -14,18 +13,18 @@ type Price struct { RemoteId string // unique ID from source, for detecting duplicates } -func InsertPriceTx(transaction *gorp.Transaction, p *Price) error { - err := transaction.Insert(p) +func InsertPrice(tx *Tx, p *Price) error { + err := tx.Insert(p) if err != nil { return err } return nil } -func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error { +func CreatePriceIfNotExist(tx *Tx, price *Price) error { if len(price.RemoteId) == 0 { // Always create a new price if we can't match on the RemoteId - err := InsertPriceTx(transaction, price) + err := InsertPrice(tx, price) if err != nil { return err } @@ -34,7 +33,7 @@ func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error { var prices []*Price - _, err := transaction.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) + _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) if err != nil { return err } @@ -43,7 +42,7 @@ func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error { return nil // price already exists } - err = InsertPriceTx(transaction, price) + err = InsertPrice(tx, price) if err != nil { return err } @@ -51,9 +50,9 @@ func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error { } // Return the latest price for security in currency units before date -func GetLatestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { +func GetLatestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { var p Price - err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) + err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { return nil, err } @@ -61,9 +60,9 @@ func GetLatestPrice(transaction *gorp.Transaction, security, currency *Security, } // Return the earliest price for security in currency units after date -func GetEarliestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { +func GetEarliestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { var p Price - err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) + err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { return nil, err } @@ -71,9 +70,9 @@ func GetEarliestPrice(transaction *gorp.Transaction, security, currency *Securit } // Return the price for security in currency closest to date -func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { - earliest, _ := GetEarliestPrice(transaction, security, currency, date) - latest, err := GetLatestPrice(transaction, security, currency, date) +func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { + earliest, _ := GetEarliestPrice(tx, security, currency, date) + latest, err := GetLatestPrice(tx, security, currency, date) // Return early if either earliest or latest are invalid if earliest == nil { @@ -90,7 +89,3 @@ func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Securi return earliest, nil } } - -func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { - return GetClosestPriceTx(tx, security, currency, date) -} diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index c4e7a47..9d047d3 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -5,7 +5,6 @@ package handlers import ( "encoding/json" "errors" - "gopkg.in/gorp.v1" "log" "net/http" "net/url" @@ -113,16 +112,6 @@ func GetSecurity(tx *Tx, securityid int64, userid int64) (*Security, error) { return &s, nil } -func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64) (*Security, error) { - var s Security - - err := transaction.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) - if err != nil { - return nil, err - } - return &s, nil -} - func GetSecurities(tx *Tx, userid int64) (*[]*Security, error) { var securities []*Security @@ -141,16 +130,8 @@ func InsertSecurity(tx *Tx, s *Security) error { return nil } -func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error { - err := transaction.Insert(s) - if err != nil { - return err - } - return nil -} - func UpdateSecurity(tx *Tx, s *Security) (err error) { - user, err := GetUserTx(tx, s.UserId) + user, err := GetUser(tx, s.UserId) if err != nil { return } else if user.DefaultCurrency == s.SecurityId && s.Type != Currency { @@ -184,7 +165,7 @@ func DeleteSecurity(tx *Tx, s *Security) error { return SecurityInUseError{"One or more accounts still use this security"} } - user, err := GetUserTx(tx, s.UserId) + user, err := GetUser(tx, s.UserId) if err != nil { return err } else if user.DefaultCurrency == s.SecurityId { @@ -209,11 +190,11 @@ func DeleteSecurity(tx *Tx, s *Security) error { return nil } -func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, security *Security) (*Security, error) { +func ImportGetCreateSecurity(tx *Tx, userid int64, security *Security) (*Security, error) { security.UserId = userid if len(security.AlternateId) == 0 { // Always create a new local security if we can't match on the AlternateId - err := InsertSecurityTx(transaction, security) + err := InsertSecurity(tx, security) if err != nil { return nil, err } @@ -222,7 +203,7 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi var securities []*Security - _, err := transaction.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Precision=?", userid, security.Type, security.AlternateId, security.Precision) + _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Precision=?", userid, security.Type, security.AlternateId, security.Precision) if err != nil { return nil, err } @@ -254,7 +235,7 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi } // If there wasn't even one security in the list, make a new one - err = InsertSecurityTx(transaction, security) + err = InsertSecurity(tx, security) if err != nil { return nil, err } @@ -263,7 +244,7 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi } func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { - user, err := GetUserFromSessionTx(tx, r) + user, err := GetUserFromSession(tx, r) if err != nil { return NewError(1 /*Not Signed In*/) } @@ -282,7 +263,7 @@ func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { security.SecurityId = -1 security.UserId = user.UserId - err = InsertSecurityTx(tx, &security) + err = InsertSecurity(tx, &security) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -306,7 +287,7 @@ func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { sl.Securities = securities return &sl } else { - security, err := GetSecurityTx(tx, securityid, user.UserId) + security, err := GetSecurity(tx, securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -339,7 +320,7 @@ func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return &security } else if r.Method == "DELETE" { - security, err := GetSecurityTx(tx, securityid, user.UserId) + security, err := GetSecurity(tx, securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index 3cc38d2..8f61a7f 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -44,24 +44,8 @@ func GetSession(tx *Tx, r *http.Request) (*Session, error) { return &s, nil } -func GetSessionTx(tx *Tx, r *http.Request) (*Session, error) { - var s Session - - cookie, err := r.Cookie("moneygo-session") - if err != nil { - return nil, fmt.Errorf("moneygo-session cookie not set") - } - s.SessionSecret = cookie.Value - - err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) - if err != nil { - return nil, err - } - return &s, nil -} - func DeleteSessionIfExists(tx *Tx, r *http.Request) error { - session, err := GetSessionTx(tx, r) + session, err := GetSession(tx, r) if err == nil { _, err := tx.Delete(session) if err != nil { @@ -153,7 +137,7 @@ func SessionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } return sessionwriter } else if r.Method == "GET" { - s, err := GetSessionTx(tx, r) + s, err := GetSession(tx, r) if err != nil { return NewError(1 /*Not Signed In*/) } diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 5a75ec9..d7879cb 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "gopkg.in/gorp.v1" "log" "math/big" "net/http" @@ -78,8 +77,8 @@ func (s *Split) Valid() bool { return err == nil } -func (s *Split) AlreadyImportedTx(transaction *gorp.Transaction) (bool, error) { - count, err := transaction.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) +func (s *Split) AlreadyImported(tx *Tx) (bool, error) { + count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) return count == 1, err } @@ -134,7 +133,7 @@ func (t *Transaction) Valid() bool { // Return a map of security ID's to big.Rat's containing the amount that // security is imbalanced by -func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]big.Rat, error) { +func (t *Transaction) GetImbalances(tx *Tx) (map[int64]big.Rat, error) { sums := make(map[int64]big.Rat) if !t.Valid() { @@ -146,7 +145,7 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64] if t.Splits[i].AccountId != -1 { var err error var account *Account - account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId) + account, err = GetAccount(tx, t.Splits[i].AccountId, t.UserId) if err != nil { return nil, err } @@ -162,10 +161,10 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64] // Returns true if all securities contained in this transaction are balanced, // false otherwise -func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) { +func (t *Transaction) Balanced(tx *Tx) (bool, error) { var zero big.Rat - sums, err := t.GetImbalancesTx(transaction) + sums, err := t.GetImbalances(tx) if err != nil { return false, err } @@ -214,7 +213,7 @@ func GetTransactions(tx *Tx, userid int64) (*[]Transaction, error) { func incrementAccountVersions(tx *Tx, user *User, accountids []int64) error { for i := range accountids { - account, err := GetAccountTx(tx, accountids[i], user.UserId) + account, err := GetAccount(tx, accountids[i], user.UserId) if err != nil { return err } @@ -236,7 +235,7 @@ func (ame AccountMissingError) Error() string { return "Account missing" } -func InsertTransactionTx(tx *Tx, t *Transaction, user *User) error { +func InsertTransaction(tx *Tx, t *Transaction, user *User) error { // Map of any accounts with transaction splits being added a_map := make(map[int64]bool) for i := range t.Splits { @@ -286,16 +285,7 @@ func InsertTransactionTx(tx *Tx, t *Transaction, user *User) error { return nil } -func InsertTransaction(tx *Tx, t *Transaction, user *User) error { - err := InsertTransactionTx(tx, t, user) - if err != nil { - return err - } - - return nil -} - -func UpdateTransactionTx(tx *Tx, t *Transaction, user *User) error { +func UpdateTransaction(tx *Tx, t *Transaction, user *User) error { var existing_splits []*Split _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) @@ -431,13 +421,13 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { for i := range transaction.Splits { transaction.Splits[i].SplitId = -1 - _, err := GetAccountTx(tx, transaction.Splits[i].AccountId, user.UserId) + _, err := GetAccount(tx, transaction.Splits[i].AccountId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } } - err = InsertTransactionTx(tx, &transaction, user) + err = InsertTransaction(tx, &transaction, user) if err != nil { if _, ok := err.(AccountMissingError); ok { return NewError(3 /*Invalid Request*/) @@ -497,13 +487,13 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } for i := range transaction.Splits { - _, err := GetAccountTx(tx, transaction.Splits[i].AccountId, user.UserId) + _, err := GetAccount(tx, transaction.Splits[i].AccountId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } } - err = UpdateTransactionTx(tx, &transaction, user) + err = UpdateTransaction(tx, &transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -533,10 +523,10 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } -func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int64, transactions []Transaction) (*big.Rat, error) { +func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []Transaction) (*big.Rat, error) { var pageDifference, tmp big.Rat for i := range transactions { - _, err := transaction.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) + _, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) if err != nil { return nil, err } @@ -649,7 +639,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) } - account, err := GetAccountTx(tx, accountid, user.UserId) + account, err := GetAccount(tx, accountid, user.UserId) if err != nil { return nil, err } @@ -673,7 +663,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa } atl.TotalTransactions = count - security, err := GetSecurityTx(tx, atl.Account.SecurityId, user.UserId) + security, err := GetSecurity(tx, atl.Account.SecurityId, user.UserId) if err != nil { return nil, err } diff --git a/internal/handlers/users.go b/internal/handlers/users.go index f6f75e1..63f3ccb 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -56,16 +56,6 @@ func GetUser(tx *Tx, userid int64) (*User, error) { return &u, nil } -func GetUserTx(tx *Tx, userid int64) (*User, error) { - var u User - - err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) - if err != nil { - return nil, err - } - return &u, nil -} - func GetUserByUsername(tx *Tx, username string) (*User, error) { var u User @@ -100,7 +90,7 @@ func InsertUser(tx *Tx, u *User) error { security = *security_template security.UserId = u.UserId - err = InsertSecurityTx(tx, &security) + err = InsertSecurity(tx, &security) if err != nil { return err } @@ -125,16 +115,8 @@ func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) { return GetUser(tx, s.UserId) } -func GetUserFromSessionTx(tx *Tx, r *http.Request) (*User, error) { - s, err := GetSessionTx(tx, r) - if err != nil { - return nil, err - } - return GetUserTx(tx, s.UserId) -} - func UpdateUser(tx *Tx, u *User) error { - security, err := GetSecurityTx(tx, u.DefaultCurrency, u.UserId) + security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId) if err != nil { return err } else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {