From 507868b7a53cdc5e284fc7b79a123fc0730be45b Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sun, 12 Nov 2017 20:17:27 -0500 Subject: [PATCH] Begin move away from using http.ServeMux --- internal/handlers/accounts.go | 24 ++--- internal/handlers/common_test.go | 3 +- internal/handlers/gnucash.go | 14 +-- internal/handlers/handlers.go | 108 ++++++++++++------- internal/handlers/imports.go | 25 +++-- internal/handlers/prices.go | 24 ++--- internal/handlers/reports.go | 18 ++-- internal/handlers/securities.go | 33 +++--- internal/handlers/security_templates_test.go | 18 ++-- internal/handlers/sessions.go | 12 +-- internal/handlers/transactions.go | 28 ++--- internal/handlers/users.go | 10 +- main.go | 3 +- 13 files changed, 179 insertions(+), 141 deletions(-) diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index b113a91..379c1ce 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -377,8 +377,8 @@ func DeleteAccount(tx *Tx, a *Account) error { return nil } -func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter { - user, err := GetUserFromSession(tx, r) +func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { + user, err := GetUserFromSession(context.Tx, r) if err != nil { return NewError(1 /*Not Signed In*/) } @@ -395,7 +395,7 @@ func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter { log.Print(err) return NewError(999 /*Internal Error*/) } - return AccountImportHandler(tx, r, user, accountid, importtype) + return AccountImportHandler(context, r, user, accountid, importtype) } account_json := r.PostFormValue("account") @@ -412,7 +412,7 @@ func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter { account.UserId = user.UserId account.AccountVersion = 0 - security, err := GetSecurity(tx, account.SecurityId, user.UserId) + security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -421,7 +421,7 @@ func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - err = InsertAccount(tx, &account) + err = InsertAccount(context.Tx, &account) if err != nil { if _, ok := err.(ParentAccountMissingError); ok { return NewError(3 /*Invalid Request*/) @@ -439,7 +439,7 @@ func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter { if err != nil || n != 1 { //Return all Accounts var al AccountList - accounts, err := GetAccounts(tx, user.UserId) + accounts, err := GetAccounts(context.Tx, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -450,11 +450,11 @@ func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter { // if URL looks like /account/[0-9]+/transactions, use the account // transaction handler if accountTransactionsRE.MatchString(r.URL.Path) { - return AccountTransactionsHandler(tx, r, user, accountid) + return AccountTransactionsHandler(context, r, user, accountid) } // Return Account with this Id - account, err := GetAccount(tx, accountid, user.UserId) + account, err := GetAccount(context.Tx, accountid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -479,7 +479,7 @@ func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } account.UserId = user.UserId - security, err := GetSecurity(tx, account.SecurityId, user.UserId) + security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -492,7 +492,7 @@ func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - err = UpdateAccount(tx, &account) + err = UpdateAccount(context.Tx, &account) if err != nil { if _, ok := err.(ParentAccountMissingError); ok { return NewError(3 /*Invalid Request*/) @@ -506,12 +506,12 @@ func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return &account } else if r.Method == "DELETE" { - account, err := GetAccount(tx, accountid, user.UserId) + account, err := GetAccount(context.Tx, accountid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - err = DeleteAccount(tx, account) + err = DeleteAccount(context.Tx, account) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/common_test.go b/internal/handlers/common_test.go index fadaf1e..f664473 100644 --- a/internal/handlers/common_test.go +++ b/internal/handlers/common_test.go @@ -186,8 +186,7 @@ func RunTests(m *testing.M) int { log.Fatal(err) } - servemux := handlers.GetHandler(dbmap) - server = httptest.NewTLSServer(servemux) + server = httptest.NewTLSServer(&handlers.APIHandler{DB: dbmap}) defer server.Close() return m.Run() diff --git a/internal/handlers/gnucash.go b/internal/handlers/gnucash.go index c787c31..713dd84 100644 --- a/internal/handlers/gnucash.go +++ b/internal/handlers/gnucash.go @@ -308,8 +308,8 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { return &gncimport, nil } -func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { - user, err := GetUserFromSession(tx, r) +func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWriter { + user, err := GetUserFromSession(context.Tx, r) if err != nil { return NewError(1 /*Not Signed In*/) } @@ -363,7 +363,7 @@ func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { securityMap := make(map[int64]int64) for _, security := range gnucashImport.Securities { securityId := security.SecurityId // save off because it could be updated - s, err := ImportGetCreateSecurity(tx, user.UserId, &security) + s, err := ImportGetCreateSecurity(context.Tx, user.UserId, &security) if err != nil { log.Print(err) log.Print(security) @@ -378,7 +378,7 @@ func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { price.CurrencyId = securityMap[price.CurrencyId] price.PriceId = 0 - err := CreatePriceIfNotExist(tx, &price) + err := CreatePriceIfNotExist(context.Tx, &price) if err != nil { log.Print(err) return NewError(6 /*Import Error*/) @@ -407,7 +407,7 @@ func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { account.ParentAccountId = accountMap[account.ParentAccountId] } account.SecurityId = securityMap[account.SecurityId] - a, err := GetCreateAccount(tx, account) + a, err := GetCreateAccount(context.Tx, account) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -436,7 +436,7 @@ func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } split.AccountId = acctId - exists, err := split.AlreadyImported(tx) + exists, err := split.AlreadyImported(context.Tx) if err != nil { log.Print("Error checking if split was already imported:", err) return NewError(999 /*Internal Error*/) @@ -445,7 +445,7 @@ func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } } if !already_imported { - err := InsertTransaction(tx, &transaction, user) + err := InsertTransaction(context.Tx, &transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 9654ca7..2bce70d 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -4,6 +4,8 @@ import ( "gopkg.in/gorp.v1" "log" "net/http" + "path" + "strings" ) // But who writes the ResponseWriterWriter? @@ -11,56 +13,84 @@ type ResponseWriterWriter interface { Write(http.ResponseWriter) error } type Tx = gorp.Transaction -type TxHandler func(*http.Request, *Tx) ResponseWriterWriter +type Context struct { + Tx *Tx + User *User + Remaining string // portion of URL not yet reached in the hierarchy +} +type Handler func(*http.Request, *Context) ResponseWriterWriter -func TxHandlerFunc(t TxHandler, db *gorp.DbMap) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - tx, err := db.Begin() - if err != nil { - log.Print(err) - WriteError(w, 999 /*Internal Error*/) - return - } - defer func() { - if r := recover(); r != nil { - tx.Rollback() - WriteError(w, 999 /*Internal Error*/) - panic(r) - } - }() +func NextLevel(previous string) (current, remaining string) { + split := strings.SplitN(previous, "/", 2) + if len(split) == 2 { + return split[0], split[1] + } + return split[0], "" +} - writer := t(r, tx) +type APIHandler struct { + DB *gorp.DbMap +} - if e, ok := writer.(*Error); ok { +func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { + tx, err := ah.DB.Begin() + if err != nil { + log.Print(err) + return NewError(999 /*Internal Error*/) + } + defer func() { + if r := recover(); r != nil { + tx.Rollback() + panic(r) + } + if _, ok := writer.(*Error); ok { tx.Rollback() - e.Write(w) } else { err = tx.Commit() if err != nil { log.Print(err) - WriteError(w, 999 /*Internal Error*/) - } else { - err = writer.Write(w) - if err != nil { - log.Print(err) - WriteError(w, 999 /*Internal Error*/) - } + writer = NewError(999 /*Internal Error*/) } } + }() + + context.Tx = tx + return h(r, context) +} + +func (ah *APIHandler) route(r *http.Request) ResponseWriterWriter { + current, remaining := NextLevel(path.Clean("/" + r.URL.Path)[1:]) + if current != "v1" { + return NewError(3 /*Invalid Request*/) + } + + current, remaining = NextLevel(remaining) + context := &Context{Remaining: remaining} + + switch current { + case "sessions": + return ah.txWrapper(SessionHandler, r, context) + case "users": + return ah.txWrapper(UserHandler, r, context) + case "securities": + return ah.txWrapper(SecurityHandler, r, context) + case "securitytemplates": + return SecurityTemplateHandler(r, context) + case "prices": + return ah.txWrapper(PriceHandler, r, context) + case "accounts": + return ah.txWrapper(AccountHandler, r, context) + case "transactions": + return ah.txWrapper(TransactionHandler, r, context) + case "imports": + return ah.txWrapper(ImportHandler, r, context) + case "reports": + return ah.txWrapper(ReportHandler, r, context) + default: + return NewError(3 /*Invalid Request*/) } } -func GetHandler(db *gorp.DbMap) *http.ServeMux { - servemux := http.NewServeMux() - servemux.HandleFunc("/v1/sessions/", TxHandlerFunc(SessionHandler, db)) - servemux.HandleFunc("/v1/users/", TxHandlerFunc(UserHandler, db)) - servemux.HandleFunc("/v1/securities/", TxHandlerFunc(SecurityHandler, db)) - servemux.HandleFunc("/v1/prices/", TxHandlerFunc(PriceHandler, db)) - servemux.HandleFunc("/v1/securitytemplates/", SecurityTemplateHandler) - servemux.HandleFunc("/v1/accounts/", TxHandlerFunc(AccountHandler, db)) - servemux.HandleFunc("/v1/transactions/", TxHandlerFunc(TransactionHandler, db)) - servemux.HandleFunc("/v1/imports/gnucash", TxHandlerFunc(GnucashImportHandler, db)) - servemux.HandleFunc("/v1/reports/", TxHandlerFunc(ReportHandler, db)) - - return servemux +func (ah *APIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ah.route(r).Write(w) } diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 47ae0dc..786ca62 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -210,7 +210,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW return SuccessWriter{} } -func OFXImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter { +func OFXImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { download_json := r.PostFormValue("ofxdownload") if download_json == "" { return NewError(3 /*Invalid Request*/) @@ -222,7 +222,7 @@ func OFXImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) Resp return NewError(3 /*Invalid Request*/) } - account, err := GetAccount(tx, accountid, user.UserId) + account, err := GetAccount(context.Tx, accountid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -308,10 +308,10 @@ func OFXImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) Resp } defer response.Body.Close() - return ofxImportHelper(tx, response.Body, user, accountid) + return ofxImportHelper(context.Tx, response.Body, user, accountid) } -func OFXFileImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter { +func OFXFileImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { multipartReader, err := r.MultipartReader() if err != nil { return NewError(3 /*Invalid Request*/) @@ -329,20 +329,29 @@ func OFXFileImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) } } - return ofxImportHelper(tx, part, user, accountid) + return ofxImportHelper(context.Tx, part, user, accountid) } /* * Assumes the User is a valid, signed-in user, but accountid has not yet been validated */ -func AccountImportHandler(tx *Tx, r *http.Request, user *User, accountid int64, importtype string) ResponseWriterWriter { +func AccountImportHandler(context *Context, r *http.Request, user *User, accountid int64, importtype string) ResponseWriterWriter { switch importtype { case "ofx": - return OFXImportHandler(tx, r, user, accountid) + return OFXImportHandler(context, r, user, accountid) case "ofxfile": - return OFXFileImportHandler(tx, r, user, accountid) + return OFXFileImportHandler(context, r, user, accountid) default: return NewError(3 /*Invalid Request*/) } } + +func ImportHandler(r *http.Request, context *Context) ResponseWriterWriter { + current, remaining := NextLevel(context.Remaining) + if current != "gnucash" { + return NewError(3 /*Invalid Request*/) + } + context.Remaining = remaining + return GnucashImportHandler(r, context) +} diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index 8f0acb0..e2240fd 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -129,8 +129,8 @@ func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Pr } } -func PriceHandler(r *http.Request, tx *Tx) ResponseWriterWriter { - user, err := GetUserFromSession(tx, r) +func PriceHandler(r *http.Request, context *Context) ResponseWriterWriter { + user, err := GetUserFromSession(context.Tx, r) if err != nil { return NewError(1 /*Not Signed In*/) } @@ -148,16 +148,16 @@ func PriceHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } price.PriceId = -1 - _, err = GetSecurity(tx, price.SecurityId, user.UserId) + _, err = GetSecurity(context.Tx, price.SecurityId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - _, err = GetSecurity(tx, price.CurrencyId, user.UserId) + _, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - err = tx.Insert(&price) + err = context.Tx.Insert(&price) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -172,7 +172,7 @@ func PriceHandler(r *http.Request, tx *Tx) ResponseWriterWriter { //Return all prices var pl PriceList - prices, err := GetPrices(tx, user.UserId) + prices, err := GetPrices(context.Tx, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -181,7 +181,7 @@ func PriceHandler(r *http.Request, tx *Tx) ResponseWriterWriter { pl.Prices = prices return &pl } else { - price, err := GetPrice(tx, priceid, user.UserId) + price, err := GetPrice(context.Tx, priceid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -205,16 +205,16 @@ func PriceHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - _, err = GetSecurity(tx, price.SecurityId, user.UserId) + _, err = GetSecurity(context.Tx, price.SecurityId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - _, err = GetSecurity(tx, price.CurrencyId, user.UserId) + _, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - count, err := tx.Update(&price) + count, err := context.Tx.Update(&price) if err != nil || count != 1 { log.Print(err) return NewError(999 /*Internal Error*/) @@ -222,12 +222,12 @@ func PriceHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return &price } else if r.Method == "DELETE" { - price, err := GetPrice(tx, priceid, user.UserId) + price, err := GetPrice(context.Tx, priceid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - count, err := tx.Delete(price) + count, err := context.Tx.Delete(price) if err != nil || count != 1 { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index d6b4495..905ba01 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -223,8 +223,8 @@ func ReportTabulationHandler(tx *Tx, r *http.Request, user *User, reportid int64 return tabulation } -func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { - user, err := GetUserFromSession(tx, r) +func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { + user, err := GetUserFromSession(context.Tx, r) if err != nil { return NewError(1 /*Not Signed In*/) } @@ -247,7 +247,7 @@ func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - err = InsertReport(tx, &report) + err = InsertReport(context.Tx, &report) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -262,7 +262,7 @@ func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { log.Print(err) return NewError(999 /*InternalError*/) } - return ReportTabulationHandler(tx, r, user, reportid) + return ReportTabulationHandler(context.Tx, r, user, reportid) } var reportid int64 @@ -270,7 +270,7 @@ func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { if err != nil || n != 1 { //Return all Reports var rl ReportList - reports, err := GetReports(tx, user.UserId) + reports, err := GetReports(context.Tx, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -279,7 +279,7 @@ func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return &rl } else { // Return Report with this Id - report, err := GetReport(tx, reportid, user.UserId) + report, err := GetReport(context.Tx, reportid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -309,7 +309,7 @@ func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - err = UpdateReport(tx, &report) + err = UpdateReport(context.Tx, &report) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -317,12 +317,12 @@ func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return &report } else if r.Method == "DELETE" { - report, err := GetReport(tx, reportid, user.UserId) + report, err := GetReport(context.Tx, reportid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - err = DeleteReport(tx, report) + err = DeleteReport(context.Tx, report) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index 8d3d664..03ce542 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -246,8 +246,8 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *Security) (*Securit return security, nil } -func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { - user, err := GetUserFromSession(tx, r) +func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { + user, err := GetUserFromSession(context.Tx, r) if err != nil { return NewError(1 /*Not Signed In*/) } @@ -266,7 +266,7 @@ func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { security.SecurityId = -1 security.UserId = user.UserId - err = InsertSecurity(tx, &security) + err = InsertSecurity(context.Tx, &security) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -281,7 +281,7 @@ func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { //Return all securities var sl SecurityList - securities, err := GetSecurities(tx, user.UserId) + securities, err := GetSecurities(context.Tx, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -290,7 +290,7 @@ func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { sl.Securities = securities return &sl } else { - security, err := GetSecurity(tx, securityid, user.UserId) + security, err := GetSecurity(context.Tx, securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -315,7 +315,7 @@ func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } security.UserId = user.UserId - err = UpdateSecurity(tx, &security) + err = UpdateSecurity(context.Tx, &security) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -323,12 +323,12 @@ func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return &security } else if r.Method == "DELETE" { - security, err := GetSecurity(tx, securityid, user.UserId) + security, err := GetSecurity(context.Tx, securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - err = DeleteSecurity(tx, security) + err = DeleteSecurity(context.Tx, security) if _, ok := err.(SecurityInUseError); ok { return NewError(7 /*In Use Error*/) } else if err != nil { @@ -342,7 +342,7 @@ func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } -func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) { +func SecurityTemplateHandler(r *http.Request, context *Context) ResponseWriterWriter { if r.Method == "GET" { var sl SecurityList @@ -356,8 +356,7 @@ func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) { if len(typestring) > 0 { _type = GetSecurityType(typestring) if _type == 0 { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } } @@ -365,8 +364,7 @@ func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) { if limitstring != "" { limitint, err := strconv.ParseInt(limitstring, 10, 0) if err != nil { - WriteError(w, 3 /*Invalid Request*/) - return + return NewError(3 /*Invalid Request*/) } limit = limitint } @@ -374,13 +372,8 @@ func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) { securities := SearchSecurityTemplates(search, _type, limit) sl.Securities = &securities - err := (&sl).Write(w) - if err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } + return &sl } else { - WriteError(w, 3 /*Invalid Request*/) + return NewError(3 /*Invalid Request*/) } } diff --git a/internal/handlers/security_templates_test.go b/internal/handlers/security_templates_test.go index 82acd86..04baac6 100644 --- a/internal/handlers/security_templates_test.go +++ b/internal/handlers/security_templates_test.go @@ -28,13 +28,15 @@ func TestSecurityTemplates(t *testing.T) { } num_usd := 0 - for _, s := range *sl.Securities { - if s.Type != handlers.Currency { - t.Fatalf("Requested Currency-only security templates, received a non-Currency template for %s", s.Name) - } + if sl.Securities != nil { + for _, s := range *sl.Securities { + if s.Type != handlers.Currency { + t.Fatalf("Requested Currency-only security templates, received a non-Currency template for %s", s.Name) + } - if s.Name == "USD" && s.AlternateId == "840" { - num_usd++ + if s.Name == "USD" && s.AlternateId == "840" { + num_usd++ + } } } @@ -64,6 +66,10 @@ func TestSecurityTemplateLimit(t *testing.T) { t.Fatal(err) } + if sl.Securities == nil { + t.Fatalf("Securities was unexpectedly nil\n") + } + if len(*sl.Securities) > 5 { t.Fatalf("Requested only 5 securities, received %d\n", len(*sl.Securities)) } diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index 8f61a7f..d7943fa 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -101,7 +101,7 @@ func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error return &NewSessionWriter{&s, &cookie}, nil } -func SessionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { +func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter { if r.Method == "POST" || r.Method == "PUT" { user_json := r.PostFormValue("user") if user_json == "" { @@ -114,7 +114,7 @@ func SessionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - dbuser, err := GetUserByUsername(tx, user.Username) + dbuser, err := GetUserByUsername(context.Tx, user.Username) if err != nil { return NewError(2 /*Unauthorized Access*/) } @@ -124,27 +124,27 @@ func SessionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return NewError(2 /*Unauthorized Access*/) } - err = DeleteSessionIfExists(tx, r) + err = DeleteSessionIfExists(context.Tx, r) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) } - sessionwriter, err := NewSession(tx, r, dbuser.UserId) + sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) } return sessionwriter } else if r.Method == "GET" { - s, err := GetSession(tx, r) + s, err := GetSession(context.Tx, r) if err != nil { return NewError(1 /*Not Signed In*/) } return s } else if r.Method == "DELETE" { - err := DeleteSessionIfExists(tx, r) + err := DeleteSessionIfExists(context.Tx, r) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index d85ea65..5752b5d 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -400,8 +400,8 @@ func DeleteTransaction(tx *Tx, t *Transaction, user *User) error { return nil } -func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { - user, err := GetUserFromSession(tx, r) +func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter { + user, err := GetUserFromSession(context.Tx, r) if err != nil { return NewError(1 /*Not Signed In*/) } @@ -426,13 +426,13 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { for i := range transaction.Splits { transaction.Splits[i].SplitId = -1 - _, err := GetAccount(tx, transaction.Splits[i].AccountId, user.UserId) + _, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } } - balanced, err := transaction.Balanced(tx) + balanced, err := transaction.Balanced(context.Tx) if err != nil { return NewError(999 /*Internal Error*/) } @@ -440,7 +440,7 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - err = InsertTransaction(tx, &transaction, user) + err = InsertTransaction(context.Tx, &transaction, user) if err != nil { if _, ok := err.(AccountMissingError); ok { return NewError(3 /*Invalid Request*/) @@ -457,7 +457,7 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { if err != nil { //Return all Transactions var al TransactionList - transactions, err := GetTransactions(tx, user.UserId) + transactions, err := GetTransactions(context.Tx, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -466,7 +466,7 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return &al } else { //Return Transaction with this Id - transaction, err := GetTransaction(tx, transactionid, user.UserId) + transaction, err := GetTransaction(context.Tx, transactionid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -490,7 +490,7 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } transaction.UserId = user.UserId - balanced, err := transaction.Balanced(tx) + balanced, err := transaction.Balanced(context.Tx) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -504,13 +504,13 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } for i := range transaction.Splits { - _, err := GetAccount(tx, transaction.Splits[i].AccountId, user.UserId) + _, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } } - err = UpdateTransaction(tx, &transaction, user) + err = UpdateTransaction(context.Tx, &transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -523,12 +523,12 @@ func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - transaction, err := GetTransaction(tx, transactionid, user.UserId) + transaction, err := GetTransaction(context.Tx, transactionid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - err = DeleteTransaction(tx, transaction, user) + err = DeleteTransaction(context.Tx, transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -714,7 +714,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa // Return only those transactions which have at least one split pertaining to // an account -func AccountTransactionsHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter { +func AccountTransactionsHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { var page uint64 = 0 var limit uint64 = 50 var sort string = "date-desc" @@ -747,7 +747,7 @@ func AccountTransactionsHandler(tx *Tx, r *http.Request, user *User, accountid i sort = sortstring } - accountTransactions, err := GetAccountTransactions(tx, user, accountid, sort, page, limit) + accountTransactions, err := GetAccountTransactions(context.Tx, user, accountid, sort, page, limit) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/users.go b/internal/handlers/users.go index b73e5e8..5041046 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -175,7 +175,7 @@ func DeleteUser(tx *Tx, u *User) error { return nil } -func UserHandler(r *http.Request, tx *Tx) ResponseWriterWriter { +func UserHandler(r *http.Request, context *Context) ResponseWriterWriter { if r.Method == "POST" { user_json := r.PostFormValue("user") if user_json == "" { @@ -190,7 +190,7 @@ func UserHandler(r *http.Request, tx *Tx) ResponseWriterWriter { user.UserId = -1 user.HashPassword() - err = InsertUser(tx, &user) + err = InsertUser(context.Tx, &user) if err != nil { if _, ok := err.(UserExistsError); ok { return NewError(4 /*User Exists*/) @@ -202,7 +202,7 @@ func UserHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return ResponseWrapper{201, &user} } else { - user, err := GetUserFromSession(tx, r) + user, err := GetUserFromSession(context.Tx, r) if err != nil { return NewError(1 /*Not Signed In*/) } @@ -240,7 +240,7 @@ func UserHandler(r *http.Request, tx *Tx) ResponseWriterWriter { user.PasswordHash = old_pwhash } - err = UpdateUser(tx, user) + err = UpdateUser(context.Tx, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -248,7 +248,7 @@ func UserHandler(r *http.Request, tx *Tx) ResponseWriterWriter { return user } else if r.Method == "DELETE" { - err := DeleteUser(tx, user) + err := DeleteUser(context.Tx, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/main.go b/main.go index d52689a..1a78a62 100644 --- a/main.go +++ b/main.go @@ -79,7 +79,8 @@ func main() { } // Get ServeMux for API and add our own handlers for files - servemux := handlers.GetHandler(dbmap) + servemux := http.NewServeMux() + servemux.Handle("/v1/", &handlers.APIHandler{DB: dbmap}) servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir)) servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir))