diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index d278b6c..4d1ff86 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -3,7 +3,6 @@ package handlers import ( "encoding/json" "errors" - "gopkg.in/gorp.v1" "log" "net/http" "regexp" @@ -129,31 +128,20 @@ func (al *AccountList) Read(json_str string) error { return dec.Decode(al) } -func GetAccount(db *DB, accountid int64, userid int64) (*Account, error) { +func GetAccount(tx *Tx, accountid int64, userid int64) (*Account, error) { var a Account - err := db.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) + err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) if err != nil { return nil, err } return &a, nil } -func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64) (*Account, error) { - var a Account - - err := transaction.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) - if err != nil { - return nil, err - } - - return &a, nil -} - -func GetAccounts(db *DB, userid int64) (*[]Account, error) { +func GetAccounts(tx *Tx, userid int64) (*[]Account, error) { var accounts []Account - _, err := db.Select(&accounts, "SELECT * from accounts where UserId=?", userid) + _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) if err != nil { return nil, err } @@ -162,12 +150,12 @@ func GetAccounts(db *DB, userid int64) (*[]Account, error) { // Get (and attempt to create if it doesn't exist). Matches on UserId, // SecurityId, Type, Name, and ParentAccountId -func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, error) { +func GetCreateAccount(tx *Tx, a Account) (*Account, error) { var accounts []Account var account Account // Try to find the top-level trading account - _, err := transaction.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId) + _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId) if err != nil { return nil, err } @@ -180,7 +168,7 @@ func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, err account.Name = a.Name account.ParentAccountId = a.ParentAccountId - err = transaction.Insert(&account) + err = tx.Insert(&account) if err != nil { return nil, err } @@ -190,11 +178,11 @@ func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, err // Get (and attempt to create if it doesn't exist) the security/currency // trading account for the supplied security/currency -func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) { +func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*Account, error) { var tradingAccount Account var account Account - user, err := GetUserTx(transaction, userid) + user, err := GetUser(tx, userid) if err != nil { return nil, err } @@ -206,12 +194,12 @@ func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid i tradingAccount.ParentAccountId = -1 // Find/create the top-level trading account - ta, err := GetCreateAccountTx(transaction, tradingAccount) + ta, err := GetCreateAccount(tx, tradingAccount) if err != nil { return nil, err } - security, err := GetSecurityTx(transaction, securityid, userid) + security, err := GetSecurity(tx, securityid, userid) if err != nil { return nil, err } @@ -222,7 +210,7 @@ func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid i account.SecurityId = securityid account.Type = Trading - a, err := GetCreateAccountTx(transaction, account) + a, err := GetCreateAccount(tx, account) if err != nil { return nil, err } @@ -232,14 +220,14 @@ func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid i // Get (and attempt to create if it doesn't exist) the security/currency // imbalance account for the supplied security/currency -func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) { +func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*Account, error) { var imbalanceAccount Account var account Account xxxtemplate := FindSecurityTemplate("XXX", Currency) if xxxtemplate == nil { return nil, errors.New("Couldn't find XXX security template") } - xxxsecurity, err := ImportGetCreateSecurity(transaction, userid, xxxtemplate) + xxxsecurity, err := ImportGetCreateSecurity(tx, userid, xxxtemplate) if err != nil { return nil, errors.New("Couldn't create XXX security") } @@ -251,12 +239,12 @@ func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid imbalanceAccount.Type = Bank // Find/create the top-level trading account - ia, err := GetCreateAccountTx(transaction, imbalanceAccount) + ia, err := GetCreateAccount(tx, imbalanceAccount) if err != nil { return nil, err } - security, err := GetSecurityTx(transaction, securityid, userid) + security, err := GetSecurity(tx, securityid, userid) if err != nil { return nil, err } @@ -267,7 +255,7 @@ func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid account.SecurityId = securityid account.Type = Bank - a, err := GetCreateAccountTx(transaction, account) + a, err := GetCreateAccount(tx, account) if err != nil { return nil, err } @@ -293,12 +281,7 @@ func (cae CircularAccountsError) Error() string { return "Would result in circular account relationship" } -func insertUpdateAccount(db *DB, a *Account, insert bool) error { - transaction, err := db.Begin() - if err != nil { - return err - } - +func insertUpdateAccount(tx *Tx, a *Account, insert bool) error { found := make(map[int64]bool) if !insert { found[a.AccountId] = true @@ -308,14 +291,12 @@ func insertUpdateAccount(db *DB, a *Account, insert bool) error { for parentid != -1 { depth += 1 if depth > 100 { - transaction.Rollback() return TooMuchNestingError{} } var a Account - err := transaction.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) + err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) if err != nil { - transaction.Rollback() return ParentAccountMissingError{} } @@ -327,107 +308,79 @@ func insertUpdateAccount(db *DB, a *Account, insert bool) error { found[parentid] = true parentid = a.ParentAccountId if _, ok := found[parentid]; ok { - transaction.Rollback() return CircularAccountsError{} } } if insert { - err = transaction.Insert(a) + err := tx.Insert(a) if err != nil { - transaction.Rollback() return err } } else { - oldacct, err := GetAccountTx(transaction, a.AccountId, a.UserId) + oldacct, err := GetAccount(tx, a.AccountId, a.UserId) if err != nil { - transaction.Rollback() return err } a.AccountVersion = oldacct.AccountVersion + 1 - count, err := transaction.Update(a) + count, err := tx.Update(a) if err != nil { - transaction.Rollback() return err } if count != 1 { - transaction.Rollback() return errors.New("Updated more than one account") } } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } -func InsertAccount(db *DB, a *Account) error { - return insertUpdateAccount(db, a, true) +func InsertAccount(tx *Tx, a *Account) error { + return insertUpdateAccount(tx, a, true) } -func UpdateAccount(db *DB, a *Account) error { - return insertUpdateAccount(db, a, false) +func UpdateAccount(tx *Tx, a *Account) error { + return insertUpdateAccount(tx, a, false) } -func DeleteAccount(db *DB, a *Account) error { - transaction, err := db.Begin() - if err != nil { - return err - } - +func DeleteAccount(tx *Tx, a *Account) error { if a.ParentAccountId != -1 { // Re-parent splits to this account's parent account if this account isn't a root account - _, err = transaction.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) + _, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) if err != nil { - transaction.Rollback() return err } } else { // Delete splits if this account is a root account - _, err = transaction.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId) + _, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId) if err != nil { - transaction.Rollback() return err } } // Re-parent child accounts to this account's parent account - _, err = transaction.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId) + _, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId) if err != nil { - transaction.Rollback() return err } - count, err := transaction.Delete(a) + count, err := tx.Delete(a) if err != nil { - transaction.Rollback() return err } if count != 1 { - transaction.Rollback() return errors.New("Was going to delete more than one account") } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } -func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) { - user, err := GetUserFromSession(db, r) +func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter { + user, err := GetUserFromSession(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } if r.Method == "POST" { @@ -439,59 +392,46 @@ func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) { n, err := GetURLPieces(r.URL.Path, "/account/%d/import/%s", &accountid, &importtype) if err != nil || n != 2 { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - AccountImportHandler(db, w, r, user, accountid, importtype) - return + return AccountImportHandler(tx, r, user, accountid, importtype) } account_json := r.PostFormValue("account") if account_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var account Account err := account.Read(account_json) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } account.AccountId = -1 account.UserId = user.UserId account.AccountVersion = 0 - security, err := GetSecurity(db, account.SecurityId, user.UserId) + security, err := GetSecurity(tx, account.SecurityId, user.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } if security == nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = InsertAccount(db, &account) + err = InsertAccount(tx, &account) if err != nil { if _, ok := err.(ParentAccountMissingError); ok { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } else { - WriteError(w, 999 /*Internal Error*/) log.Print(err) + return NewError(999 /*Internal Error*/) } - return } - w.WriteHeader(201 /*Created*/) - err = account.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return ResponseWrapper{201, &account} } else if r.Method == "GET" { var accountid int64 n, err := GetURLPieces(r.URL.Path, "/account/%d", &accountid) @@ -499,112 +439,86 @@ func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) { if err != nil || n != 1 { //Return all Accounts var al AccountList - accounts, err := GetAccounts(db, user.UserId) + accounts, err := GetAccounts(tx, user.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } al.Accounts = accounts - err = (&al).Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &al } else { // if URL looks like /account/[0-9]+/transactions, use the account // transaction handler if accountTransactionsRE.MatchString(r.URL.Path) { - AccountTransactionsHandler(db, w, r, user, accountid) - return + return AccountTransactionsHandler(tx, r, user, accountid) } // Return Account with this Id - account, err := GetAccount(db, accountid, user.UserId) + account, err := GetAccount(tx, accountid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = account.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return account } } else { accountid, err := GetURLID(r.URL.Path) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } if r.Method == "PUT" { account_json := r.PostFormValue("account") if account_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var account Account err := account.Read(account_json) if err != nil || account.AccountId != accountid { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } account.UserId = user.UserId - security, err := GetSecurity(db, account.SecurityId, user.UserId) + security, err := GetSecurity(tx, account.SecurityId, user.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } if security == nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } if account.ParentAccountId == account.AccountId { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = UpdateAccount(db, &account) + err = UpdateAccount(tx, &account) if err != nil { if _, ok := err.(ParentAccountMissingError); ok { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } else if _, ok := err.(CircularAccountsError); ok { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } else { - WriteError(w, 999 /*Internal Error*/) log.Print(err) + return NewError(999 /*Internal Error*/) } - return } - err = account.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &account } else if r.Method == "DELETE" { - account, err := GetAccount(db, accountid, user.UserId) + account, err := GetAccount(tx, accountid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = DeleteAccount(db, account) + err = DeleteAccount(tx, account) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - WriteSuccess(w) + return SuccessWriter{} } } + return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/accounts_lua.go b/internal/handlers/accounts_lua.go index 928b0ba..4edf440 100644 --- a/internal/handlers/accounts_lua.go +++ b/internal/handlers/accounts_lua.go @@ -15,9 +15,9 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { ctx := L.Context() - db, ok := ctx.Value(dbContextKey).(*DB) + tx, ok := ctx.Value(dbContextKey).(*Tx) if !ok { - return nil, errors.New("Couldn't find DB in lua's Context") + return nil, errors.New("Couldn't find tx in lua's Context") } account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account) @@ -27,7 +27,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { return nil, errors.New("Couldn't find User in lua's Context") } - accounts, err := GetAccounts(db, user.UserId) + accounts, err := GetAccounts(tx, user.UserId) if err != nil { return nil, err } @@ -149,9 +149,9 @@ func luaAccountBalance(L *lua.LState) int { a := luaCheckAccount(L, 1) ctx := L.Context() - db, ok := ctx.Value(dbContextKey).(*DB) + tx, ok := ctx.Value(dbContextKey).(*Tx) if !ok { - panic("Couldn't find DB in lua's Context") + panic("Couldn't find tx in lua's Context") } user, ok := ctx.Value(userContextKey).(*User) if !ok { @@ -171,12 +171,12 @@ func luaAccountBalance(L *lua.LState) int { if date != nil { end := luaWeakCheckTime(L, 3) if end != nil { - rat, err = GetAccountBalanceDateRange(db, user, a.AccountId, date, end) + rat, err = GetAccountBalanceDateRange(tx, user, a.AccountId, date, end) } else { - rat, err = GetAccountBalanceDate(db, user, a.AccountId, date) + rat, err = GetAccountBalanceDate(tx, user, a.AccountId, date) } } else { - rat, err = GetAccountBalance(db, user, a.AccountId) + rat, err = GetAccountBalance(tx, user, a.AccountId) } if err != nil { panic("Failed to GetAccountBalance:" + err.Error()) diff --git a/internal/handlers/errors.go b/internal/handlers/errors.go index 17a5c0c..e5f4041 100644 --- a/internal/handlers/errors.go +++ b/internal/handlers/errors.go @@ -38,13 +38,17 @@ var error_codes = map[int]string{ 999: "Internal Error", } -func WriteError(w http.ResponseWriter, error_code int) { +func NewError(error_code int) *Error { msg, ok := error_codes[error_code] if !ok { log.Printf("Error: WriteError received error code of %d", error_code) msg = error_codes[999] } - e := Error{error_code, msg} + return &Error{error_code, msg} +} + +func WriteError(w http.ResponseWriter, error_code int) { + e := NewError(error_code) err := e.Write(w) if err != nil { diff --git a/internal/handlers/gnucash.go b/internal/handlers/gnucash.go index 24f6b7c..c787c31 100644 --- a/internal/handlers/gnucash.go +++ b/internal/handlers/gnucash.go @@ -308,42 +308,37 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { return &gncimport, nil } -func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { - user, err := GetUserFromSession(db, r) +func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { + user, err := GetUserFromSession(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } if r.Method != "POST" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } multipartReader, err := r.MultipartReader() if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } // Assume there is only one 'part' and it's the one we care about part, err := multipartReader.NextPart() if err != nil { if err == io.EOF { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } else { - WriteError(w, 999 /*Internal Error*/) log.Print(err) + return NewError(999 /*Internal Error*/) } - return } bufread := bufio.NewReader(part) gzHeader, err := bufread.Peek(2) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } // Does this look like a gzipped file? @@ -351,9 +346,8 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { if gzHeader[0] == 0x1f && gzHeader[1] == 0x8b { gzr, err := gzip.NewReader(bufread) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } gnucashImport, err = ImportGnucash(gzr) } else { @@ -361,15 +355,7 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { } if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return - } - - sqltransaction, err := db.Begin() - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return + return NewError(3 /*Invalid Request*/) } // Import securities, building map from Gnucash security IDs to our @@ -377,13 +363,11 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { securityMap := make(map[int64]int64) for _, security := range gnucashImport.Securities { securityId := security.SecurityId // save off because it could be updated - s, err := ImportGetCreateSecurity(sqltransaction, user.UserId, &security) + s, err := ImportGetCreateSecurity(tx, user.UserId, &security) if err != nil { - sqltransaction.Rollback() - WriteError(w, 6 /*Import Error*/) log.Print(err) log.Print(security) - return + return NewError(6 /*Import Error*/) } securityMap[securityId] = s.SecurityId } @@ -394,12 +378,10 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { price.CurrencyId = securityMap[price.CurrencyId] price.PriceId = 0 - err := CreatePriceIfNotExist(sqltransaction, &price) + err := CreatePriceIfNotExist(tx, &price) if err != nil { - sqltransaction.Rollback() - WriteError(w, 6 /*Import Error*/) log.Print(err) - return + return NewError(6 /*Import Error*/) } } @@ -425,12 +407,10 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { account.ParentAccountId = accountMap[account.ParentAccountId] } account.SecurityId = securityMap[account.SecurityId] - a, err := GetCreateAccountTx(sqltransaction, account) + a, err := GetCreateAccount(tx, account) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } accountMap[account.AccountId] = a.AccountId accountsRemaining-- @@ -438,10 +418,8 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { } if accountsRemaining == accountsRemainingLast { //We didn't make any progress in importing the next level of accounts, so there must be a circular parent-child relationship, so give up and tell the user they're wrong - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(fmt.Errorf("Circular account parent-child relationship when importing %s", part.FileName())) - return + return NewError(999 /*Internal Error*/) } accountsRemainingLast = accountsRemaining } @@ -453,41 +431,27 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { for _, split := range transaction.Splits { acctId, ok := accountMap[split.AccountId] if !ok { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(fmt.Errorf("Error: Split's AccountID Doesn't exist: %d\n", split.AccountId)) - return + return NewError(999 /*Internal Error*/) } split.AccountId = acctId - exists, err := split.AlreadyImportedTx(sqltransaction) + exists, err := split.AlreadyImported(tx) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Error checking if split was already imported:", err) - return + return NewError(999 /*Internal Error*/) } else if exists { already_imported = true } } if !already_imported { - err := InsertTransactionTx(sqltransaction, &transaction, user) + err := InsertTransaction(tx, &transaction, user) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } } } - err = sqltransaction.Commit() - if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } - - WriteSuccess(w) + return SuccessWriter{} } diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 985d8d5..d518e92 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -2,30 +2,64 @@ package handlers import ( "gopkg.in/gorp.v1" + "log" "net/http" ) -// Create a closure over db, allowing the handlers to look like a -// http.HandlerFunc -type DB = gorp.DbMap -type DBHandler func(http.ResponseWriter, *http.Request, *DB) +// But who writes the ResponseWriterWriter? +type ResponseWriterWriter interface { + Write(http.ResponseWriter) error +} +type Tx = gorp.Transaction +type TxHandler func(*http.Request, *Tx) ResponseWriterWriter -func DBHandlerFunc(h DBHandler, db *DB) http.HandlerFunc { +func TxHandlerFunc(t TxHandler, db *gorp.DbMap) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - h(w, r, db) + tx, err := db.Begin() + if err != nil { + log.Print(err) + WriteError(w, 999 /*Internal Error*/) + return + } + defer func() { + if r := recover(); r != nil { + tx.Rollback() + WriteError(w, 999 /*Internal Error*/) + panic(r) + } + }() + + writer := t(r, tx) + + if e, ok := writer.(*Error); ok { + tx.Rollback() + e.Write(w) + } else { + err = tx.Commit() + if err != nil { + log.Print(err) + WriteError(w, 999 /*Internal Error*/) + } else { + err = writer.Write(w) + if err != nil { + log.Print(err) + WriteError(w, 999 /*Internal Error*/) + } + } + } } } -func GetHandler(db *DB) *http.ServeMux { +func GetHandler(db *gorp.DbMap) *http.ServeMux { servemux := http.NewServeMux() - servemux.HandleFunc("/session/", DBHandlerFunc(SessionHandler, db)) - servemux.HandleFunc("/user/", DBHandlerFunc(UserHandler, db)) - servemux.HandleFunc("/security/", DBHandlerFunc(SecurityHandler, db)) + servemux.HandleFunc("/session/", TxHandlerFunc(SessionHandler, db)) + servemux.HandleFunc("/user/", TxHandlerFunc(UserHandler, db)) + servemux.HandleFunc("/security/", TxHandlerFunc(SecurityHandler, db)) servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler) - servemux.HandleFunc("/account/", DBHandlerFunc(AccountHandler, db)) - servemux.HandleFunc("/transaction/", DBHandlerFunc(TransactionHandler, db)) - servemux.HandleFunc("/import/gnucash", DBHandlerFunc(GnucashImportHandler, db)) - servemux.HandleFunc("/report/", DBHandlerFunc(ReportHandler, db)) + servemux.HandleFunc("/account/", TxHandlerFunc(AccountHandler, db)) + servemux.HandleFunc("/transaction/", TxHandlerFunc(TransactionHandler, db)) + servemux.HandleFunc("/import/gnucash", TxHandlerFunc(GnucashImportHandler, db)) + servemux.HandleFunc("/report/", TxHandlerFunc(ReportHandler, db)) return servemux } diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 28a5922..1f712af 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -22,48 +22,35 @@ func (od *OFXDownload) Read(json_str string) error { return dec.Decode(od) } -func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, accountid int64) { +func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseWriterWriter { itl, err := ImportOFX(r) if err != nil { //TODO is this necessarily an invalid request (what if it was an error on our end)? - WriteError(w, 3 /*Invalid Request*/) log.Print(err) - return + return NewError(3 /*Invalid Request*/) } if len(itl.Accounts) != 1 { - WriteError(w, 3 /*Invalid Request*/) log.Printf("Found %d accounts when importing OFX, expected 1", len(itl.Accounts)) - return - } - - sqltransaction, err := db.Begin() - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return + return NewError(3 /*Invalid Request*/) } // Return Account with this Id - account, err := GetAccountTx(sqltransaction, accountid, user.UserId) + account, err := GetAccount(tx, accountid, user.UserId) if err != nil { - sqltransaction.Rollback() - WriteError(w, 3 /*Invalid Request*/) log.Print(err) - return + return NewError(3 /*Invalid Request*/) } importedAccount := itl.Accounts[0] if len(account.ExternalAccountId) > 0 && account.ExternalAccountId != importedAccount.ExternalAccountId { - sqltransaction.Rollback() - WriteError(w, 3 /*Invalid Request*/) log.Printf("OFX import has \"%s\" as ExternalAccountId, but the account being imported to has\"%s\"", importedAccount.ExternalAccountId, account.ExternalAccountId) - return + return NewError(3 /*Invalid Request*/) } // Find matching existing securities or create new ones for those @@ -74,21 +61,17 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc // save off since ImportGetCreateSecurity overwrites SecurityId on // ofxsecurity oldsecurityid := ofxsecurity.SecurityId - security, err := ImportGetCreateSecurity(sqltransaction, user.UserId, &ofxsecurity) + security, err := ImportGetCreateSecurity(tx, user.UserId, &ofxsecurity) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } securitymap[oldsecurityid] = *security } if account.SecurityId != securitymap[importedAccount.SecurityId].SecurityId { - sqltransaction.Rollback() - WriteError(w, 3 /*Invalid Request*/) log.Printf("OFX import account's SecurityId (%d) does not match this account's (%d)", securitymap[importedAccount.SecurityId].SecurityId, account.SecurityId) - return + return NewError(3 /*Invalid Request*/) } // TODO Ensure all transactions have at least one split in the account @@ -99,10 +82,8 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc transaction.UserId = user.UserId if !transaction.Valid() { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Unexpected invalid transaction from OFX import") - return + return NewError(999 /*Internal Error*/) } // Ensure that either AccountId or SecurityId is set for this split, @@ -112,10 +93,8 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc split.Status = Imported if split.AccountId != -1 { if split.AccountId != importedAccount.AccountId { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Imported split's AccountId wasn't -1 but also didn't match the account") - return + return NewError(999 /*Internal Error*/) } split.AccountId = account.AccountId } else if split.SecurityId != -1 { @@ -123,12 +102,10 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc // TODO try to auto-match splits to existing accounts based on past transactions that look like this one if split.ImportSplitType == TradingAccount { // Find/make trading account if we're that type of split - trading_account, err := GetTradingAccount(sqltransaction, user.UserId, sec.SecurityId) + trading_account, err := GetTradingAccount(tx, user.UserId, sec.SecurityId) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Couldn't find split's SecurityId in map during OFX import") - return + return NewError(999 /*Internal Error*/) } split.AccountId = trading_account.AccountId split.SecurityId = -1 @@ -140,12 +117,10 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc SecurityId: sec.SecurityId, Type: account.Type, } - subaccount, err := GetCreateAccountTx(sqltransaction, *subaccount) + subaccount, err := GetCreateAccount(tx, *subaccount) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } split.AccountId = subaccount.AccountId split.SecurityId = -1 @@ -153,49 +128,39 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc split.SecurityId = sec.SecurityId } } else { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Couldn't find split's SecurityId in map during OFX import") - return + return NewError(999 /*Internal Error*/) } } else { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Neither Split.AccountId Split.SecurityId was set during OFX import") - return + return NewError(999 /*Internal Error*/) } } - imbalances, err := transaction.GetImbalancesTx(sqltransaction) + imbalances, err := transaction.GetImbalances(tx) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } // Fixup any imbalances in transactions var zero big.Rat for imbalanced_security, imbalance := range imbalances { if imbalance.Cmp(&zero) != 0 { - imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, imbalanced_security) + imbalanced_account, err := GetImbalanceAccount(tx, user.UserId, imbalanced_security) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } // Add new split to fixup imbalance split := new(Split) r := new(big.Rat) r.Neg(&imbalance) - security, err := GetSecurityTx(sqltransaction, imbalanced_security, user.UserId) + security, err := GetSecurity(tx, imbalanced_security, user.UserId) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } split.Amount = r.FloatString(security.Precision) split.SecurityId = -1 @@ -210,24 +175,20 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc var already_imported bool for _, split := range transaction.Splits { if split.SecurityId != -1 || split.AccountId == -1 { - imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, split.SecurityId) + imbalanced_account, err := GetImbalanceAccount(tx, user.UserId, split.SecurityId) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } split.AccountId = imbalanced_account.AccountId split.SecurityId = -1 } - exists, err := split.AlreadyImportedTx(sqltransaction) + exists, err := split.AlreadyImported(tx) if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print("Error checking if split was already imported:", err) - return + return NewError(999 /*Internal Error*/) } else if exists { already_imported = true } @@ -239,55 +200,38 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc } for _, transaction := range transactions { - err := InsertTransactionTx(sqltransaction, &transaction, user) + err := InsertTransaction(tx, &transaction, user) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } } - err = sqltransaction.Commit() - if err != nil { - sqltransaction.Rollback() - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } - - WriteSuccess(w) + return SuccessWriter{} } -func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) { +func OFXImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter { download_json := r.PostFormValue("ofxdownload") if download_json == "" { - log.Print("download_json") - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var ofxdownload OFXDownload err := ofxdownload.Read(download_json) if err != nil { - log.Print("ofxdownload.Read") - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - account, err := GetAccount(db, accountid, user.UserId) + account, err := GetAccount(tx, accountid, user.UserId) if err != nil { - log.Print("GetAccount") - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } ofxver := ofxgo.OfxVersion203 if len(account.OFXVersion) != 0 { ofxver, err = ofxgo.NewOfxVersion(account.OFXVersion) if err != nil { - log.Print("NewOfxVersion") - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } } @@ -308,9 +252,8 @@ func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User transactionuid, err := ofxgo.RandomUID() if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Println("Error creating uid for transaction:", err) - return + return NewError(999 /*Internal Error*/) } if account.Type == Investment { @@ -343,8 +286,7 @@ func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User // Import generic bank transactions acctTypeEnum, err := ofxgo.NewAcctType(account.OFXAcctType) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } statementRequest := ofxgo.StatementRequest{ TrnUID: *transactionuid, @@ -361,49 +303,46 @@ func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User response, err := client.RequestNoParse(&query) if err != nil { // TODO this could be an error talking with the OFX server... - WriteError(w, 3 /*Invalid Request*/) log.Print(err) - return + return NewError(3 /*Invalid Request*/) } defer response.Body.Close() - ofxImportHelper(db, response.Body, w, user, accountid) + return ofxImportHelper(tx, response.Body, user, accountid) } -func OFXFileImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) { +func OFXFileImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter { multipartReader, err := r.MultipartReader() if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } // assume there is only one 'part' part, err := multipartReader.NextPart() if err != nil { if err == io.EOF { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) log.Print("Encountered unexpected EOF") } else { - WriteError(w, 999 /*Internal Error*/) + return NewError(999 /*Internal Error*/) log.Print(err) } - return } - ofxImportHelper(db, part, w, user, accountid) + return ofxImportHelper(tx, part, user, accountid) } /* * Assumes the User is a valid, signed-in user, but accountid has not yet been validated */ -func AccountImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) { +func AccountImportHandler(tx *Tx, r *http.Request, user *User, accountid int64, importtype string) ResponseWriterWriter { switch importtype { case "ofx": - OFXImportHandler(db, w, r, user, accountid) + return OFXImportHandler(tx, r, user, accountid) case "ofxfile": - OFXFileImportHandler(db, w, r, user, accountid) + return OFXFileImportHandler(tx, r, user, accountid) default: - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } } diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index 42fa512..81b32d0 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -1,7 +1,6 @@ package handlers import ( - "gopkg.in/gorp.v1" "time" ) @@ -14,18 +13,18 @@ type Price struct { RemoteId string // unique ID from source, for detecting duplicates } -func InsertPriceTx(transaction *gorp.Transaction, p *Price) error { - err := transaction.Insert(p) +func InsertPrice(tx *Tx, p *Price) error { + err := tx.Insert(p) if err != nil { return err } return nil } -func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error { +func CreatePriceIfNotExist(tx *Tx, price *Price) error { if len(price.RemoteId) == 0 { // Always create a new price if we can't match on the RemoteId - err := InsertPriceTx(transaction, price) + err := InsertPrice(tx, price) if err != nil { return err } @@ -34,7 +33,7 @@ func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error { var prices []*Price - _, err := transaction.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) + _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) if err != nil { return err } @@ -43,7 +42,7 @@ func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error { return nil // price already exists } - err = InsertPriceTx(transaction, price) + err = InsertPrice(tx, price) if err != nil { return err } @@ -51,9 +50,9 @@ func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error { } // Return the latest price for security in currency units before date -func GetLatestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { +func GetLatestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { var p Price - err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) + err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { return nil, err } @@ -61,9 +60,9 @@ func GetLatestPrice(transaction *gorp.Transaction, security, currency *Security, } // Return the earliest price for security in currency units after date -func GetEarliestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { +func GetEarliestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { var p Price - err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) + err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { return nil, err } @@ -71,9 +70,9 @@ func GetEarliestPrice(transaction *gorp.Transaction, security, currency *Securit } // Return the price for security in currency closest to date -func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) { - earliest, _ := GetEarliestPrice(transaction, security, currency, date) - latest, err := GetLatestPrice(transaction, security, currency, date) +func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { + earliest, _ := GetEarliestPrice(tx, security, currency, date) + latest, err := GetLatestPrice(tx, security, currency, date) // Return early if either earliest or latest are invalid if earliest == nil { @@ -90,24 +89,3 @@ func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Securi return earliest, nil } } - -func GetClosestPrice(db *DB, security, currency *Security, date *time.Time) (*Price, error) { - transaction, err := db.Begin() - if err != nil { - return nil, err - } - - price, err := GetClosestPriceTx(transaction, security, currency, date) - if err != nil { - transaction.Rollback() - return nil, err - } - - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - - return price, nil -} diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index f42d00f..7ddbb4e 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -77,36 +77,36 @@ func (r *Tabulation) Write(w http.ResponseWriter) error { return enc.Encode(r) } -func GetReport(db *DB, reportid int64, userid int64) (*Report, error) { +func GetReport(tx *Tx, reportid int64, userid int64) (*Report, error) { var r Report - err := db.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) + err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) if err != nil { return nil, err } return &r, nil } -func GetReports(db *DB, userid int64) (*[]Report, error) { +func GetReports(tx *Tx, userid int64) (*[]Report, error) { var reports []Report - _, err := db.Select(&reports, "SELECT * from reports where UserId=?", userid) + _, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid) if err != nil { return nil, err } return &reports, nil } -func InsertReport(db *DB, r *Report) error { - err := db.Insert(r) +func InsertReport(tx *Tx, r *Report) error { + err := tx.Insert(r) if err != nil { return err } return nil } -func UpdateReport(db *DB, r *Report) error { - count, err := db.Update(r) +func UpdateReport(tx *Tx, r *Report) error { + count, err := tx.Update(r) if err != nil { return err } @@ -116,8 +116,8 @@ func UpdateReport(db *DB, r *Report) error { return nil } -func DeleteReport(db *DB, r *Report) error { - count, err := db.Delete(r) +func DeleteReport(tx *Tx, r *Report) error { + count, err := tx.Delete(r) if err != nil { return err } @@ -127,14 +127,14 @@ func DeleteReport(db *DB, r *Report) error { return nil } -func runReport(db *DB, user *User, report *Report) (*Tabulation, error) { +func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) { // Create a new LState without opening the default libs for security L := lua.NewState(lua.Options{SkipOpenLibs: true}) defer L.Close() // Create a new context holding the current user with a timeout ctx := context.WithValue(context.Background(), userContextKey, user) - ctx = context.WithValue(ctx, dbContextKey, db) + ctx = context.WithValue(ctx, dbContextKey, tx) ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second) defer cancel() L.SetContext(ctx) @@ -191,79 +191,60 @@ func runReport(db *DB, user *User, report *Report) (*Tabulation, error) { } } -func ReportTabulationHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, reportid int64) { - report, err := GetReport(db, reportid, user.UserId) +func ReportTabulationHandler(tx *Tx, r *http.Request, user *User, reportid int64) ResponseWriterWriter { + report, err := GetReport(tx, reportid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - tabulation, err := runReport(db, user, report) + tabulation, err := runReport(tx, user, report) if err != nil { // TODO handle different failure cases differently log.Print("runReport returned:", err) - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } tabulation.ReportId = reportid - err = tabulation.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return tabulation } -func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) { - user, err := GetUserFromSession(db, r) +func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { + user, err := GetUserFromSession(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } if r.Method == "POST" { report_json := r.PostFormValue("report") if report_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var report Report err := report.Read(report_json) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } report.ReportId = -1 report.UserId = user.UserId - err = InsertReport(db, &report) + err = InsertReport(tx, &report) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - w.WriteHeader(201 /*Created*/) - err = report.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return ResponseWrapper{201, &report} } else if r.Method == "GET" { if reportTabulationRE.MatchString(r.URL.Path) { var reportid int64 n, err := GetURLPieces(r.URL.Path, "/report/%d/tabulation", &reportid) if err != nil || n != 1 { - WriteError(w, 999 /*InternalError*/) log.Print(err) - return + return NewError(999 /*InternalError*/) } - ReportTabulationHandler(db, w, r, user, reportid) - return + return ReportTabulationHandler(tx, r, user, reportid) } var reportid int64 @@ -271,84 +252,62 @@ func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) { if err != nil || n != 1 { //Return all Reports var rl ReportList - reports, err := GetReports(db, user.UserId) + reports, err := GetReports(tx, user.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } rl.Reports = reports - err = (&rl).Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &rl } else { // Return Report with this Id - report, err := GetReport(db, reportid, user.UserId) + report, err := GetReport(tx, reportid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = report.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return report } } else { reportid, err := GetURLID(r.URL.Path) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } if r.Method == "PUT" { report_json := r.PostFormValue("report") if report_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var report Report err := report.Read(report_json) if err != nil || report.ReportId != reportid { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } report.UserId = user.UserId - err = UpdateReport(db, &report) + err = UpdateReport(tx, &report) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - err = report.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &report } else if r.Method == "DELETE" { - report, err := GetReport(db, reportid, user.UserId) + report, err := GetReport(tx, reportid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = DeleteReport(db, report) + err = DeleteReport(tx, report) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - WriteSuccess(w) + return SuccessWriter{} } } + return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index 7333d65..9d047d3 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -5,7 +5,6 @@ package handlers import ( "encoding/json" "errors" - "gopkg.in/gorp.v1" "log" "net/http" "net/url" @@ -103,83 +102,50 @@ func FindCurrencyTemplate(iso4217 int64) *Security { return nil } -func GetSecurity(db *DB, securityid int64, userid int64) (*Security, error) { +func GetSecurity(tx *Tx, securityid int64, userid int64) (*Security, error) { var s Security - err := db.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) + err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) if err != nil { return nil, err } return &s, nil } -func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64) (*Security, error) { - var s Security - - err := transaction.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) - if err != nil { - return nil, err - } - return &s, nil -} - -func GetSecurities(db *DB, userid int64) (*[]*Security, error) { +func GetSecurities(tx *Tx, userid int64) (*[]*Security, error) { var securities []*Security - _, err := db.Select(&securities, "SELECT * from securities where UserId=?", userid) + _, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) if err != nil { return nil, err } return &securities, nil } -func InsertSecurity(db *DB, s *Security) error { - err := db.Insert(s) +func InsertSecurity(tx *Tx, s *Security) error { + err := tx.Insert(s) if err != nil { return err } return nil } -func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error { - err := transaction.Insert(s) +func UpdateSecurity(tx *Tx, s *Security) (err error) { + user, err := GetUser(tx, s.UserId) if err != nil { - return err - } - return nil -} - -func UpdateSecurity(db *DB, s *Security) error { - transaction, err := db.Begin() - if err != nil { - return err - } - - user, err := GetUserTx(transaction, s.UserId) - if err != nil { - transaction.Rollback() - return err + return } else if user.DefaultCurrency == s.SecurityId && s.Type != Currency { - transaction.Rollback() return errors.New("Cannot change security which is user's default currency to be non-currency") } - count, err := transaction.Update(s) + count, err := tx.Update(s) if err != nil { - transaction.Rollback() - return err + return } if count != 1 { - transaction.Rollback() return errors.New("Updated more than one security") } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } @@ -191,61 +157,44 @@ func (e SecurityInUseError) Error() string { return e.message } -func DeleteSecurity(db *DB, s *Security) error { - transaction, err := db.Begin() - if err != nil { - return err - } - +func DeleteSecurity(tx *Tx, s *Security) error { // First, ensure no accounts are using this security - accounts, err := transaction.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) + accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) if accounts != 0 { - transaction.Rollback() return SecurityInUseError{"One or more accounts still use this security"} } - user, err := GetUserTx(transaction, s.UserId) + user, err := GetUser(tx, s.UserId) if err != nil { - transaction.Rollback() return err } else if user.DefaultCurrency == s.SecurityId { - transaction.Rollback() return SecurityInUseError{"Cannot delete security which is user's default currency"} } // Remove all prices involving this security (either of this security, or // using it as a currency) - _, err = transaction.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) + _, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) if err != nil { - transaction.Rollback() return err } - count, err := transaction.Delete(s) + count, err := tx.Delete(s) if err != nil { - transaction.Rollback() return err } if count != 1 { - transaction.Rollback() return errors.New("Deleted more than one security") } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } -func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, security *Security) (*Security, error) { +func ImportGetCreateSecurity(tx *Tx, userid int64, security *Security) (*Security, error) { security.UserId = userid if len(security.AlternateId) == 0 { // Always create a new local security if we can't match on the AlternateId - err := InsertSecurityTx(transaction, security) + err := InsertSecurity(tx, security) if err != nil { return nil, err } @@ -254,7 +203,7 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi var securities []*Security - _, err := transaction.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Precision=?", userid, security.Type, security.AlternateId, security.Precision) + _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Precision=?", userid, security.Type, security.AlternateId, security.Precision) if err != nil { return nil, err } @@ -286,7 +235,7 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi } // If there wasn't even one security in the list, make a new one - err = InsertSecurityTx(transaction, security) + err = InsertSecurity(tx, security) if err != nil { return nil, err } @@ -294,43 +243,33 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi return security, nil } -func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) { - user, err := GetUserFromSession(db, r) +func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { + user, err := GetUserFromSession(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } if r.Method == "POST" { security_json := r.PostFormValue("security") if security_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var security Security err := security.Read(security_json) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } security.SecurityId = -1 security.UserId = user.UserId - err = InsertSecurity(db, &security) + err = InsertSecurity(tx, &security) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - w.WriteHeader(201 /*Created*/) - err = security.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return ResponseWrapper{201, &security} } else if r.Method == "GET" { var securityid int64 n, err := GetURLPieces(r.URL.Path, "/security/%d", &securityid) @@ -339,87 +278,65 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) { //Return all securities var sl SecurityList - securities, err := GetSecurities(db, user.UserId) + securities, err := GetSecurities(tx, user.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } sl.Securities = securities - err = (&sl).Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &sl } else { - security, err := GetSecurity(db, securityid, user.UserId) + security, err := GetSecurity(tx, securityid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = security.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return security } } else { securityid, err := GetURLID(r.URL.Path) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } if r.Method == "PUT" { security_json := r.PostFormValue("security") if security_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var security Security err := security.Read(security_json) if err != nil || security.SecurityId != securityid { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } security.UserId = user.UserId - err = UpdateSecurity(db, &security) + err = UpdateSecurity(tx, &security) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - err = security.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &security } else if r.Method == "DELETE" { - security, err := GetSecurity(db, securityid, user.UserId) + security, err := GetSecurity(tx, securityid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = DeleteSecurity(db, security) + err = DeleteSecurity(tx, security) if _, ok := err.(SecurityInUseError); ok { - WriteError(w, 7 /*In Use Error*/) + return NewError(7 /*In Use Error*/) } else if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - WriteSuccess(w) + return SuccessWriter{} } } + return NewError(3 /*Invalid Request*/) } func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) { diff --git a/internal/handlers/securities_lua.go b/internal/handlers/securities_lua.go index a13614b..4acd979 100644 --- a/internal/handlers/securities_lua.go +++ b/internal/handlers/securities_lua.go @@ -13,9 +13,9 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { ctx := L.Context() - db, ok := ctx.Value(dbContextKey).(*DB) + tx, ok := ctx.Value(dbContextKey).(*Tx) if !ok { - return nil, errors.New("Couldn't find DB in lua's Context") + return nil, errors.New("Couldn't find tx in lua's Context") } security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security) @@ -25,7 +25,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { return nil, errors.New("Couldn't find User in lua's Context") } - securities, err := GetSecurities(db, user.UserId) + securities, err := GetSecurities(tx, user.UserId) if err != nil { return nil, err } @@ -155,12 +155,12 @@ func luaClosestPrice(L *lua.LState) int { date := luaCheckTime(L, 3) ctx := L.Context() - db, ok := ctx.Value(dbContextKey).(*DB) + tx, ok := ctx.Value(dbContextKey).(*Tx) if !ok { - panic("Couldn't find DB in lua's Context") + panic("Couldn't find tx in lua's Context") } - p, err := GetClosestPrice(db, s, c, date) + p, err := GetClosestPrice(tx, s, c, date) if err != nil { L.Push(lua.LNil) } else { diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index 2cfa56a..8f61a7f 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -28,7 +28,7 @@ func (s *Session) Read(json_str string) error { return dec.Decode(s) } -func GetSession(db *DB, r *http.Request) (*Session, error) { +func GetSession(tx *Tx, r *http.Request) (*Session, error) { var s Session cookie, err := r.Cookie("moneygo-session") @@ -37,18 +37,17 @@ func GetSession(db *DB, r *http.Request) (*Session, error) { } s.SessionSecret = cookie.Value - err = db.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) + err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) if err != nil { return nil, err } return &s, nil } -func DeleteSessionIfExists(db *DB, r *http.Request) error { - // TODO do this in one transaction - session, err := GetSession(db, r) +func DeleteSessionIfExists(tx *Tx, r *http.Request) error { + session, err := GetSession(tx, r) if err == nil { - _, err := db.Delete(session) + _, err := tx.Delete(session) if err != nil { return err } @@ -64,7 +63,17 @@ func NewSessionCookie() (string, error) { return base64.StdEncoding.EncodeToString(bits), nil } -func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) { +type NewSessionWriter struct { + session *Session + cookie *http.Cookie +} + +func (n *NewSessionWriter) Write(w http.ResponseWriter) error { + http.SetCookie(w, n.cookie) + return n.session.Write(w) +} + +func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) { s := Session{} session_secret, err := NewSessionCookie() @@ -81,79 +90,66 @@ func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (* Secure: true, HttpOnly: true, } - http.SetCookie(w, &cookie) s.SessionSecret = session_secret s.UserId = userid - err = db.Insert(&s) + err = tx.Insert(&s) if err != nil { return nil, err } - return &s, nil + return &NewSessionWriter{&s, &cookie}, nil } -func SessionHandler(w http.ResponseWriter, r *http.Request, db *DB) { +func SessionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { if r.Method == "POST" || r.Method == "PUT" { user_json := r.PostFormValue("user") if user_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } user := User{} err := user.Read(user_json) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - dbuser, err := GetUserByUsername(db, user.Username) + dbuser, err := GetUserByUsername(tx, user.Username) if err != nil { - WriteError(w, 2 /*Unauthorized Access*/) - return + return NewError(2 /*Unauthorized Access*/) } user.HashPassword() if user.PasswordHash != dbuser.PasswordHash { - WriteError(w, 2 /*Unauthorized Access*/) - return + return NewError(2 /*Unauthorized Access*/) } - err = DeleteSessionIfExists(db, r) + err = DeleteSessionIfExists(tx, r) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - session, err := NewSession(db, w, r, dbuser.UserId) + sessionwriter, err := NewSession(tx, r, dbuser.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) - return - } - - err = session.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } + return sessionwriter } else if r.Method == "GET" { - s, err := GetSession(db, r) + s, err := GetSession(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } - s.Write(w) + return s } else if r.Method == "DELETE" { - err := DeleteSessionIfExists(db, r) + err := DeleteSessionIfExists(tx, r) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - WriteSuccess(w) + return SuccessWriter{} } + return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 72c5ab2..d7879cb 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "gopkg.in/gorp.v1" "log" "math/big" "net/http" @@ -78,8 +77,8 @@ func (s *Split) Valid() bool { return err == nil } -func (s *Split) AlreadyImportedTx(transaction *gorp.Transaction) (bool, error) { - count, err := transaction.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) +func (s *Split) AlreadyImported(tx *Tx) (bool, error) { + count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) return count == 1, err } @@ -134,7 +133,7 @@ func (t *Transaction) Valid() bool { // Return a map of security ID's to big.Rat's containing the amount that // security is imbalanced by -func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]big.Rat, error) { +func (t *Transaction) GetImbalances(tx *Tx) (map[int64]big.Rat, error) { sums := make(map[int64]big.Rat) if !t.Valid() { @@ -146,7 +145,7 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64] if t.Splits[i].AccountId != -1 { var err error var account *Account - account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId) + account, err = GetAccount(tx, t.Splits[i].AccountId, t.UserId) if err != nil { return nil, err } @@ -162,10 +161,10 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64] // Returns true if all securities contained in this transaction are balanced, // false otherwise -func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) { +func (t *Transaction) Balanced(tx *Tx) (bool, error) { var zero big.Rat - sums, err := t.GetImbalancesTx(transaction) + sums, err := t.GetImbalances(tx) if err != nil { return false, err } @@ -178,72 +177,48 @@ func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) { return true, nil } -func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) { +func GetTransaction(tx *Tx, transactionid int64, userid int64) (*Transaction, error) { var t Transaction - transaction, err := db.Begin() + err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) if err != nil { return nil, err } - err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) + _, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) if err != nil { - transaction.Rollback() - return nil, err - } - - _, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) - if err != nil { - transaction.Rollback() - return nil, err - } - - err = transaction.Commit() - if err != nil { - transaction.Rollback() return nil, err } return &t, nil } -func GetTransactions(db *DB, userid int64) (*[]Transaction, error) { +func GetTransactions(tx *Tx, userid int64) (*[]Transaction, error) { var transactions []Transaction - transaction, err := db.Begin() - if err != nil { - return nil, err - } - - _, err = transaction.Select(&transactions, "SELECT * from transactions where UserId=?", userid) + _, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) if err != nil { return nil, err } for i := range transactions { - _, err := transaction.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) + _, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) if err != nil { return nil, err } } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - return &transactions, nil } -func incrementAccountVersions(transaction *gorp.Transaction, user *User, accountids []int64) error { +func incrementAccountVersions(tx *Tx, user *User, accountids []int64) error { for i := range accountids { - account, err := GetAccountTx(transaction, accountids[i], user.UserId) + account, err := GetAccount(tx, accountids[i], user.UserId) if err != nil { return err } account.AccountVersion++ - count, err := transaction.Update(account) + count, err := tx.Update(account) if err != nil { return err } @@ -260,12 +235,12 @@ func (ame AccountMissingError) Error() string { return "Account missing" } -func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error { +func InsertTransaction(tx *Tx, t *Transaction, user *User) error { // Map of any accounts with transaction splits being added a_map := make(map[int64]bool) for i := range t.Splits { if t.Splits[i].AccountId != -1 { - existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) + existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) if err != nil { return err } @@ -287,13 +262,13 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us if len(a_ids) < 1 { return AccountMissingError{} } - err := incrementAccountVersions(transaction, user, a_ids) + err := incrementAccountVersions(tx, user, a_ids) if err != nil { return err } t.UserId = user.UserId - err = transaction.Insert(t) + err = tx.Insert(t) if err != nil { return err } @@ -301,7 +276,7 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us for i := range t.Splits { t.Splits[i].TransactionId = t.TransactionId t.Splits[i].SplitId = -1 - err = transaction.Insert(t.Splits[i]) + err = tx.Insert(t.Splits[i]) if err != nil { return err } @@ -310,31 +285,10 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us return nil } -func InsertTransaction(db *DB, t *Transaction, user *User) error { - transaction, err := db.Begin() - if err != nil { - return err - } - - err = InsertTransactionTx(transaction, t, user) - if err != nil { - transaction.Rollback() - return err - } - - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - - return nil -} - -func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error { +func UpdateTransaction(tx *Tx, t *Transaction, user *User) error { var existing_splits []*Split - _, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) + _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) if err != nil { return err } @@ -353,7 +307,7 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us t.Splits[i].TransactionId = t.TransactionId _, ok := s_map[t.Splits[i].SplitId] if ok { - count, err := transaction.Update(t.Splits[i]) + count, err := tx.Update(t.Splits[i]) if err != nil { return err } @@ -363,7 +317,7 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us delete(s_map, t.Splits[i].SplitId) } else { t.Splits[i].SplitId = -1 - err := transaction.Insert(t.Splits[i]) + err := tx.Insert(t.Splits[i]) if err != nil { return err } @@ -380,7 +334,7 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us a_map[existing_splits[i].AccountId] = true } if ok { - _, err := transaction.Delete(existing_splits[i]) + _, err := tx.Delete(existing_splits[i]) if err != nil { return err } @@ -392,12 +346,12 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us for id := range a_map { a_ids = append(a_ids, id) } - err = incrementAccountVersions(transaction, user, a_ids) + err = incrementAccountVersions(tx, user, a_ids) if err != nil { return err } - count, err := transaction.Update(t) + count, err := tx.Update(t) if err != nil { return err } @@ -408,263 +362,171 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us return nil } -func DeleteTransaction(db *DB, t *Transaction, user *User) error { - transaction, err := db.Begin() - if err != nil { - return err - } - +func DeleteTransaction(tx *Tx, t *Transaction, user *User) error { var accountids []int64 - _, err = transaction.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) + _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) + _, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) if err != nil { - transaction.Rollback() return err } - count, err := transaction.Delete(t) + count, err := tx.Delete(t) if err != nil { - transaction.Rollback() return err } if count != 1 { - transaction.Rollback() return errors.New("Deleted more than one transaction") } - err = incrementAccountVersions(transaction, user, accountids) + err = incrementAccountVersions(tx, user, accountids) if err != nil { - transaction.Rollback() - return err - } - - err = transaction.Commit() - if err != nil { - transaction.Rollback() return err } return nil } -func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) { - user, err := GetUserFromSession(db, r) +func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { + user, err := GetUserFromSession(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } if r.Method == "POST" { transaction_json := r.PostFormValue("transaction") if transaction_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var transaction Transaction err := transaction.Read(transaction_json) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } transaction.TransactionId = -1 transaction.UserId = user.UserId - sqltx, err := db.Begin() + balanced, err := transaction.Balanced(tx) if err != nil { - WriteError(w, 999 /*Internal Error*/) + return NewError(999 /*Internal Error*/) log.Print(err) - return - } - - balanced, err := transaction.Balanced(sqltx) - if err != nil { - sqltx.Rollback() - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return } if !transaction.Valid() || !balanced { - sqltx.Rollback() - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } for i := range transaction.Splits { transaction.Splits[i].SplitId = -1 - _, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId) + _, err := GetAccount(tx, transaction.Splits[i].AccountId, user.UserId) if err != nil { - sqltx.Rollback() - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } } - err = InsertTransactionTx(sqltx, &transaction, user) + err = InsertTransaction(tx, &transaction, user) if err != nil { if _, ok := err.(AccountMissingError); ok { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } else { - WriteError(w, 999 /*Internal Error*/) log.Print(err) + return NewError(999 /*Internal Error*/) } - sqltx.Rollback() - return } - err = sqltx.Commit() - if err != nil { - sqltx.Rollback() - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } - - err = transaction.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &transaction } else if r.Method == "GET" { transactionid, err := GetURLID(r.URL.Path) if err != nil { //Return all Transactions var al TransactionList - transactions, err := GetTransactions(db, user.UserId) + transactions, err := GetTransactions(tx, user.UserId) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } al.Transactions = transactions - err = (&al).Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &al } else { //Return Transaction with this Id - transaction, err := GetTransaction(db, transactionid, user.UserId) + transaction, err := GetTransaction(tx, transactionid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return - } - err = transaction.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return + return NewError(3 /*Invalid Request*/) } + return transaction } } else { transactionid, err := GetURLID(r.URL.Path) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } if r.Method == "PUT" { transaction_json := r.PostFormValue("transaction") if transaction_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var transaction Transaction err := transaction.Read(transaction_json) if err != nil || transaction.TransactionId != transactionid { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } transaction.UserId = user.UserId - sqltx, err := db.Begin() + balanced, err := transaction.Balanced(tx) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return - } - - balanced, err := transaction.Balanced(sqltx) - if err != nil { - sqltx.Rollback() - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return + return NewError(999 /*Internal Error*/) } if !transaction.Valid() || !balanced { - sqltx.Rollback() - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } for i := range transaction.Splits { - _, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId) + _, err := GetAccount(tx, transaction.Splits[i].AccountId, user.UserId) if err != nil { - sqltx.Rollback() - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } } - err = UpdateTransactionTx(sqltx, &transaction, user) + err = UpdateTransaction(tx, &transaction, user) if err != nil { - sqltx.Rollback() - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - err = sqltx.Commit() - if err != nil { - sqltx.Rollback() - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } - - err = transaction.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &transaction } else if r.Method == "DELETE" { transactionid, err := GetURLID(r.URL.Path) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - transaction, err := GetTransaction(db, transactionid, user.UserId) + transaction, err := GetTransaction(tx, transactionid, user.UserId) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } - err = DeleteTransaction(db, transaction, user) + err = DeleteTransaction(tx, transaction, user) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - WriteSuccess(w) + return SuccessWriter{} } } + return NewError(3 /*Invalid Request*/) } -func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int64, transactions []Transaction) (*big.Rat, error) { +func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []Transaction) (*big.Rat, error) { var pageDifference, tmp big.Rat for i := range transactions { - _, err := transaction.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) + _, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) if err != nil { return nil, err } @@ -685,17 +547,12 @@ func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int6 return &pageDifference, nil } -func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) { +func GetAccountBalance(tx *Tx, user *User, accountid int64) (*big.Rat, error) { var splits []Split - transaction, err := db.Begin() - if err != nil { - return nil, err - } sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" - _, err = transaction.Select(&splits, sql, accountid, user.UserId) + _, err := tx.Select(&splits, sql, accountid, user.UserId) if err != nil { - transaction.Rollback() return nil, err } @@ -703,34 +560,22 @@ func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) { for _, s := range splits { rat_amount, err := GetBigAmount(s.Amount) if err != nil { - transaction.Rollback() return nil, err } tmp.Add(&balance, rat_amount) balance.Set(&tmp) } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - return &balance, nil } // Assumes accountid is valid and is owned by the current user -func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) { +func GetAccountBalanceDate(tx *Tx, user *User, accountid int64, date *time.Time) (*big.Rat, error) { var splits []Split - transaction, err := db.Begin() - if err != nil { - return nil, err - } sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" - _, err = transaction.Select(&splits, sql, accountid, user.UserId, date) + _, err := tx.Select(&splits, sql, accountid, user.UserId, date) if err != nil { - transaction.Rollback() return nil, err } @@ -738,33 +583,21 @@ func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) for _, s := range splits { rat_amount, err := GetBigAmount(s.Amount) if err != nil { - transaction.Rollback() return nil, err } tmp.Add(&balance, rat_amount) balance.Set(&tmp) } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - return &balance, nil } -func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { +func GetAccountBalanceDateRange(tx *Tx, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { var splits []Split - transaction, err := db.Begin() - if err != nil { - return nil, err - } sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" - _, err = transaction.Select(&splits, sql, accountid, user.UserId, begin, end) + _, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end) if err != nil { - transaction.Rollback() return nil, err } @@ -772,31 +605,19 @@ func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end for _, s := range splits { rat_amount, err := GetBigAmount(s.Amount) if err != nil { - transaction.Rollback() return nil, err } tmp.Add(&balance, rat_amount) balance.Set(&tmp) } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - return &balance, nil } -func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { +func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { var transactions []Transaction var atl AccountTransactionsList - transaction, err := db.Begin() - if err != nil { - return nil, err - } - var sqlsort, balanceLimitOffset string var balanceLimitOffsetArg uint64 if sort == "date-asc" { @@ -804,9 +625,8 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa balanceLimitOffset = " LIMIT ?" balanceLimitOffsetArg = page * limit } else if sort == "date-desc" { - numSplits, err := transaction.SelectInt("SELECT count(*) FROM splits") + numSplits, err := tx.SelectInt("SELECT count(*) FROM splits") if err != nil { - transaction.Rollback() return nil, err } sqlsort = " ORDER BY transactions.Date DESC" @@ -819,41 +639,35 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) } - account, err := GetAccountTx(transaction, accountid, user.UserId) + account, err := GetAccount(tx, accountid, user.UserId) if err != nil { - transaction.Rollback() return nil, err } atl.Account = account sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset - _, err = transaction.Select(&transactions, sql, user.UserId, accountid, limit) + _, err = tx.Select(&transactions, sql, user.UserId, accountid, limit) if err != nil { - transaction.Rollback() return nil, err } atl.Transactions = &transactions - pageDifference, err := TransactionsBalanceDifference(transaction, accountid, transactions) + pageDifference, err := TransactionsBalanceDifference(tx, accountid, transactions) if err != nil { - transaction.Rollback() return nil, err } - count, err := transaction.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) + count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) if err != nil { - transaction.Rollback() return nil, err } atl.TotalTransactions = count - security, err := GetSecurityTx(transaction, atl.Account.SecurityId, user.UserId) + security, err := GetSecurity(tx, atl.Account.SecurityId, user.UserId) if err != nil { - transaction.Rollback() return nil, err } if security == nil { - transaction.Rollback() return nil, errors.New("Security not found") } @@ -861,9 +675,8 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa // occurred before the page we're returning var amounts []string sql = "SELECT splits.Amount FROM splits WHERE splits.AccountId=? AND splits.TransactionId IN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ")" - _, err = transaction.Select(&amounts, sql, accountid, user.UserId, accountid, balanceLimitOffsetArg) + _, err = tx.Select(&amounts, sql, accountid, user.UserId, accountid, balanceLimitOffsetArg) if err != nil { - transaction.Rollback() return nil, err } @@ -871,7 +684,6 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa for _, amount := range amounts { rat_amount, err := GetBigAmount(amount) if err != nil { - transaction.Rollback() return nil, err } tmp.Add(&balance, rat_amount) @@ -880,20 +692,12 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa atl.BeginningBalance = balance.FloatString(security.Precision) atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - return &atl, nil } // Return only those transactions which have at least one split pertaining to // an account -func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, - user *User, accountid int64) { - +func AccountTransactionsHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter { var page uint64 = 0 var limit uint64 = 50 var sort string = "date-desc" @@ -904,8 +708,7 @@ func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, if pagestring != "" { p, err := strconv.ParseUint(pagestring, 10, 0) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } page = p } @@ -914,8 +717,7 @@ func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, if limitstring != "" { l, err := strconv.ParseUint(limitstring, 10, 0) if err != nil || l > 100 { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } limit = l } @@ -923,23 +725,16 @@ func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, sortstring := query.Get("sort") if sortstring != "" { if sortstring != "date-asc" && sortstring != "date-desc" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } sort = sortstring } - accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit) + accountTransactions, err := GetAccountTransactions(tx, user, accountid, sort, page, limit) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - err = accountTransactions.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return accountTransactions } diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 99ea60a..63f3ccb 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "gopkg.in/gorp.v1" "io" "log" "net/http" @@ -47,61 +46,42 @@ func (u *User) HashPassword() { u.Password = "" } -func GetUser(db *DB, userid int64) (*User, error) { +func GetUser(tx *Tx, userid int64) (*User, error) { var u User - err := db.SelectOne(&u, "SELECT * from users where UserId=?", userid) + err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) if err != nil { return nil, err } return &u, nil } -func GetUserTx(transaction *gorp.Transaction, userid int64) (*User, error) { +func GetUserByUsername(tx *Tx, username string) (*User, error) { var u User - err := transaction.SelectOne(&u, "SELECT * from users where UserId=?", userid) + err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) if err != nil { return nil, err } return &u, nil } -func GetUserByUsername(db *DB, username string) (*User, error) { - var u User - - err := db.SelectOne(&u, "SELECT * from users where Username=?", username) - if err != nil { - return nil, err - } - return &u, nil -} - -func InsertUser(db *DB, u *User) error { - transaction, err := db.Begin() - if err != nil { - return err - } - +func InsertUser(tx *Tx, u *User) error { security_template := FindCurrencyTemplate(u.DefaultCurrency) if security_template == nil { - transaction.Rollback() return errors.New("Invalid ISO4217 Default Currency") } - existing, err := transaction.SelectInt("SELECT count(*) from users where Username=?", u.Username) + existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username) if err != nil { - transaction.Rollback() return err } if existing > 0 { - transaction.Rollback() return UserExistsError{} } - err = transaction.Insert(u) + err = tx.Insert(u) if err != nil { - transaction.Rollback() return err } @@ -110,201 +90,138 @@ func InsertUser(db *DB, u *User) error { security = *security_template security.UserId = u.UserId - err = InsertSecurityTx(transaction, &security) + err = InsertSecurity(tx, &security) if err != nil { - transaction.Rollback() return err } // Update the user's DefaultCurrency to our new SecurityId u.DefaultCurrency = security.SecurityId - count, err := transaction.Update(u) + count, err := tx.Update(u) if err != nil { - transaction.Rollback() return err } else if count != 1 { - transaction.Rollback() return errors.New("Would have updated more than one user") } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } -func GetUserFromSession(db *DB, r *http.Request) (*User, error) { - s, err := GetSession(db, r) +func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) { + s, err := GetSession(tx, r) if err != nil { return nil, err } - return GetUser(db, s.UserId) + return GetUser(tx, s.UserId) } -func UpdateUser(db *DB, u *User) error { - transaction, err := db.Begin() +func UpdateUser(tx *Tx, u *User) error { + security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId) if err != nil { return err - } - - security, err := GetSecurityTx(transaction, u.DefaultCurrency, u.UserId) - if err != nil { - transaction.Rollback() - return err } else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency { - transaction.Rollback() return errors.New("UserId and DefaultCurrency don't match the fetched security") } else if security.Type != Currency { - transaction.Rollback() return errors.New("New DefaultCurrency security is not a currency") } - count, err := transaction.Update(u) + count, err := tx.Update(u) if err != nil { - transaction.Rollback() return err } else if count != 1 { - transaction.Rollback() return errors.New("Would have updated more than one user") } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } -func DeleteUser(db *DB, u *User) error { - transaction, err := db.Begin() +func DeleteUser(tx *Tx, u *User) error { + count, err := tx.Delete(u) if err != nil { return err } - - count, err := transaction.Delete(u) - if err != nil { - transaction.Rollback() - return err - } if count != 1 { - transaction.Rollback() return fmt.Errorf("No user to delete") } - _, err = transaction.Exec("DELETE FROM prices WHERE prices.PriceId IN (SELECT prices.PriceId FROM prices INNER JOIN securities ON prices.SecurityId=securities.SecurityId WHERE securities.UserId=?)", u.UserId) + _, err = tx.Exec("DELETE FROM prices WHERE prices.PriceId IN (SELECT prices.PriceId FROM prices INNER JOIN securities ON prices.SecurityId=securities.SecurityId WHERE securities.UserId=?)", u.UserId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM splits WHERE splits.SplitId IN (SELECT splits.SplitId FROM splits INNER JOIN transactions ON splits.TransactionId=transactions.TransactionId WHERE transactions.UserId=?)", u.UserId) + _, err = tx.Exec("DELETE FROM splits WHERE splits.SplitId IN (SELECT splits.SplitId FROM splits INNER JOIN transactions ON splits.TransactionId=transactions.TransactionId WHERE transactions.UserId=?)", u.UserId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId) + _, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId) + _, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId) + _, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId) + _, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId) if err != nil { - transaction.Rollback() return err } - _, err = transaction.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId) + _, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId) if err != nil { - transaction.Rollback() - return err - } - - err = transaction.Commit() - if err != nil { - transaction.Rollback() return err } return nil } -func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { +func UserHandler(r *http.Request, tx *Tx) ResponseWriterWriter { if r.Method == "POST" { user_json := r.PostFormValue("user") if user_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } var user User err := user.Read(user_json) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } user.UserId = -1 user.HashPassword() - err = InsertUser(db, &user) + err = InsertUser(tx, &user) if err != nil { if _, ok := err.(UserExistsError); ok { - WriteError(w, 4 /*User Exists*/) + return NewError(4 /*User Exists*/) } else { - WriteError(w, 999 /*Internal Error*/) log.Print(err) + return NewError(999 /*Internal Error*/) } - return } - w.WriteHeader(201 /*Created*/) - err = user.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return ResponseWrapper{201, &user} } else { - user, err := GetUserFromSession(db, r) + user, err := GetUserFromSession(tx, r) if err != nil { - WriteError(w, 1 /*Not Signed In*/) - return + return NewError(1 /*Not Signed In*/) } userid, err := GetURLID(r.URL.Path) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } if userid != user.UserId { - WriteError(w, 2 /*Unauthorized Access*/) - return + return NewError(2 /*Unauthorized Access*/) } if r.Method == "GET" { - err = user.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return user } else if r.Method == "PUT" { user_json := r.PostFormValue("user") if user_json == "" { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } // Save old PWHash in case the new password is bogus @@ -312,8 +229,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { err = user.Read(user_json) if err != nil || user.UserId != userid { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } // If the user didn't create a new password, keep their old one @@ -324,27 +240,21 @@ func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { user.PasswordHash = old_pwhash } - err = UpdateUser(db, user) + err = UpdateUser(tx, user) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - err = user.Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return user } else if r.Method == "DELETE" { - err := DeleteUser(db, user) + err := DeleteUser(tx, user) if err != nil { - WriteError(w, 999 /*Internal Error*/) log.Print(err) - return + return NewError(999 /*Internal Error*/) } - WriteSuccess(w) + return SuccessWriter{} } } + return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/util.go b/internal/handlers/util.go index 3579a13..b3e9fa7 100644 --- a/internal/handlers/util.go +++ b/internal/handlers/util.go @@ -18,6 +18,23 @@ func GetURLPieces(url string, format string, a ...interface{}) (int, error) { return fmt.Sscanf(url, format, a...) } +type ResponseWrapper struct { + Code int + Writer ResponseWriterWriter +} + +func (r ResponseWrapper) Write(w http.ResponseWriter) error { + w.WriteHeader(r.Code) + return r.Writer.Write(w) +} + +type SuccessWriter struct{} + +func (s SuccessWriter) Write(w http.ResponseWriter) error { + fmt.Fprint(w, "{}") + return nil +} + func WriteSuccess(w http.ResponseWriter) { fmt.Fprint(w, "{}") }