diff --git a/accounts.go b/accounts.go index 54760f1..e038322 100644 --- a/accounts.go +++ b/accounts.go @@ -124,10 +124,10 @@ func (al *AccountList) Write(w http.ResponseWriter) error { return enc.Encode(al) } -func GetAccount(accountid int64, userid int64) (*Account, error) { +func GetAccount(db *DB, accountid int64, userid int64) (*Account, error) { var a Account - err := DB.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) + err := db.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) if err != nil { return nil, err } @@ -145,10 +145,10 @@ func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64) return &a, nil } -func GetAccounts(userid int64) (*[]Account, error) { +func GetAccounts(db *DB, userid int64) (*[]Account, error) { var accounts []Account - _, err := DB.Select(&accounts, "SELECT * from accounts where UserId=?", userid) + _, err := db.Select(&accounts, "SELECT * from accounts where UserId=?", userid) if err != nil { return nil, err } @@ -276,8 +276,8 @@ func (pame ParentAccountMissingError) Error() string { return "Parent account missing" } -func insertUpdateAccount(a *Account, insert bool) error { - transaction, err := DB.Begin() +func insertUpdateAccount(db *DB, a *Account, insert bool) error { + transaction, err := db.Begin() if err != nil { return err } @@ -329,16 +329,16 @@ func insertUpdateAccount(a *Account, insert bool) error { return nil } -func InsertAccount(a *Account) error { - return insertUpdateAccount(a, true) +func InsertAccount(db *DB, a *Account) error { + return insertUpdateAccount(db, a, true) } -func UpdateAccount(a *Account) error { - return insertUpdateAccount(a, false) +func UpdateAccount(db *DB, a *Account) error { + return insertUpdateAccount(db, a, false) } -func DeleteAccount(a *Account) error { - transaction, err := DB.Begin() +func DeleteAccount(db *DB, a *Account) error { + transaction, err := db.Begin() if err != nil { return err } @@ -385,8 +385,8 @@ func DeleteAccount(a *Account) error { return nil } -func AccountHandler(w http.ResponseWriter, r *http.Request) { - user, err := GetUserFromSession(r) +func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) { + user, err := GetUserFromSession(db, r) if err != nil { WriteError(w, 1 /*Not Signed In*/) return @@ -405,7 +405,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { log.Print(err) return } - AccountImportHandler(w, r, user, accountid, importtype) + AccountImportHandler(db, w, r, user, accountid, importtype) return } @@ -425,7 +425,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { account.UserId = user.UserId account.AccountVersion = 0 - security, err := GetSecurity(account.SecurityId, user.UserId) + security, err := GetSecurity(db, account.SecurityId, user.UserId) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -436,7 +436,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { return } - err = InsertAccount(&account) + err = InsertAccount(db, &account) if err != nil { if _, ok := err.(ParentAccountMissingError); ok { WriteError(w, 3 /*Invalid Request*/) @@ -461,7 +461,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { if err != nil || n != 1 { //Return all Accounts var al AccountList - accounts, err := GetAccounts(user.UserId) + accounts, err := GetAccounts(db, user.UserId) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -478,12 +478,12 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { // if URL looks like /account/[0-9]+/transactions, use the account // transaction handler if accountTransactionsRE.MatchString(r.URL.Path) { - AccountTransactionsHandler(w, r, user, accountid) + AccountTransactionsHandler(db, w, r, user, accountid) return } // Return Account with this Id - account, err := GetAccount(accountid, user.UserId) + account, err := GetAccount(db, accountid, user.UserId) if err != nil { WriteError(w, 3 /*Invalid Request*/) return @@ -517,7 +517,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { } account.UserId = user.UserId - security, err := GetSecurity(account.SecurityId, user.UserId) + security, err := GetSecurity(db, account.SecurityId, user.UserId) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -528,7 +528,7 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { return } - err = UpdateAccount(&account) + err = UpdateAccount(db, &account) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -542,13 +542,13 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { return } } else if r.Method == "DELETE" { - account, err := GetAccount(accountid, user.UserId) + account, err := GetAccount(db, accountid, user.UserId) if err != nil { WriteError(w, 3 /*Invalid Request*/) return } - err = DeleteAccount(account) + err = DeleteAccount(db, account) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) diff --git a/accounts_lua.go b/accounts_lua.go index 5a076c8..52f9095 100644 --- a/accounts_lua.go +++ b/accounts_lua.go @@ -15,14 +15,19 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { ctx := L.Context() - account_map, ok := ctx.Value(accountsContextKey).(map[int64]*Account) + db, ok := ctx.Value(dbContextKey).(*DB) + if !ok { + return nil, errors.New("Couldn't find DB in lua's Context") + } + + account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account) if !ok { user, ok := ctx.Value(userContextKey).(*User) if !ok { return nil, errors.New("Couldn't find User in lua's Context") } - accounts, err := GetAccounts(user.UserId) + accounts, err := GetAccounts(db, user.UserId) if err != nil { return nil, err } @@ -144,6 +149,10 @@ func luaAccountBalance(L *lua.LState) int { a := luaCheckAccount(L, 1) ctx := L.Context() + db, ok := ctx.Value(dbContextKey).(*DB) + if !ok { + panic("Couldn't find DB in lua's Context") + } user, ok := ctx.Value(userContextKey).(*User) if !ok { panic("Couldn't find User in lua's Context") @@ -162,12 +171,12 @@ func luaAccountBalance(L *lua.LState) int { if date != nil { end := luaWeakCheckTime(L, 3) if end != nil { - rat, err = GetAccountBalanceDateRange(user, a.AccountId, date, end) + rat, err = GetAccountBalanceDateRange(db, user, a.AccountId, date, end) } else { - rat, err = GetAccountBalanceDate(user, a.AccountId, date) + rat, err = GetAccountBalanceDate(db, user, a.AccountId, date) } } else { - rat, err = GetAccountBalance(user, a.AccountId) + rat, err = GetAccountBalance(db, user, a.AccountId) } if err != nil { panic("Failed to GetAccountBalance:" + err.Error()) diff --git a/db.go b/db.go index 0c2fdae..60505e9 100644 --- a/db.go +++ b/db.go @@ -2,19 +2,17 @@ package main import ( "database/sql" + "fmt" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "gopkg.in/gorp.v1" - "log" ) -var DB *gorp.DbMap - -func initDB(cfg *Config) { +func initDB(cfg *Config) (*gorp.DbMap, error) { db, err := sql.Open(cfg.MoneyGo.DBType.String(), cfg.MoneyGo.DSN) if err != nil { - log.Fatal(err) + return nil, err } var dialect gorp.Dialect @@ -28,7 +26,7 @@ func initDB(cfg *Config) { } else if cfg.MoneyGo.DBType == Postgres { dialect = gorp.PostgresDialect{} } else { - log.Fatalf("Don't know gorp dialect to go with '%s' DB type", cfg.MoneyGo.DBType.String()) + return nil, fmt.Errorf("Don't know gorp dialect to go with '%s' DB type", cfg.MoneyGo.DBType.String()) } dbmap := &gorp.DbMap{Db: db, Dialect: dialect} @@ -43,8 +41,8 @@ func initDB(cfg *Config) { err = dbmap.CreateTablesIfNotExists() if err != nil { - log.Fatal(err) + return nil, err } - DB = dbmap + return dbmap, nil } diff --git a/gnucash.go b/gnucash.go index 1d719c1..2665194 100644 --- a/gnucash.go +++ b/gnucash.go @@ -308,8 +308,8 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { return &gncimport, nil } -func GnucashImportHandler(w http.ResponseWriter, r *http.Request) { - user, err := GetUserFromSession(r) +func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) { + user, err := GetUserFromSession(db, r) if err != nil { WriteError(w, 1 /*Not Signed In*/) return @@ -365,7 +365,7 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request) { return } - sqltransaction, err := DB.Begin() + sqltransaction, err := db.Begin() if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) diff --git a/imports.go b/imports.go index fd57dc4..acd9f8c 100644 --- a/imports.go +++ b/imports.go @@ -22,7 +22,7 @@ func (od *OFXDownload) Read(json_str string) error { return dec.Decode(od) } -func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid int64) { +func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, accountid int64) { itl, err := ImportOFX(r) if err != nil { @@ -38,7 +38,7 @@ func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid i return } - sqltransaction, err := DB.Begin() + sqltransaction, err := db.Begin() if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -258,7 +258,7 @@ func ofxImportHelper(r io.Reader, w http.ResponseWriter, user *User, accountid i WriteSuccess(w) } -func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) { +func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) { download_json := r.PostFormValue("ofxdownload") if download_json == "" { log.Print("download_json") @@ -274,7 +274,7 @@ func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accoun return } - account, err := GetAccount(accountid, user.UserId) + account, err := GetAccount(db, accountid, user.UserId) if err != nil { log.Print("GetAccount") WriteError(w, 3 /*Invalid Request*/) @@ -367,10 +367,10 @@ func OFXImportHandler(w http.ResponseWriter, r *http.Request, user *User, accoun } defer response.Body.Close() - ofxImportHelper(response.Body, w, user, accountid) + ofxImportHelper(db, response.Body, w, user, accountid) } -func OFXFileImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) { +func OFXFileImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) { multipartReader, err := r.MultipartReader() if err != nil { WriteError(w, 3 /*Invalid Request*/) @@ -390,19 +390,19 @@ func OFXFileImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac return } - ofxImportHelper(part, w, user, accountid) + ofxImportHelper(db, part, w, user, accountid) } /* * Assumes the User is a valid, signed-in user, but accountid has not yet been validated */ -func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) { +func AccountImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) { switch importtype { case "ofx": - OFXImportHandler(w, r, user, accountid) + OFXImportHandler(db, w, r, user, accountid) case "ofxfile": - OFXFileImportHandler(w, r, user, accountid) + OFXFileImportHandler(db, w, r, user, accountid) default: WriteError(w, 3 /*Invalid Request*/) } diff --git a/main.go b/main.go index 0fbd86c..0ffb0e9 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ package main import ( "flag" + "gopkg.in/gorp.v1" "log" "net" "net/http" @@ -43,8 +44,6 @@ func init() { // Setup the logging flags to be printed log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) - - initDB(config) } func rootHandler(w http.ResponseWriter, r *http.Request) { @@ -55,18 +54,39 @@ func staticHandler(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, path.Join(config.MoneyGo.Basedir, r.URL.Path)) } -func main() { +// Create a closure over db, allowing the handlers to look like a +// http.HandlerFunc +type DB = gorp.DbMap +type DBHandler func(http.ResponseWriter, *http.Request, *DB) + +func DBHandlerFunc(h DBHandler, db *DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + h(w, r, db) + } +} + +func GetHandler(db *DB) http.Handler { servemux := http.NewServeMux() servemux.HandleFunc("/", rootHandler) servemux.HandleFunc("/static/", staticHandler) - servemux.HandleFunc("/session/", SessionHandler) - servemux.HandleFunc("/user/", UserHandler) - servemux.HandleFunc("/security/", SecurityHandler) + servemux.HandleFunc("/session/", DBHandlerFunc(SessionHandler, db)) + servemux.HandleFunc("/user/", DBHandlerFunc(UserHandler, db)) + servemux.HandleFunc("/security/", DBHandlerFunc(SecurityHandler, db)) servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler) - servemux.HandleFunc("/account/", AccountHandler) - servemux.HandleFunc("/transaction/", TransactionHandler) - servemux.HandleFunc("/import/gnucash", GnucashImportHandler) - servemux.HandleFunc("/report/", ReportHandler) + servemux.HandleFunc("/account/", DBHandlerFunc(AccountHandler, db)) + servemux.HandleFunc("/transaction/", DBHandlerFunc(TransactionHandler, db)) + servemux.HandleFunc("/import/gnucash", DBHandlerFunc(GnucashImportHandler, db)) + servemux.HandleFunc("/report/", DBHandlerFunc(ReportHandler, db)) + + return servemux +} + +func main() { + database, err := initDB(config) + if err != nil { + log.Fatal(err) + } + handler := GetHandler(database) listener, err := net.Listen("tcp", ":"+strconv.Itoa(config.MoneyGo.Port)) if err != nil { @@ -75,8 +95,8 @@ func main() { log.Printf("Serving on port %d out of directory: %s", config.MoneyGo.Port, config.MoneyGo.Basedir) if config.MoneyGo.Fcgi { - fcgi.Serve(listener, servemux) + fcgi.Serve(listener, handler) } else { - http.Serve(listener, servemux) + http.Serve(listener, handler) } } diff --git a/prices.go b/prices.go index b067594..2ed133a 100644 --- a/prices.go +++ b/prices.go @@ -1,8 +1,6 @@ package main import ( - "fmt" - "github.com/FlashBoys/go-finance" "gopkg.in/gorp.v1" "time" ) @@ -93,8 +91,8 @@ func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Securi } } -func GetClosestPrice(security, currency *Security, date *time.Time) (*Price, error) { - transaction, err := DB.Begin() +func GetClosestPrice(db *DB, security, currency *Security, date *time.Time) (*Price, error) { + transaction, err := db.Begin() if err != nil { return nil, err } @@ -113,10 +111,3 @@ func GetClosestPrice(security, currency *Security, date *time.Time) (*Price, err return price, nil } - -func init() { - q, err := finance.GetQuote("BRK-A") - if err == nil { - fmt.Printf("%+v", q) - } -} diff --git a/reports.go b/reports.go index 328e068..325e613 100644 --- a/reports.go +++ b/reports.go @@ -27,6 +27,7 @@ const ( accountsContextKey securitiesContextKey balanceContextKey + dbContextKey ) const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for @@ -76,36 +77,36 @@ func (r *Tabulation) Write(w http.ResponseWriter) error { return enc.Encode(r) } -func GetReport(reportid int64, userid int64) (*Report, error) { +func GetReport(db *DB, reportid int64, userid int64) (*Report, error) { var r Report - err := DB.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) + err := db.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) if err != nil { return nil, err } return &r, nil } -func GetReports(userid int64) (*[]Report, error) { +func GetReports(db *DB, userid int64) (*[]Report, error) { var reports []Report - _, err := DB.Select(&reports, "SELECT * from reports where UserId=?", userid) + _, err := db.Select(&reports, "SELECT * from reports where UserId=?", userid) if err != nil { return nil, err } return &reports, nil } -func InsertReport(r *Report) error { - err := DB.Insert(r) +func InsertReport(db *DB, r *Report) error { + err := db.Insert(r) if err != nil { return err } return nil } -func UpdateReport(r *Report) error { - count, err := DB.Update(r) +func UpdateReport(db *DB, r *Report) error { + count, err := db.Update(r) if err != nil { return err } @@ -115,8 +116,8 @@ func UpdateReport(r *Report) error { return nil } -func DeleteReport(r *Report) error { - count, err := DB.Delete(r) +func DeleteReport(db *DB, r *Report) error { + count, err := db.Delete(r) if err != nil { return err } @@ -126,13 +127,14 @@ func DeleteReport(r *Report) error { return nil } -func runReport(user *User, report *Report) (*Tabulation, error) { +func runReport(db *DB, user *User, report *Report) (*Tabulation, error) { // Create a new LState without opening the default libs for security L := lua.NewState(lua.Options{SkipOpenLibs: true}) defer L.Close() // Create a new context holding the current user with a timeout ctx := context.WithValue(context.Background(), userContextKey, user) + ctx = context.WithValue(ctx, dbContextKey, db) ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second) defer cancel() L.SetContext(ctx) @@ -189,14 +191,14 @@ func runReport(user *User, report *Report) (*Tabulation, error) { } } -func ReportTabulationHandler(w http.ResponseWriter, r *http.Request, user *User, reportid int64) { - report, err := GetReport(reportid, user.UserId) +func ReportTabulationHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, reportid int64) { + report, err := GetReport(db, reportid, user.UserId) if err != nil { WriteError(w, 3 /*Invalid Request*/) return } - tabulation, err := runReport(user, report) + tabulation, err := runReport(db, user, report) if err != nil { // TODO handle different failure cases differently log.Print("runReport returned:", err) @@ -214,8 +216,8 @@ func ReportTabulationHandler(w http.ResponseWriter, r *http.Request, user *User, } } -func ReportHandler(w http.ResponseWriter, r *http.Request) { - user, err := GetUserFromSession(r) +func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) { + user, err := GetUserFromSession(db, r) if err != nil { WriteError(w, 1 /*Not Signed In*/) return @@ -237,7 +239,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) { report.ReportId = -1 report.UserId = user.UserId - err = InsertReport(&report) + err = InsertReport(db, &report) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -260,7 +262,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) { log.Print(err) return } - ReportTabulationHandler(w, r, user, reportid) + ReportTabulationHandler(db, w, r, user, reportid) return } @@ -269,7 +271,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) { if err != nil || n != 1 { //Return all Reports var rl ReportList - reports, err := GetReports(user.UserId) + reports, err := GetReports(db, user.UserId) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -284,7 +286,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) { } } else { // Return Report with this Id - report, err := GetReport(reportid, user.UserId) + report, err := GetReport(db, reportid, user.UserId) if err != nil { WriteError(w, 3 /*Invalid Request*/) return @@ -319,7 +321,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) { } report.UserId = user.UserId - err = UpdateReport(&report) + err = UpdateReport(db, &report) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -333,13 +335,13 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) { return } } else if r.Method == "DELETE" { - report, err := GetReport(reportid, user.UserId) + report, err := GetReport(db, reportid, user.UserId) if err != nil { WriteError(w, 3 /*Invalid Request*/) return } - err = DeleteReport(report) + err = DeleteReport(db, report) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) diff --git a/securities.go b/securities.go index bb65461..bd4d148 100644 --- a/securities.go +++ b/securities.go @@ -96,10 +96,10 @@ func FindCurrencyTemplate(iso4217 int64) *Security { return nil } -func GetSecurity(securityid int64, userid int64) (*Security, error) { +func GetSecurity(db *DB, securityid int64, userid int64) (*Security, error) { var s Security - err := DB.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) + err := db.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) if err != nil { return nil, err } @@ -116,18 +116,18 @@ func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64 return &s, nil } -func GetSecurities(userid int64) (*[]*Security, error) { +func GetSecurities(db *DB, userid int64) (*[]*Security, error) { var securities []*Security - _, err := DB.Select(&securities, "SELECT * from securities where UserId=?", userid) + _, err := db.Select(&securities, "SELECT * from securities where UserId=?", userid) if err != nil { return nil, err } return &securities, nil } -func InsertSecurity(s *Security) error { - err := DB.Insert(s) +func InsertSecurity(db *DB, s *Security) error { + err := db.Insert(s) if err != nil { return err } @@ -142,8 +142,8 @@ func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error { return nil } -func UpdateSecurity(s *Security) error { - transaction, err := DB.Begin() +func UpdateSecurity(db *DB, s *Security) error { + transaction, err := db.Begin() if err != nil { return err } @@ -176,8 +176,8 @@ func UpdateSecurity(s *Security) error { return nil } -func DeleteSecurity(s *Security) error { - transaction, err := DB.Begin() +func DeleteSecurity(db *DB, s *Security) error { + transaction, err := db.Begin() if err != nil { return err } @@ -279,8 +279,8 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi return security, nil } -func SecurityHandler(w http.ResponseWriter, r *http.Request) { - user, err := GetUserFromSession(r) +func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) { + user, err := GetUserFromSession(db, r) if err != nil { WriteError(w, 1 /*Not Signed In*/) return @@ -302,7 +302,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) { security.SecurityId = -1 security.UserId = user.UserId - err = InsertSecurity(&security) + err = InsertSecurity(db, &security) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -324,7 +324,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) { //Return all securities var sl SecurityList - securities, err := GetSecurities(user.UserId) + securities, err := GetSecurities(db, user.UserId) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -339,7 +339,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) { return } } else { - security, err := GetSecurity(securityid, user.UserId) + security, err := GetSecurity(db, securityid, user.UserId) if err != nil { WriteError(w, 3 /*Invalid Request*/) return @@ -373,7 +373,7 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) { } security.UserId = user.UserId - err = UpdateSecurity(&security) + err = UpdateSecurity(db, &security) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -387,13 +387,13 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request) { return } } else if r.Method == "DELETE" { - security, err := GetSecurity(securityid, user.UserId) + security, err := GetSecurity(db, securityid, user.UserId) if err != nil { WriteError(w, 3 /*Invalid Request*/) return } - err = DeleteSecurity(security) + err = DeleteSecurity(db, security) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) diff --git a/securities_lua.go b/securities_lua.go index 6cd769a..dbd394f 100644 --- a/securities_lua.go +++ b/securities_lua.go @@ -13,14 +13,19 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { ctx := L.Context() - security_map, ok := ctx.Value(securitiesContextKey).(map[int64]*Security) + db, ok := ctx.Value(dbContextKey).(*DB) + if !ok { + return nil, errors.New("Couldn't find DB in lua's Context") + } + + security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security) if !ok { user, ok := ctx.Value(userContextKey).(*User) if !ok { return nil, errors.New("Couldn't find User in lua's Context") } - securities, err := GetSecurities(user.UserId) + securities, err := GetSecurities(db, user.UserId) if err != nil { return nil, err } @@ -149,7 +154,13 @@ func luaClosestPrice(L *lua.LState) int { c := luaCheckSecurity(L, 2) date := luaCheckTime(L, 3) - p, err := GetClosestPrice(s, c, date) + ctx := L.Context() + db, ok := ctx.Value(dbContextKey).(*DB) + if !ok { + panic("Couldn't find DB in lua's Context") + } + + p, err := GetClosestPrice(db, s, c, date) if err != nil { L.Push(lua.LNil) } else { diff --git a/sessions.go b/sessions.go index 880b43b..d3df3ae 100644 --- a/sessions.go +++ b/sessions.go @@ -22,7 +22,7 @@ func (s *Session) Write(w http.ResponseWriter) error { return enc.Encode(s) } -func GetSession(r *http.Request) (*Session, error) { +func GetSession(db *DB, r *http.Request) (*Session, error) { var s Session cookie, err := r.Cookie("moneygo-session") @@ -31,17 +31,18 @@ func GetSession(r *http.Request) (*Session, error) { } s.SessionSecret = cookie.Value - err = DB.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) + err = db.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) if err != nil { return nil, err } return &s, nil } -func DeleteSessionIfExists(r *http.Request) { - session, err := GetSession(r) +func DeleteSessionIfExists(db *DB, r *http.Request) { + // TODO do this in one transaction + session, err := GetSession(db, r) if err == nil { - DB.Delete(session) + db.Delete(session) } } @@ -53,7 +54,7 @@ func NewSessionCookie() (string, error) { return base64.StdEncoding.EncodeToString(bits), nil } -func NewSession(w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) { +func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) { s := Session{} session_secret, err := NewSessionCookie() @@ -75,14 +76,14 @@ func NewSession(w http.ResponseWriter, r *http.Request, userid int64) (*Session, s.SessionSecret = session_secret s.UserId = userid - err = DB.Insert(&s) + err = db.Insert(&s) if err != nil { return nil, err } return &s, nil } -func SessionHandler(w http.ResponseWriter, r *http.Request) { +func SessionHandler(w http.ResponseWriter, r *http.Request, db *DB) { if r.Method == "POST" || r.Method == "PUT" { user_json := r.PostFormValue("user") if user_json == "" { @@ -97,7 +98,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) { return } - dbuser, err := GetUserByUsername(user.Username) + dbuser, err := GetUserByUsername(db, user.Username) if err != nil { WriteError(w, 2 /*Unauthorized Access*/) return @@ -109,9 +110,9 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) { return } - DeleteSessionIfExists(r) + DeleteSessionIfExists(db, r) - session, err := NewSession(w, r, dbuser.UserId) + session, err := NewSession(db, w, r, dbuser.UserId) if err != nil { WriteError(w, 999 /*Internal Error*/) return @@ -124,7 +125,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) { return } } else if r.Method == "GET" { - s, err := GetSession(r) + s, err := GetSession(db, r) if err != nil { WriteError(w, 1 /*Not Signed In*/) return @@ -132,7 +133,7 @@ func SessionHandler(w http.ResponseWriter, r *http.Request) { s.Write(w) } else if r.Method == "DELETE" { - DeleteSessionIfExists(r) + DeleteSessionIfExists(db, r) WriteSuccess(w) } } diff --git a/transactions.go b/transactions.go index cc448c7..8ea2214 100644 --- a/transactions.go +++ b/transactions.go @@ -146,11 +146,7 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64] if t.Splits[i].AccountId != -1 { var err error var account *Account - if transaction != nil { - account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId) - } else { - account, err = GetAccount(t.Splits[i].AccountId, t.UserId) - } + account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId) if err != nil { return nil, err } @@ -164,16 +160,12 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64] return sums, nil } -func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) { - return t.GetImbalancesTx(nil) -} - // Returns true if all securities contained in this transaction are balanced, // false otherwise -func (t *Transaction) Balanced() (bool, error) { +func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) { var zero big.Rat - sums, err := t.GetImbalances() + sums, err := t.GetImbalancesTx(transaction) if err != nil { return false, err } @@ -186,21 +178,23 @@ func (t *Transaction) Balanced() (bool, error) { return true, nil } -func GetTransaction(transactionid int64, userid int64) (*Transaction, error) { +func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) { var t Transaction - transaction, err := DB.Begin() + transaction, err := db.Begin() if err != nil { return nil, err } err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) if err != nil { + transaction.Rollback() return nil, err } _, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) if err != nil { + transaction.Rollback() return nil, err } @@ -213,10 +207,10 @@ func GetTransaction(transactionid int64, userid int64) (*Transaction, error) { return &t, nil } -func GetTransactions(userid int64) (*[]Transaction, error) { +func GetTransactions(db *DB, userid int64) (*[]Transaction, error) { var transactions []Transaction - transaction, err := DB.Begin() + transaction, err := db.Begin() if err != nil { return nil, err } @@ -316,8 +310,8 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us return nil } -func InsertTransaction(t *Transaction, user *User) error { - transaction, err := DB.Begin() +func InsertTransaction(db *DB, t *Transaction, user *User) error { + transaction, err := db.Begin() if err != nil { return err } @@ -337,17 +331,11 @@ func InsertTransaction(t *Transaction, user *User) error { return nil } -func UpdateTransaction(t *Transaction, user *User) error { - transaction, err := DB.Begin() - if err != nil { - return err - } - +func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error { var existing_splits []*Split - _, err = transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) + _, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) if err != nil { - transaction.Rollback() return err } @@ -367,11 +355,9 @@ func UpdateTransaction(t *Transaction, user *User) error { if ok { count, err := transaction.Update(t.Splits[i]) if err != nil { - transaction.Rollback() return err } if count != 1 { - transaction.Rollback() return errors.New("Updated more than one transaction split") } delete(s_map, t.Splits[i].SplitId) @@ -379,7 +365,6 @@ func UpdateTransaction(t *Transaction, user *User) error { t.Splits[i].SplitId = -1 err := transaction.Insert(t.Splits[i]) if err != nil { - transaction.Rollback() return err } } @@ -397,7 +382,6 @@ func UpdateTransaction(t *Transaction, user *User) error { if ok { _, err := transaction.Delete(existing_splits[i]) if err != nil { - transaction.Rollback() return err } } @@ -410,31 +394,22 @@ func UpdateTransaction(t *Transaction, user *User) error { } err = incrementAccountVersions(transaction, user, a_ids) if err != nil { - transaction.Rollback() return err } count, err := transaction.Update(t) if err != nil { - transaction.Rollback() return err } if count != 1 { - transaction.Rollback() return errors.New("Updated more than one transaction") } - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return err - } - return nil } -func DeleteTransaction(t *Transaction, user *User) error { - transaction, err := DB.Begin() +func DeleteTransaction(db *DB, t *Transaction, user *User) error { + transaction, err := db.Begin() if err != nil { return err } @@ -477,8 +452,8 @@ func DeleteTransaction(t *Transaction, user *User) error { return nil } -func TransactionHandler(w http.ResponseWriter, r *http.Request) { - user, err := GetUserFromSession(r) +func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) { + user, err := GetUserFromSession(db, r) if err != nil { WriteError(w, 1 /*Not Signed In*/) return @@ -500,27 +475,37 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) { transaction.TransactionId = -1 transaction.UserId = user.UserId - balanced, err := transaction.Balanced() + sqltx, err := db.Begin() if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } + + balanced, err := transaction.Balanced(sqltx) + if err != nil { + sqltx.Rollback() + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } if !transaction.Valid() || !balanced { + sqltx.Rollback() WriteError(w, 3 /*Invalid Request*/) return } for i := range transaction.Splits { transaction.Splits[i].SplitId = -1 - _, err := GetAccount(transaction.Splits[i].AccountId, user.UserId) + _, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId) if err != nil { + sqltx.Rollback() WriteError(w, 3 /*Invalid Request*/) return } } - err = InsertTransaction(&transaction, user) + err = InsertTransactionTx(sqltx, &transaction, user) if err != nil { if _, ok := err.(AccountMissingError); ok { WriteError(w, 3 /*Invalid Request*/) @@ -528,6 +513,15 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) { WriteError(w, 999 /*Internal Error*/) log.Print(err) } + sqltx.Rollback() + return + } + + err = sqltx.Commit() + if err != nil { + sqltx.Rollback() + WriteError(w, 999 /*Internal Error*/) + log.Print(err) return } @@ -543,7 +537,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) { if err != nil { //Return all Transactions var al TransactionList - transactions, err := GetTransactions(user.UserId) + transactions, err := GetTransactions(db, user.UserId) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -558,7 +552,7 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) { } } else { //Return Transaction with this Id - transaction, err := GetTransaction(transactionid, user.UserId) + transaction, err := GetTransaction(db, transactionid, user.UserId) if err != nil { WriteError(w, 3 /*Invalid Request*/) return @@ -591,27 +585,46 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) { } transaction.UserId = user.UserId - balanced, err := transaction.Balanced() + sqltx, err := db.Begin() if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } + + balanced, err := transaction.Balanced(sqltx) + if err != nil { + sqltx.Rollback() + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } if !transaction.Valid() || !balanced { + sqltx.Rollback() WriteError(w, 3 /*Invalid Request*/) return } for i := range transaction.Splits { - _, err := GetAccount(transaction.Splits[i].AccountId, user.UserId) + _, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId) if err != nil { + sqltx.Rollback() WriteError(w, 3 /*Invalid Request*/) return } } - err = UpdateTransaction(&transaction, user) + err = UpdateTransactionTx(sqltx, &transaction, user) if err != nil { + sqltx.Rollback() + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + + err = sqltx.Commit() + if err != nil { + sqltx.Rollback() WriteError(w, 999 /*Internal Error*/) log.Print(err) return @@ -630,13 +643,13 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) { return } - transaction, err := GetTransaction(transactionid, user.UserId) + transaction, err := GetTransaction(db, transactionid, user.UserId) if err != nil { WriteError(w, 3 /*Invalid Request*/) return } - err = DeleteTransaction(transaction, user) + err = DeleteTransaction(db, transaction, user) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -672,9 +685,9 @@ func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int6 return &pageDifference, nil } -func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) { +func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) { var splits []Split - transaction, err := DB.Begin() + transaction, err := db.Begin() if err != nil { return nil, err } @@ -707,9 +720,9 @@ func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) { } // Assumes accountid is valid and is owned by the current user -func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.Rat, error) { +func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) { var splits []Split - transaction, err := DB.Begin() + transaction, err := db.Begin() if err != nil { return nil, err } @@ -741,9 +754,9 @@ func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.R return &balance, nil } -func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { +func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { var splits []Split - transaction, err := DB.Begin() + transaction, err := db.Begin() if err != nil { return nil, err } @@ -775,11 +788,11 @@ func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Ti return &balance, nil } -func GetAccountTransactions(user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { +func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { var transactions []Transaction var atl AccountTransactionsList - transaction, err := DB.Begin() + transaction, err := db.Begin() if err != nil { return nil, err } @@ -878,7 +891,7 @@ func GetAccountTransactions(user *User, accountid int64, sort string, page uint6 // Return only those transactions which have at least one split pertaining to // an account -func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request, +func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) { var page uint64 = 0 @@ -916,7 +929,7 @@ func AccountTransactionsHandler(w http.ResponseWriter, r *http.Request, sort = sortstring } - accountTransactions, err := GetAccountTransactions(user, accountid, sort, page, limit) + accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) diff --git a/users.go b/users.go index e994188..83e7b77 100644 --- a/users.go +++ b/users.go @@ -47,10 +47,10 @@ func (u *User) HashPassword() { u.Password = "" } -func GetUser(userid int64) (*User, error) { +func GetUser(db *DB, userid int64) (*User, error) { var u User - err := DB.SelectOne(&u, "SELECT * from users where UserId=?", userid) + err := db.SelectOne(&u, "SELECT * from users where UserId=?", userid) if err != nil { return nil, err } @@ -67,18 +67,18 @@ func GetUserTx(transaction *gorp.Transaction, userid int64) (*User, error) { return &u, nil } -func GetUserByUsername(username string) (*User, error) { +func GetUserByUsername(db *DB, username string) (*User, error) { var u User - err := DB.SelectOne(&u, "SELECT * from users where Username=?", username) + err := db.SelectOne(&u, "SELECT * from users where Username=?", username) if err != nil { return nil, err } return &u, nil } -func InsertUser(u *User) error { - transaction, err := DB.Begin() +func InsertUser(db *DB, u *User) error { + transaction, err := db.Begin() if err != nil { return err } @@ -136,16 +136,16 @@ func InsertUser(u *User) error { return nil } -func GetUserFromSession(r *http.Request) (*User, error) { - s, err := GetSession(r) +func GetUserFromSession(db *DB, r *http.Request) (*User, error) { + s, err := GetSession(db, r) if err != nil { return nil, err } - return GetUser(s.UserId) + return GetUser(db, s.UserId) } -func UpdateUser(u *User) error { - transaction, err := DB.Begin() +func UpdateUser(db *DB, u *User) error { + transaction, err := db.Begin() if err != nil { return err } @@ -180,7 +180,7 @@ func UpdateUser(u *User) error { return nil } -func UserHandler(w http.ResponseWriter, r *http.Request) { +func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) { if r.Method == "POST" { user_json := r.PostFormValue("user") if user_json == "" { @@ -197,7 +197,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) { user.UserId = -1 user.HashPassword() - err = InsertUser(&user) + err = InsertUser(db, &user) if err != nil { if _, ok := err.(UserExistsError); ok { WriteError(w, 4 /*User Exists*/) @@ -216,7 +216,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) { return } } else { - user, err := GetUserFromSession(r) + user, err := GetUserFromSession(db, r) if err != nil { WriteError(w, 1 /*Not Signed In*/) return @@ -264,7 +264,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) { user.PasswordHash = old_pwhash } - err = UpdateUser(user) + err = UpdateUser(db, user) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) @@ -278,7 +278,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request) { return } } else if r.Method == "DELETE" { - count, err := DB.Delete(&user) + count, err := db.Delete(&user) if count != 1 || err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err)