From c452984f23538e3d0b14e9482f29a52c1a86d8a7 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Wed, 6 Dec 2017 21:09:47 -0500 Subject: [PATCH 1/9] Lay groundwork and move sessions to 'store' --- internal/handlers/accounts.go | 19 +++++----- internal/handlers/accounts_lua.go | 5 ++- internal/handlers/common_test.go | 17 +++------ internal/handlers/handlers.go | 37 ++++++++++++++++-- internal/handlers/imports.go | 3 +- internal/handlers/prices.go | 13 ++++--- internal/handlers/reports.go | 15 ++++---- internal/handlers/securities.go | 13 ++++--- internal/handlers/securities_lua.go | 5 ++- internal/handlers/sessions.go | 46 +++++++++++----------- internal/handlers/transactions.go | 29 +++++++------- internal/handlers/tx.go | 59 ++--------------------------- internal/handlers/users.go | 13 ++++--- internal/{ => store}/db/db.go | 38 +++++++++++++++++++ internal/store/db/sessions.go | 36 ++++++++++++++++++ internal/store/db/tx.go | 57 ++++++++++++++++++++++++++++ internal/store/store.go | 24 ++++++++++++ main.go | 15 ++------ 18 files changed, 286 insertions(+), 158 deletions(-) rename internal/{ => store}/db/db.go (74%) create mode 100644 internal/store/db/sessions.go create mode 100644 internal/store/db/tx.go create mode 100644 internal/store/store.go diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index 2812fb4..9ef9521 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -3,11 +3,12 @@ package handlers import ( "errors" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" ) -func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) { +func GetAccount(tx *db.Tx, accountid int64, userid int64) (*models.Account, error) { var a models.Account err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) @@ -17,7 +18,7 @@ func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) return &a, nil } -func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) { +func GetAccounts(tx *db.Tx, userid int64) (*[]models.Account, error) { var accounts []models.Account _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) @@ -29,7 +30,7 @@ func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) { // Get (and attempt to create if it doesn't exist). Matches on UserId, // SecurityId, Type, Name, and ParentAccountId -func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) { +func GetCreateAccount(tx *db.Tx, a models.Account) (*models.Account, error) { var accounts []models.Account var account models.Account @@ -57,7 +58,7 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) { // Get (and attempt to create if it doesn't exist) the security/currency // trading account for the supplied security/currency -func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) { +func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) { var tradingAccount models.Account var account models.Account @@ -99,7 +100,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, // Get (and attempt to create if it doesn't exist) the security/currency // imbalance account for the supplied security/currency -func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) { +func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) { var imbalanceAccount models.Account var account models.Account xxxtemplate := FindSecurityTemplate("XXX", models.Currency) @@ -160,7 +161,7 @@ func (cae CircularAccountsError) Error() string { return "Would result in circular account relationship" } -func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error { +func insertUpdateAccount(tx *db.Tx, a *models.Account, insert bool) error { found := make(map[int64]bool) if !insert { found[a.AccountId] = true @@ -216,15 +217,15 @@ func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error { return nil } -func InsertAccount(tx *Tx, a *models.Account) error { +func InsertAccount(tx *db.Tx, a *models.Account) error { return insertUpdateAccount(tx, a, true) } -func UpdateAccount(tx *Tx, a *models.Account) error { +func UpdateAccount(tx *db.Tx, a *models.Account) error { return insertUpdateAccount(tx, a, false) } -func DeleteAccount(tx *Tx, a *models.Account) error { +func DeleteAccount(tx *db.Tx, a *models.Account) error { if a.ParentAccountId != -1 { // Re-parent splits to this account's parent account if this account isn't a root account _, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) diff --git a/internal/handlers/accounts_lua.go b/internal/handlers/accounts_lua.go index 5a2fc23..ee91eb2 100644 --- a/internal/handlers/accounts_lua.go +++ b/internal/handlers/accounts_lua.go @@ -4,6 +4,7 @@ import ( "context" "errors" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "github.com/yuin/gopher-lua" "math/big" "strings" @@ -16,7 +17,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) { ctx := L.Context() - tx, ok := ctx.Value(dbContextKey).(*Tx) + tx, ok := ctx.Value(dbContextKey).(*db.Tx) if !ok { return nil, errors.New("Couldn't find tx in lua's Context") } @@ -150,7 +151,7 @@ func luaAccountBalance(L *lua.LState) int { a := luaCheckAccount(L, 1) ctx := L.Context() - tx, ok := ctx.Value(dbContextKey).(*Tx) + tx, ok := ctx.Value(dbContextKey).(*db.Tx) if !ok { panic("Couldn't find tx in lua's Context") } diff --git a/internal/handlers/common_test.go b/internal/handlers/common_test.go index a0ef7f8..6cbd9ef 100644 --- a/internal/handlers/common_test.go +++ b/internal/handlers/common_test.go @@ -2,12 +2,11 @@ package handlers_test import ( "bytes" - "database/sql" "encoding/json" "github.com/aclindsa/moneygo/internal/config" - "github.com/aclindsa/moneygo/internal/db" "github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "io" "io/ioutil" "log" @@ -253,24 +252,18 @@ func RunTests(m *testing.M) int { dsn = envDSN } - dsn = db.GetDSN(dbType, dsn) - database, err := sql.Open(dbType.String(), dsn) + db, err := db.GetStore(dbType, dsn) if err != nil { log.Fatal(err) } - defer database.Close() + defer db.Close() - dbmap, err := db.GetDbMap(database, dbType) + err = db.DbMap.TruncateTables() if err != nil { log.Fatal(err) } - err = dbmap.TruncateTables() - if err != nil { - log.Fatal(err) - } - - server = httptest.NewTLSServer(&handlers.APIHandler{DB: dbmap}) + server = httptest.NewTLSServer(&handlers.APIHandler{Store: db}) defer server.Close() return m.Run() diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 309eed1..47a58b6 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -1,8 +1,9 @@ package handlers import ( - "github.com/aclindsa/gorp" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" + "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" "path" @@ -16,7 +17,8 @@ type ResponseWriterWriter interface { } type Context struct { - Tx *Tx + Tx *db.Tx + StoreTx store.Tx User *models.User remainingURL string // portion of URL path not yet reached in the hierarchy } @@ -46,11 +48,11 @@ func (c *Context) LastLevel() bool { type Handler func(*http.Request, *Context) ResponseWriterWriter type APIHandler struct { - DB *gorp.DbMap + Store *db.DbStore } func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { - tx, err := GetTx(ah.DB) + tx, err := GetTx(ah.Store.DbMap) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -72,6 +74,33 @@ func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (w }() context.Tx = tx + context.StoreTx = tx + return h(r, context) +} + +func (ah *APIHandler) storeTxWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { + tx, err := ah.Store.Begin() + if err != nil { + log.Print(err) + return NewError(999 /*Internal Error*/) + } + defer func() { + if r := recover(); r != nil { + tx.Rollback() + panic(r) + } + if _, ok := writer.(*Error); ok { + tx.Rollback() + } else { + err = tx.Commit() + if err != nil { + log.Print(err) + writer = NewError(999 /*Internal Error*/) + } + } + }() + + context.StoreTx = tx return h(r, context) } diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 78d5236..08267d3 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -3,6 +3,7 @@ package handlers import ( "encoding/json" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "github.com/aclindsa/ofxgo" "io" "log" @@ -23,7 +24,7 @@ func (od *OFXDownload) Read(json_str string) error { return dec.Decode(od) } -func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter { +func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter { itl, err := ImportOFX(r) if err != nil { diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index c92eeb4..620150d 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -2,12 +2,13 @@ package handlers import ( "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" "time" ) -func CreatePriceIfNotExist(tx *Tx, price *models.Price) error { +func CreatePriceIfNotExist(tx *db.Tx, price *models.Price) error { if len(price.RemoteId) == 0 { // Always create a new price if we can't match on the RemoteId err := tx.Insert(price) @@ -35,7 +36,7 @@ func CreatePriceIfNotExist(tx *Tx, price *models.Price) error { return nil } -func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) { +func GetPrice(tx *db.Tx, priceid, securityid int64) (*models.Price, error) { var p models.Price err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid) if err != nil { @@ -44,7 +45,7 @@ func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) { return &p, nil } -func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) { +func GetPrices(tx *db.Tx, securityid int64) (*[]*models.Price, error) { var prices []*models.Price _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid) @@ -55,7 +56,7 @@ func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) { } // Return the latest price for security in currency units before date -func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { +func GetLatestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { var p models.Price err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { @@ -65,7 +66,7 @@ func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time } // Return the earliest price for security in currency units after date -func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { +func GetEarliestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { var p models.Price err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { @@ -75,7 +76,7 @@ func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Ti } // Return the price for security in currency closest to date -func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { +func GetClosestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { earliest, _ := GetEarliestPrice(tx, security, currency, date) latest, err := GetLatestPrice(tx, security, currency, date) diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index bec9525..46d8061 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "github.com/yuin/gopher-lua" "log" "net/http" @@ -24,7 +25,7 @@ const ( const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for -func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) { +func GetReport(tx *db.Tx, reportid int64, userid int64) (*models.Report, error) { var r models.Report err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) @@ -34,7 +35,7 @@ func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) { return &r, nil } -func GetReports(tx *Tx, userid int64) (*[]models.Report, error) { +func GetReports(tx *db.Tx, userid int64) (*[]models.Report, error) { var reports []models.Report _, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid) @@ -44,7 +45,7 @@ func GetReports(tx *Tx, userid int64) (*[]models.Report, error) { return &reports, nil } -func InsertReport(tx *Tx, r *models.Report) error { +func InsertReport(tx *db.Tx, r *models.Report) error { err := tx.Insert(r) if err != nil { return err @@ -52,7 +53,7 @@ func InsertReport(tx *Tx, r *models.Report) error { return nil } -func UpdateReport(tx *Tx, r *models.Report) error { +func UpdateReport(tx *db.Tx, r *models.Report) error { count, err := tx.Update(r) if err != nil { return err @@ -63,7 +64,7 @@ func UpdateReport(tx *Tx, r *models.Report) error { return nil } -func DeleteReport(tx *Tx, r *models.Report) error { +func DeleteReport(tx *db.Tx, r *models.Report) error { count, err := tx.Delete(r) if err != nil { return err @@ -74,7 +75,7 @@ func DeleteReport(tx *Tx, r *models.Report) error { return nil } -func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabulation, error) { +func runReport(tx *db.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) { // Create a new LState without opening the default libs for security L := lua.NewState(lua.Options{SkipOpenLibs: true}) defer L.Close() @@ -138,7 +139,7 @@ func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabula } } -func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { +func ReportTabulationHandler(tx *db.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { report, err := GetReport(tx, reportid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index ab58de4..e836822 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" "net/url" @@ -50,7 +51,7 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security { return nil } -func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, error) { +func GetSecurity(tx *db.Tx, securityid int64, userid int64) (*models.Security, error) { var s models.Security err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) @@ -60,7 +61,7 @@ func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, erro return &s, nil } -func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) { +func GetSecurities(tx *db.Tx, userid int64) (*[]*models.Security, error) { var securities []*models.Security _, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) @@ -70,7 +71,7 @@ func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) { return &securities, nil } -func InsertSecurity(tx *Tx, s *models.Security) error { +func InsertSecurity(tx *db.Tx, s *models.Security) error { err := tx.Insert(s) if err != nil { return err @@ -78,7 +79,7 @@ func InsertSecurity(tx *Tx, s *models.Security) error { return nil } -func UpdateSecurity(tx *Tx, s *models.Security) (err error) { +func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) { user, err := GetUser(tx, s.UserId) if err != nil { return @@ -105,7 +106,7 @@ func (e SecurityInUseError) Error() string { return e.message } -func DeleteSecurity(tx *Tx, s *models.Security) error { +func DeleteSecurity(tx *db.Tx, s *models.Security) error { // First, ensure no accounts are using this security accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) @@ -138,7 +139,7 @@ func DeleteSecurity(tx *Tx, s *models.Security) error { return nil } -func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*models.Security, error) { +func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) (*models.Security, error) { security.UserId = userid if len(security.AlternateId) == 0 { // Always create a new local security if we can't match on the AlternateId diff --git a/internal/handlers/securities_lua.go b/internal/handlers/securities_lua.go index 12783ce..a5307c6 100644 --- a/internal/handlers/securities_lua.go +++ b/internal/handlers/securities_lua.go @@ -4,6 +4,7 @@ import ( "context" "errors" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "github.com/yuin/gopher-lua" ) @@ -14,7 +15,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error) ctx := L.Context() - tx, ok := ctx.Value(dbContextKey).(*Tx) + tx, ok := ctx.Value(dbContextKey).(*db.Tx) if !ok { return nil, errors.New("Couldn't find tx in lua's Context") } @@ -158,7 +159,7 @@ func luaClosestPrice(L *lua.LState) int { date := luaCheckTime(L, 3) ctx := L.Context() - tx, ok := ctx.Value(dbContextKey).(*Tx) + tx, ok := ctx.Value(dbContextKey).(*db.Tx) if !ok { panic("Couldn't find tx in lua's Context") } diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index 8349613..f4d6c5e 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -3,36 +3,37 @@ package handlers import ( "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" "log" "net/http" "time" ) -func GetSession(tx *Tx, r *http.Request) (*models.Session, error) { - var s models.Session - +func GetSession(tx store.Tx, r *http.Request) (*models.Session, error) { cookie, err := r.Cookie("moneygo-session") if err != nil { return nil, fmt.Errorf("moneygo-session cookie not set") } - s.SessionSecret = cookie.Value - err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) + s, err := tx.GetSession(cookie.Value) if err != nil { return nil, err } if s.Expires.Before(time.Now()) { - tx.Delete(&s) + err := tx.DeleteSession(s) + if err != nil { + log.Printf("Unexpected error when attempting to delete expired session: %s", err) + } return nil, fmt.Errorf("Session has expired") } - return &s, nil + return s, nil } -func DeleteSessionIfExists(tx *Tx, r *http.Request) error { +func DeleteSessionIfExists(tx store.Tx, r *http.Request) error { session, err := GetSession(tx, r) if err == nil { - _, err := tx.Delete(session) + err := tx.DeleteSession(session) if err != nil { return err } @@ -50,21 +51,26 @@ func (n *NewSessionWriter) Write(w http.ResponseWriter) error { return n.session.Write(w) } -func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) { +func NewSession(tx store.Tx, r *http.Request, userid int64) (*NewSessionWriter, error) { + err := DeleteSessionIfExists(tx, r) + if err != nil { + return nil, err + } + s, err := models.NewSession(userid) if err != nil { return nil, err } - existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", s.SessionSecret) + exists, err := tx.SessionExists(s.SessionSecret) if err != nil { return nil, err } - if existing > 0 { - return nil, fmt.Errorf("%d session(s) exist with the generated session_secret", existing) + if exists { + return nil, fmt.Errorf("Session already exists with the generated session_secret") } - err = tx.Insert(s) + err = tx.InsertSession(s) if err != nil { return nil, err } @@ -89,27 +95,21 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(2 /*Unauthorized Access*/) } - err = DeleteSessionIfExists(context.Tx, r) - if err != nil { - log.Print(err) - return NewError(999 /*Internal Error*/) - } - - sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId) + sessionwriter, err := NewSession(context.StoreTx, r, dbuser.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) } return sessionwriter } else if r.Method == "GET" { - s, err := GetSession(context.Tx, r) + s, err := GetSession(context.StoreTx, r) if err != nil { return NewError(1 /*Not Signed In*/) } return s } else if r.Method == "DELETE" { - err := DeleteSessionIfExists(context.Tx, r) + err := DeleteSessionIfExists(context.StoreTx, r) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 3795ce7..4ec94d1 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "log" "math/big" "net/http" @@ -12,14 +13,14 @@ import ( "time" ) -func SplitAlreadyImported(tx *Tx, s *models.Split) (bool, error) { +func SplitAlreadyImported(tx *db.Tx, s *models.Split) (bool, error) { count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) return count == 1, err } // Return a map of security ID's to big.Rat's containing the amount that // security is imbalanced by -func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, error) { +func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.Rat, error) { sums := make(map[int64]big.Rat) if !t.Valid() { @@ -47,7 +48,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, // Returns true if all securities contained in this transaction are balanced, // false otherwise -func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) { +func TransactionBalanced(tx *db.Tx, t *models.Transaction) (bool, error) { var zero big.Rat sums, err := GetTransactionImbalances(tx, t) @@ -63,7 +64,7 @@ func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) { return true, nil } -func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transaction, error) { +func GetTransaction(tx *db.Tx, transactionid int64, userid int64) (*models.Transaction, error) { var t models.Transaction err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) @@ -79,7 +80,7 @@ func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transact return &t, nil } -func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) { +func GetTransactions(tx *db.Tx, userid int64) (*[]models.Transaction, error) { var transactions []models.Transaction _, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) @@ -97,7 +98,7 @@ func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) { return &transactions, nil } -func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error { +func incrementAccountVersions(tx *db.Tx, user *models.User, accountids []int64) error { for i := range accountids { account, err := GetAccount(tx, accountids[i], user.UserId) if err != nil { @@ -121,7 +122,7 @@ func (ame AccountMissingError) Error() string { return "Account missing" } -func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error { +func InsertTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error { // Map of any accounts with transaction splits being added a_map := make(map[int64]bool) for i := range t.Splits { @@ -171,7 +172,7 @@ func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error { return nil } -func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error { +func UpdateTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error { var existing_splits []*models.Split _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) @@ -248,7 +249,7 @@ func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error { return nil } -func DeleteTransaction(tx *Tx, t *models.Transaction, user *models.User) error { +func DeleteTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error { var accountids []int64 _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) if err != nil { @@ -401,7 +402,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return NewError(3 /*Invalid Request*/) } -func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) { +func TransactionsBalanceDifference(tx *db.Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) { var pageDifference, tmp big.Rat for i := range transactions { _, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) @@ -425,7 +426,7 @@ func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []model return &pageDifference, nil } -func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) { +func GetAccountBalance(tx *db.Tx, user *models.User, accountid int64) (*big.Rat, error) { var splits []models.Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" @@ -448,7 +449,7 @@ func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, er } // Assumes accountid is valid and is owned by the current user -func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) { +func GetAccountBalanceDate(tx *db.Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) { var splits []models.Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" @@ -470,7 +471,7 @@ func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *tim return &balance, nil } -func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) { +func GetAccountBalanceDateRange(tx *db.Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) { var splits []models.Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" @@ -492,7 +493,7 @@ func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begi return &balance, nil } -func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) { +func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) { var transactions []models.Transaction var atl models.AccountTransactionsList diff --git a/internal/handlers/tx.go b/internal/handlers/tx.go index c0db452..750220a 100644 --- a/internal/handlers/tx.go +++ b/internal/handlers/tx.go @@ -1,65 +1,14 @@ package handlers import ( - "database/sql" "github.com/aclindsa/gorp" - "strings" + "github.com/aclindsa/moneygo/internal/store/db" ) -type Tx struct { - Dialect gorp.Dialect - Tx *gorp.Transaction -} - -func (tx *Tx) Rebind(query string) string { - chunks := strings.Split(query, "?") - str := chunks[0] - for i := 1; i < len(chunks); i++ { - str += tx.Dialect.BindVar(i-1) + chunks[i] - } - return str -} - -func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { - return tx.Tx.Select(i, tx.Rebind(query), args...) -} - -func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { - return tx.Tx.Exec(tx.Rebind(query), args...) -} - -func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) { - return tx.Tx.SelectInt(tx.Rebind(query), args...) -} - -func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error { - return tx.Tx.SelectOne(holder, tx.Rebind(query), args...) -} - -func (tx *Tx) Insert(list ...interface{}) error { - return tx.Tx.Insert(list...) -} - -func (tx *Tx) Update(list ...interface{}) (int64, error) { - return tx.Tx.Update(list...) -} - -func (tx *Tx) Delete(list ...interface{}) (int64, error) { - return tx.Tx.Delete(list...) -} - -func (tx *Tx) Commit() error { - return tx.Tx.Commit() -} - -func (tx *Tx) Rollback() error { - return tx.Tx.Rollback() -} - -func GetTx(db *gorp.DbMap) (*Tx, error) { - tx, err := db.Begin() +func GetTx(gdb *gorp.DbMap) (*db.Tx, error) { + tx, err := gdb.Begin() if err != nil { return nil, err } - return &Tx{db.Dialect, tx}, nil + return &db.Tx{gdb.Dialect, tx}, nil } diff --git a/internal/handlers/users.go b/internal/handlers/users.go index ba1a9d0..7225fe5 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" ) @@ -14,7 +15,7 @@ func (ueu UserExistsError) Error() string { return "User exists" } -func GetUser(tx *Tx, userid int64) (*models.User, error) { +func GetUser(tx *db.Tx, userid int64) (*models.User, error) { var u models.User err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) @@ -24,7 +25,7 @@ func GetUser(tx *Tx, userid int64) (*models.User, error) { return &u, nil } -func GetUserByUsername(tx *Tx, username string) (*models.User, error) { +func GetUserByUsername(tx *db.Tx, username string) (*models.User, error) { var u models.User err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) @@ -34,7 +35,7 @@ func GetUserByUsername(tx *Tx, username string) (*models.User, error) { return &u, nil } -func InsertUser(tx *Tx, u *models.User) error { +func InsertUser(tx *db.Tx, u *models.User) error { security_template := FindCurrencyTemplate(u.DefaultCurrency) if security_template == nil { return errors.New("Invalid ISO4217 Default Currency") @@ -75,7 +76,7 @@ func InsertUser(tx *Tx, u *models.User) error { return nil } -func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) { +func GetUserFromSession(tx *db.Tx, r *http.Request) (*models.User, error) { s, err := GetSession(tx, r) if err != nil { return nil, err @@ -83,7 +84,7 @@ func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) { return GetUser(tx, s.UserId) } -func UpdateUser(tx *Tx, u *models.User) error { +func UpdateUser(tx *db.Tx, u *models.User) error { security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId) if err != nil { return err @@ -103,7 +104,7 @@ func UpdateUser(tx *Tx, u *models.User) error { return nil } -func DeleteUser(tx *Tx, u *models.User) error { +func DeleteUser(tx *db.Tx, u *models.User) error { count, err := tx.Delete(u) if err != nil { return err diff --git a/internal/db/db.go b/internal/store/db/db.go similarity index 74% rename from internal/db/db.go rename to internal/store/db/db.go index a33fb08..3a4031e 100644 --- a/internal/db/db.go +++ b/internal/store/db/db.go @@ -6,6 +6,7 @@ import ( "github.com/aclindsa/gorp" "github.com/aclindsa/moneygo/internal/config" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -60,3 +61,40 @@ func GetDSN(dbtype config.DbType, dsn string) string { } return dsn } + +type DbStore struct { + DbMap *gorp.DbMap +} + +func (db *DbStore) Begin() (store.Tx, error) { + tx, err := db.DbMap.Begin() + if err != nil { + return nil, err + } + return &Tx{db.DbMap.Dialect, tx}, nil +} + +func (db *DbStore) Close() error { + err := db.DbMap.Db.Close() + db.DbMap = nil + return err +} + +func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) { + dsn = GetDSN(dbtype, dsn) + database, err := sql.Open(dbtype.String(), dsn) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + database.Close() + } + }() + + dbmap, err := GetDbMap(database, dbtype) + if err != nil { + return nil, err + } + return &DbStore{dbmap}, nil +} diff --git a/internal/store/db/sessions.go b/internal/store/db/sessions.go new file mode 100644 index 0000000..671c55a --- /dev/null +++ b/internal/store/db/sessions.go @@ -0,0 +1,36 @@ +package db + +import ( + "fmt" + "github.com/aclindsa/moneygo/internal/models" + "time" +) + +func (tx *Tx) InsertSession(session *models.Session) error { + return tx.Insert(session) +} + +func (tx *Tx) GetSession(secret string) (*models.Session, error) { + var s models.Session + + err := tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", secret) + if err != nil { + return nil, err + } + + if s.Expires.Before(time.Now()) { + tx.Delete(&s) + return nil, fmt.Errorf("Session has expired") + } + return &s, nil +} + +func (tx *Tx) SessionExists(secret string) (bool, error) { + existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", secret) + return existing != 0, err +} + +func (tx *Tx) DeleteSession(session *models.Session) error { + _, err := tx.Delete(session) + return err +} diff --git a/internal/store/db/tx.go b/internal/store/db/tx.go new file mode 100644 index 0000000..5187201 --- /dev/null +++ b/internal/store/db/tx.go @@ -0,0 +1,57 @@ +package db + +import ( + "database/sql" + "github.com/aclindsa/gorp" + "strings" +) + +type Tx struct { + Dialect gorp.Dialect + Tx *gorp.Transaction +} + +func (tx *Tx) Rebind(query string) string { + chunks := strings.Split(query, "?") + str := chunks[0] + for i := 1; i < len(chunks); i++ { + str += tx.Dialect.BindVar(i-1) + chunks[i] + } + return str +} + +func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { + return tx.Tx.Select(i, tx.Rebind(query), args...) +} + +func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { + return tx.Tx.Exec(tx.Rebind(query), args...) +} + +func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) { + return tx.Tx.SelectInt(tx.Rebind(query), args...) +} + +func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error { + return tx.Tx.SelectOne(holder, tx.Rebind(query), args...) +} + +func (tx *Tx) Insert(list ...interface{}) error { + return tx.Tx.Insert(list...) +} + +func (tx *Tx) Update(list ...interface{}) (int64, error) { + return tx.Tx.Update(list...) +} + +func (tx *Tx) Delete(list ...interface{}) (int64, error) { + return tx.Tx.Delete(list...) +} + +func (tx *Tx) Commit() error { + return tx.Tx.Commit() +} + +func (tx *Tx) Rollback() error { + return tx.Tx.Rollback() +} diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..9823236 --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,24 @@ +package store + +import ( + "github.com/aclindsa/moneygo/internal/models" +) + +type SessionStore interface { + InsertSession(session *models.Session) error + GetSession(secret string) (*models.Session, error) + SessionExists(secret string) (bool, error) + DeleteSession(session *models.Session) error +} + +type Tx interface { + Commit() error + Rollback() error + + SessionStore +} + +type Store interface { + Begin() (Tx, error) + Close() error +} diff --git a/main.go b/main.go index 2baf7c0..d87b5fa 100644 --- a/main.go +++ b/main.go @@ -3,11 +3,10 @@ package main //go:generate make import ( - "database/sql" "flag" "github.com/aclindsa/moneygo/internal/config" - "github.com/aclindsa/moneygo/internal/db" "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/store/db" "github.com/kabukky/httpscerts" "log" "net" @@ -67,21 +66,15 @@ func staticHandler(w http.ResponseWriter, r *http.Request, basedir string) { } func main() { - dsn := db.GetDSN(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN) - database, err := sql.Open(cfg.MoneyGo.DBType.String(), dsn) - if err != nil { - log.Fatal(err) - } - defer database.Close() - - dbmap, err := db.GetDbMap(database, cfg.MoneyGo.DBType) + db, err := db.GetStore(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN) if err != nil { log.Fatal(err) } + defer db.Close() // Get ServeMux for API and add our own handlers for files servemux := http.NewServeMux() - servemux.Handle("/v1/", &handlers.APIHandler{DB: dbmap}) + servemux.Handle("/v1/", &handlers.APIHandler{Store: db}) servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir)) servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir)) From bec5152e53b9c8b9eaef22505832764bed17af76 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Thu, 7 Dec 2017 20:08:43 -0500 Subject: [PATCH 2/9] Move users and securities to store --- internal/handlers/accounts.go | 10 +-- internal/handlers/imports.go | 2 +- internal/handlers/prices.go | 8 +-- internal/handlers/securities.go | 105 ++++------------------------ internal/handlers/securities_lua.go | 2 +- internal/handlers/sessions.go | 7 +- internal/handlers/transactions.go | 2 +- internal/handlers/users.go | 91 ++++-------------------- internal/store/db/securities.go | 95 +++++++++++++++++++++++++ internal/store/db/sessions.go | 10 ++- internal/store/db/users.go | 86 +++++++++++++++++++++++ internal/store/store.go | 22 +++++- 12 files changed, 255 insertions(+), 185 deletions(-) create mode 100644 internal/store/db/securities.go create mode 100644 internal/store/db/users.go diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index 9ef9521..d5ddcd8 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -62,7 +62,7 @@ func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Accou var tradingAccount models.Account var account models.Account - user, err := GetUser(tx, userid) + user, err := tx.GetUser(userid) if err != nil { return nil, err } @@ -79,7 +79,7 @@ func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Accou return nil, err } - security, err := GetSecurity(tx, securityid, userid) + security, err := tx.GetSecurity(securityid, userid) if err != nil { return nil, err } @@ -124,7 +124,7 @@ func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Acc return nil, err } - security, err := GetSecurity(tx, securityid, userid) + security, err := tx.GetSecurity(securityid, userid) if err != nil { return nil, err } @@ -280,7 +280,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { account.UserId = user.UserId account.AccountVersion = 0 - security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId) + security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -341,7 +341,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { } account.UserId = user.UserId - security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId) + security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 08267d3..ea74042 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -159,7 +159,7 @@ func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) split := new(models.Split) r := new(big.Rat) r.Neg(&imbalance) - security, err := GetSecurity(tx, imbalanced_security, user.UserId) + security, err := tx.GetSecurity(imbalanced_security, user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index 620150d..f08d37e 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -97,7 +97,7 @@ func GetClosestPrice(tx *db.Tx, security, currency *models.Security, date *time. } func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter { - security, err := GetSecurity(context.Tx, securityid, user.UserId) + security, err := context.Tx.GetSecurity(securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -112,7 +112,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security if price.SecurityId != security.SecurityId { return NewError(3 /*Invalid Request*/) } - _, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId) + _, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -161,11 +161,11 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security return NewError(3 /*Invalid Request*/) } - _, err = GetSecurity(context.Tx, price.SecurityId, user.UserId) + _, err = context.Tx.GetSecurity(price.SecurityId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - _, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId) + _, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index e836822..ec9544b 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -4,7 +4,6 @@ package handlers import ( "errors" - "fmt" "github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/store/db" "log" @@ -51,90 +50,18 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security { return nil } -func GetSecurity(tx *db.Tx, securityid int64, userid int64) (*models.Security, error) { - var s models.Security - - err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) - if err != nil { - return nil, err - } - return &s, nil -} - -func GetSecurities(tx *db.Tx, userid int64) (*[]*models.Security, error) { - var securities []*models.Security - - _, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) - if err != nil { - return nil, err - } - return &securities, nil -} - -func InsertSecurity(tx *db.Tx, s *models.Security) error { - err := tx.Insert(s) - if err != nil { - return err - } - return nil -} - func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) { - user, err := GetUser(tx, s.UserId) + user, err := tx.GetUser(s.UserId) if err != nil { return } else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency { return errors.New("Cannot change security which is user's default currency to be non-currency") } - count, err := tx.Update(s) + err = tx.UpdateSecurity(s) if err != nil { return } - if count > 1 { - return fmt.Errorf("Updated %d securities (expected 1)", count) - } - - return nil -} - -type SecurityInUseError struct { - message string -} - -func (e SecurityInUseError) Error() string { - return e.message -} - -func DeleteSecurity(tx *db.Tx, s *models.Security) error { - // First, ensure no accounts are using this security - accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) - - if accounts != 0 { - return SecurityInUseError{"One or more accounts still use this security"} - } - - user, err := GetUser(tx, s.UserId) - if err != nil { - return err - } else if user.DefaultCurrency == s.SecurityId { - return SecurityInUseError{"Cannot delete security which is user's default currency"} - } - - // Remove all prices involving this security (either of this security, or - // using it as a currency) - _, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) - if err != nil { - return err - } - - count, err := tx.Delete(s) - if err != nil { - return err - } - if count != 1 { - return errors.New("Deleted more than one security") - } return nil } @@ -143,16 +70,14 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) security.UserId = userid if len(security.AlternateId) == 0 { // Always create a new local security if we can't match on the AlternateId - err := InsertSecurity(tx, security) + err := tx.InsertSecurity(security) if err != nil { return nil, err } return security, nil } - var securities []*models.Security - - _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision) + securities, err := tx.FindMatchingSecurities(userid, security) if err != nil { return nil, err } @@ -160,7 +85,7 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) // First try to find a case insensitive match on the name or symbol upperName := strings.ToUpper(security.Name) upperSymbol := strings.ToUpper(security.Symbol) - for _, s := range securities { + for _, s := range *securities { if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) || (len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) { return s, nil @@ -169,7 +94,7 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) // if strings.Contains(strings.ToUpper(security.Name), upperSearch) || // Try to find a partial string match on the name or symbol - for _, s := range securities { + for _, s := range *securities { sUpperName := strings.ToUpper(s.Name) sUpperSymbol := strings.ToUpper(s.Symbol) if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) || @@ -179,12 +104,12 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) } // Give up and return the first security in the list - if len(securities) > 0 { - return securities[0], nil + if len(*securities) > 0 { + return (*securities)[0], nil } // If there wasn't even one security in the list, make a new one - err = InsertSecurity(tx, security) + err = tx.InsertSecurity(security) if err != nil { return nil, err } @@ -217,7 +142,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { security.SecurityId = -1 security.UserId = user.UserId - err = InsertSecurity(context.Tx, &security) + err = context.Tx.InsertSecurity(&security) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -229,7 +154,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { //Return all securities var sl models.SecurityList - securities, err := GetSecurities(context.Tx, user.UserId) + securities, err := context.Tx.GetSecurities(user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -250,7 +175,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { return PriceHandler(r, context, user, securityid) } - security, err := GetSecurity(context.Tx, securityid, user.UserId) + security, err := context.Tx.GetSecurity(securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -284,13 +209,13 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { return &security } else if r.Method == "DELETE" { - security, err := GetSecurity(context.Tx, securityid, user.UserId) + security, err := context.Tx.GetSecurity(securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - err = DeleteSecurity(context.Tx, security) - if _, ok := err.(SecurityInUseError); ok { + err = context.Tx.DeleteSecurity(security) + if _, ok := err.(db.SecurityInUseError); ok { return NewError(7 /*In Use Error*/) } else if err != nil { log.Print(err) diff --git a/internal/handlers/securities_lua.go b/internal/handlers/securities_lua.go index a5307c6..78716f2 100644 --- a/internal/handlers/securities_lua.go +++ b/internal/handlers/securities_lua.go @@ -27,7 +27,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error) return nil, errors.New("Couldn't find User in lua's Context") } - securities, err := GetSecurities(tx, user.UserId) + securities, err := tx.GetSecurities(user.UserId) if err != nil { return nil, err } diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index f4d6c5e..71deff4 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -85,12 +85,15 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - dbuser, err := GetUserByUsername(context.Tx, user.Username) + // Hash password before checking username to help mitigate timing + // attacks + user.HashPassword() + + dbuser, err := context.StoreTx.GetUserByUsername(user.Username) if err != nil { return NewError(2 /*Unauthorized Access*/) } - user.HashPassword() if user.PasswordHash != dbuser.PasswordHash { return NewError(2 /*Unauthorized Access*/) } diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 4ec94d1..707d29a 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -542,7 +542,7 @@ func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort } atl.TotalTransactions = count - security, err := GetSecurity(tx, atl.Account.SecurityId, user.UserId) + security, err := tx.GetSecurity(atl.Account.SecurityId, user.UserId) if err != nil { return nil, err } diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 7225fe5..e9a468d 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -2,9 +2,8 @@ package handlers import ( "errors" - "fmt" "github.com/aclindsa/moneygo/internal/models" - "github.com/aclindsa/moneygo/internal/store/db" + "github.com/aclindsa/moneygo/internal/store" "log" "net/http" ) @@ -15,41 +14,21 @@ func (ueu UserExistsError) Error() string { return "User exists" } -func GetUser(tx *db.Tx, userid int64) (*models.User, error) { - var u models.User - - err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) - if err != nil { - return nil, err - } - return &u, nil -} - -func GetUserByUsername(tx *db.Tx, username string) (*models.User, error) { - var u models.User - - err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) - if err != nil { - return nil, err - } - return &u, nil -} - -func InsertUser(tx *db.Tx, u *models.User) error { +func InsertUser(tx store.Tx, u *models.User) error { security_template := FindCurrencyTemplate(u.DefaultCurrency) if security_template == nil { return errors.New("Invalid ISO4217 Default Currency") } - existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username) + exists, err := tx.UsernameExists(u.Username) if err != nil { return err } - if existing > 0 { + if exists { return UserExistsError{} } - err = tx.Insert(u) + err = tx.InsertUser(u) if err != nil { return err } @@ -59,33 +38,31 @@ func InsertUser(tx *db.Tx, u *models.User) error { security = *security_template security.UserId = u.UserId - err = InsertSecurity(tx, &security) + err = tx.InsertSecurity(&security) if err != nil { return err } // Update the user's DefaultCurrency to our new SecurityId u.DefaultCurrency = security.SecurityId - count, err := tx.Update(u) + err = tx.UpdateUser(u) if err != nil { return err - } else if count != 1 { - return errors.New("Would have updated more than one user") } return nil } -func GetUserFromSession(tx *db.Tx, r *http.Request) (*models.User, error) { +func GetUserFromSession(tx store.Tx, r *http.Request) (*models.User, error) { s, err := GetSession(tx, r) if err != nil { return nil, err } - return GetUser(tx, s.UserId) + return tx.GetUser(s.UserId) } -func UpdateUser(tx *db.Tx, u *models.User) error { - security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId) +func UpdateUser(tx store.Tx, u *models.User) error { + security, err := tx.GetSecurity(u.DefaultCurrency, u.UserId) if err != nil { return err } else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency { @@ -94,49 +71,7 @@ func UpdateUser(tx *db.Tx, u *models.User) error { return errors.New("New DefaultCurrency security is not a currency") } - count, err := tx.Update(u) - if err != nil { - return err - } else if count != 1 { - return errors.New("Would have updated more than one user") - } - - return nil -} - -func DeleteUser(tx *db.Tx, u *models.User) error { - count, err := tx.Delete(u) - if err != nil { - return err - } - if count != 1 { - return fmt.Errorf("No user to delete") - } - _, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", u.UserId) - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", u.UserId) - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId) - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId) - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId) - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId) - if err != nil { - return err - } - _, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId) + err = tx.UpdateUser(u) if err != nil { return err } @@ -205,7 +140,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter { return user } else if r.Method == "DELETE" { - err := DeleteUser(context.Tx, user) + err := context.StoreTx.DeleteUser(user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/store/db/securities.go b/internal/store/db/securities.go new file mode 100644 index 0000000..0acba29 --- /dev/null +++ b/internal/store/db/securities.go @@ -0,0 +1,95 @@ +package db + +import ( + "fmt" + "github.com/aclindsa/moneygo/internal/models" +) + +type SecurityInUseError struct { + message string +} + +func (e SecurityInUseError) Error() string { + return e.message +} + +func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) { + var s models.Security + + err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) + if err != nil { + return nil, err + } + return &s, nil +} + +func (tx *Tx) GetSecurities(userid int64) (*[]*models.Security, error) { + var securities []*models.Security + + _, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) + if err != nil { + return nil, err + } + return &securities, nil +} + +func (tx *Tx) FindMatchingSecurities(userid int64, security *models.Security) (*[]*models.Security, error) { + var securities []*models.Security + + _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision) + if err != nil { + return nil, err + } + return &securities, nil +} + +func (tx *Tx) InsertSecurity(s *models.Security) error { + err := tx.Insert(s) + if err != nil { + return err + } + return nil +} + +func (tx *Tx) UpdateSecurity(security *models.Security) error { + count, err := tx.Update(security) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to update 1 security, was going to update %d", count) + } + return nil +} + +func (tx *Tx) DeleteSecurity(s *models.Security) error { + // First, ensure no accounts are using this security + accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) + + if accounts != 0 { + return SecurityInUseError{"One or more accounts still use this security"} + } + + user, err := tx.GetUser(s.UserId) + if err != nil { + return err + } else if user.DefaultCurrency == s.SecurityId { + return SecurityInUseError{"Cannot delete security which is user's default currency"} + } + + // Remove all prices involving this security (either of this security, or + // using it as a currency) + _, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId) + if err != nil { + return err + } + + count, err := tx.Delete(s) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to delete 1 security, was going to delete %d", count) + } + return nil +} diff --git a/internal/store/db/sessions.go b/internal/store/db/sessions.go index 671c55a..57a5ad5 100644 --- a/internal/store/db/sessions.go +++ b/internal/store/db/sessions.go @@ -31,6 +31,12 @@ func (tx *Tx) SessionExists(secret string) (bool, error) { } func (tx *Tx) DeleteSession(session *models.Session) error { - _, err := tx.Delete(session) - return err + count, err := tx.Delete(session) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count) + } + return nil } diff --git a/internal/store/db/users.go b/internal/store/db/users.go new file mode 100644 index 0000000..2a44b23 --- /dev/null +++ b/internal/store/db/users.go @@ -0,0 +1,86 @@ +package db + +import ( + "fmt" + "github.com/aclindsa/moneygo/internal/models" +) + +func (tx *Tx) UsernameExists(username string) (bool, error) { + existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", username) + return existing != 0, err +} + +func (tx *Tx) InsertUser(user *models.User) error { + return tx.Insert(user) +} + +func (tx *Tx) GetUser(userid int64) (*models.User, error) { + var u models.User + + err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) + if err != nil { + return nil, err + } + return &u, nil +} + +func (tx *Tx) GetUserByUsername(username string) (*models.User, error) { + var u models.User + + err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) + if err != nil { + return nil, err + } + return &u, nil +} + +func (tx *Tx) UpdateUser(user *models.User) error { + count, err := tx.Update(user) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to update 1 user, was going to update %d", count) + } + return nil +} + +func (tx *Tx) DeleteUser(user *models.User) error { + count, err := tx.Delete(user) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count) + } + _, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", user.UserId) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", user.UserId) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", user.UserId) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", user.UserId) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", user.UserId) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", user.UserId) + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", user.UserId) + if err != nil { + return err + } + + return nil +} diff --git a/internal/store/store.go b/internal/store/store.go index 9823236..86d6c66 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -5,17 +5,37 @@ import ( ) type SessionStore interface { + SessionExists(secret string) (bool, error) InsertSession(session *models.Session) error GetSession(secret string) (*models.Session, error) - SessionExists(secret string) (bool, error) DeleteSession(session *models.Session) error } +type UserStore interface { + UsernameExists(username string) (bool, error) + InsertUser(user *models.User) error + GetUser(userid int64) (*models.User, error) + GetUserByUsername(username string) (*models.User, error) + UpdateUser(user *models.User) error + DeleteUser(user *models.User) error +} + +type SecurityStore interface { + InsertSecurity(security *models.Security) error + GetSecurity(securityid int64, userid int64) (*models.Security, error) + GetSecurities(userid int64) (*[]*models.Security, error) + FindMatchingSecurities(userid int64, security *models.Security) (*[]*models.Security, error) + UpdateSecurity(security *models.Security) error + DeleteSecurity(security *models.Security) error +} + type Tx interface { Commit() error Rollback() error SessionStore + UserStore + SecurityStore } type Store interface { From 3326c3b29284ded6ffd7b8ca61660077bdb5ae7f Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Thu, 7 Dec 2017 20:47:55 -0500 Subject: [PATCH 3/9] Move accounts to store --- internal/handlers/accounts.go | 161 +++--------------------------- internal/handlers/accounts_lua.go | 4 +- internal/handlers/gnucash_test.go | 18 ++-- internal/handlers/imports.go | 4 +- internal/handlers/ofx_test.go | 2 +- internal/handlers/securities.go | 5 +- internal/handlers/transactions.go | 10 +- internal/models/accounts.go | 2 +- internal/store/db/accounts.go | 133 ++++++++++++++++++++++++ internal/store/db/securities.go | 17 +--- internal/store/store.go | 38 ++++++- 11 files changed, 211 insertions(+), 183 deletions(-) create mode 100644 internal/store/db/accounts.go diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index d5ddcd8..60a2199 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -3,44 +3,23 @@ package handlers import ( "errors" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" ) -func GetAccount(tx *db.Tx, accountid int64, userid int64) (*models.Account, error) { - var a models.Account - - err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) - if err != nil { - return nil, err - } - return &a, nil -} - -func GetAccounts(tx *db.Tx, userid int64) (*[]models.Account, error) { - var accounts []models.Account - - _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) - if err != nil { - return nil, err - } - return &accounts, nil -} - // Get (and attempt to create if it doesn't exist). Matches on UserId, // SecurityId, Type, Name, and ParentAccountId func GetCreateAccount(tx *db.Tx, a models.Account) (*models.Account, error) { - var accounts []models.Account var account models.Account - // Try to find the top-level trading account - _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId) + accounts, err := tx.FindMatchingAccounts(&a) if err != nil { return nil, err } - if len(accounts) == 1 { - account = accounts[0] + if len(*accounts) > 0 { + account = *(*accounts)[0] } else { account.UserId = a.UserId account.SecurityId = a.SecurityId @@ -143,120 +122,6 @@ func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Acc return a, nil } -type ParentAccountMissingError struct{} - -func (pame ParentAccountMissingError) Error() string { - return "Parent account missing" -} - -type TooMuchNestingError struct{} - -func (tmne TooMuchNestingError) Error() string { - return "Too much nesting" -} - -type CircularAccountsError struct{} - -func (cae CircularAccountsError) Error() string { - return "Would result in circular account relationship" -} - -func insertUpdateAccount(tx *db.Tx, a *models.Account, insert bool) error { - found := make(map[int64]bool) - if !insert { - found[a.AccountId] = true - } - parentid := a.ParentAccountId - depth := 0 - for parentid != -1 { - depth += 1 - if depth > 100 { - return TooMuchNestingError{} - } - - var a models.Account - err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) - if err != nil { - return ParentAccountMissingError{} - } - - // Insertion by itself can never result in circular dependencies - if insert { - break - } - - found[parentid] = true - parentid = a.ParentAccountId - if _, ok := found[parentid]; ok { - return CircularAccountsError{} - } - } - - if insert { - err := tx.Insert(a) - if err != nil { - return err - } - } else { - oldacct, err := GetAccount(tx, a.AccountId, a.UserId) - if err != nil { - return err - } - - a.AccountVersion = oldacct.AccountVersion + 1 - - count, err := tx.Update(a) - if err != nil { - return err - } - if count != 1 { - return errors.New("Updated more than one account") - } - } - - return nil -} - -func InsertAccount(tx *db.Tx, a *models.Account) error { - return insertUpdateAccount(tx, a, true) -} - -func UpdateAccount(tx *db.Tx, a *models.Account) error { - return insertUpdateAccount(tx, a, false) -} - -func DeleteAccount(tx *db.Tx, a *models.Account) error { - if a.ParentAccountId != -1 { - // Re-parent splits to this account's parent account if this account isn't a root account - _, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) - if err != nil { - return err - } - } else { - // Delete splits if this account is a root account - _, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId) - if err != nil { - return err - } - } - - // Re-parent child accounts to this account's parent account - _, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId) - if err != nil { - return err - } - - count, err := tx.Delete(a) - if err != nil { - return err - } - if count != 1 { - return errors.New("Was going to delete more than one account") - } - - return nil -} - func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { user, err := GetUserFromSession(context.Tx, r) if err != nil { @@ -289,9 +154,9 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - err = InsertAccount(context.Tx, &account) + err = context.Tx.InsertAccount(&account) if err != nil { - if _, ok := err.(ParentAccountMissingError); ok { + if _, ok := err.(store.ParentAccountMissingError); ok { return NewError(3 /*Invalid Request*/) } else { log.Print(err) @@ -304,7 +169,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { if context.LastLevel() { //Return all Accounts var al models.AccountList - accounts, err := GetAccounts(context.Tx, user.UserId) + accounts, err := context.Tx.GetAccounts(user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -320,7 +185,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { if context.LastLevel() { // Return Account with this Id - account, err := GetAccount(context.Tx, accountid, user.UserId) + account, err := context.Tx.GetAccount(accountid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -354,11 +219,11 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - err = UpdateAccount(context.Tx, &account) + err = context.Tx.UpdateAccount(&account) if err != nil { - if _, ok := err.(ParentAccountMissingError); ok { + if _, ok := err.(store.ParentAccountMissingError); ok { return NewError(3 /*Invalid Request*/) - } else if _, ok := err.(CircularAccountsError); ok { + } else if _, ok := err.(store.CircularAccountsError); ok { return NewError(3 /*Invalid Request*/) } else { log.Print(err) @@ -368,12 +233,12 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { return &account } else if r.Method == "DELETE" { - account, err := GetAccount(context.Tx, accountid, user.UserId) + account, err := context.Tx.GetAccount(accountid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - err = DeleteAccount(context.Tx, account) + err = context.Tx.DeleteAccount(account) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/accounts_lua.go b/internal/handlers/accounts_lua.go index ee91eb2..2036e62 100644 --- a/internal/handlers/accounts_lua.go +++ b/internal/handlers/accounts_lua.go @@ -29,14 +29,14 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) { return nil, errors.New("Couldn't find User in lua's Context") } - accounts, err := GetAccounts(tx, user.UserId) + accounts, err := tx.GetAccounts(user.UserId) if err != nil { return nil, err } account_map = make(map[int64]*models.Account) for i := range *accounts { - account_map[(*accounts)[i].AccountId] = &(*accounts)[i] + account_map[(*accounts)[i].AccountId] = (*accounts)[i] } ctx = context.WithValue(ctx, accountsContextKey, account_map) diff --git a/internal/handlers/gnucash_test.go b/internal/handlers/gnucash_test.go index 960078f..3bbd2df 100644 --- a/internal/handlers/gnucash_test.go +++ b/internal/handlers/gnucash_test.go @@ -38,13 +38,13 @@ func TestImportGnucash(t *testing.T) { } for i, account := range *accounts.Accounts { if account.Name == "Income" && account.Type == models.Income && account.ParentAccountId == -1 { - income = &(*accounts.Accounts)[i] + income = (*accounts.Accounts)[i] } else if account.Name == "Equity" && account.Type == models.Equity && account.ParentAccountId == -1 { - equity = &(*accounts.Accounts)[i] + equity = (*accounts.Accounts)[i] } else if account.Name == "Liabilities" && account.Type == models.Liability && account.ParentAccountId == -1 { - liabilities = &(*accounts.Accounts)[i] + liabilities = (*accounts.Accounts)[i] } else if account.Name == "Expenses" && account.Type == models.Expense && account.ParentAccountId == -1 { - expenses = &(*accounts.Accounts)[i] + expenses = (*accounts.Accounts)[i] } } if income == nil { @@ -61,15 +61,15 @@ func TestImportGnucash(t *testing.T) { } for i, account := range *accounts.Accounts { if account.Name == "Salary" && account.Type == models.Income && account.ParentAccountId == income.AccountId { - salary = &(*accounts.Accounts)[i] + salary = (*accounts.Accounts)[i] } else if account.Name == "Opening Balances" && account.Type == models.Equity && account.ParentAccountId == equity.AccountId { - openingbalances = &(*accounts.Accounts)[i] + openingbalances = (*accounts.Accounts)[i] } else if account.Name == "Credit Card" && account.Type == models.Liability && account.ParentAccountId == liabilities.AccountId { - creditcard = &(*accounts.Accounts)[i] + creditcard = (*accounts.Accounts)[i] } else if account.Name == "Groceries" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId { - groceries = &(*accounts.Accounts)[i] + groceries = (*accounts.Accounts)[i] } else if account.Name == "Cable" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId { - cable = &(*accounts.Accounts)[i] + cable = (*accounts.Accounts)[i] } } if salary == nil { diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index ea74042..78cee60 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -39,7 +39,7 @@ func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) } // Return Account with this Id - account, err := GetAccount(tx, accountid, user.UserId) + account, err := tx.GetAccount(accountid, user.UserId) if err != nil { log.Print(err) return NewError(3 /*Invalid Request*/) @@ -218,7 +218,7 @@ func OFXImportHandler(context *Context, r *http.Request, user *models.User, acco return NewError(3 /*Invalid Request*/) } - account, err := GetAccount(context.Tx, accountid, user.UserId) + account, err := context.Tx.GetAccount(accountid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/ofx_test.go b/internal/handlers/ofx_test.go index 4d78f25..af8712f 100644 --- a/internal/handlers/ofx_test.go +++ b/internal/handlers/ofx_test.go @@ -83,7 +83,7 @@ func findAccount(client *http.Client, name string, tipe models.AccountType, secu } for _, account := range *accounts.Accounts { if account.Name == name && account.Type == tipe && account.SecurityId == securityid { - return &account, nil + return account, nil } } return nil, fmt.Errorf("Unable to find account: \"%s\"", name) diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index ec9544b..7c720eb 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -5,6 +5,7 @@ package handlers import ( "errors" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" @@ -77,7 +78,7 @@ func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) return security, nil } - securities, err := tx.FindMatchingSecurities(userid, security) + securities, err := tx.FindMatchingSecurities(security) if err != nil { return nil, err } @@ -215,7 +216,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { } err = context.Tx.DeleteSecurity(security) - if _, ok := err.(db.SecurityInUseError); ok { + if _, ok := err.(store.SecurityInUseError); ok { return NewError(7 /*In Use Error*/) } else if err != nil { log.Print(err) diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 707d29a..9ba4952 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -32,7 +32,7 @@ func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.R if t.Splits[i].AccountId != -1 { var err error var account *models.Account - account, err = GetAccount(tx, t.Splits[i].AccountId, t.UserId) + account, err = tx.GetAccount(t.Splits[i].AccountId, t.UserId) if err != nil { return nil, err } @@ -100,7 +100,7 @@ func GetTransactions(tx *db.Tx, userid int64) (*[]models.Transaction, error) { func incrementAccountVersions(tx *db.Tx, user *models.User, accountids []int64) error { for i := range accountids { - account, err := GetAccount(tx, accountids[i], user.UserId) + account, err := tx.GetAccount(accountids[i], user.UserId) if err != nil { return err } @@ -297,7 +297,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter for i := range transaction.Splits { transaction.Splits[i].SplitId = -1 - _, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId) + _, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -371,7 +371,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter } for i := range transaction.Splits { - _, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId) + _, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -518,7 +518,7 @@ func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) } - account, err := GetAccount(tx, accountid, user.UserId) + account, err := tx.GetAccount(accountid, user.UserId) if err != nil { return nil, err } diff --git a/internal/models/accounts.go b/internal/models/accounts.go index fdfac98..7941ed7 100644 --- a/internal/models/accounts.go +++ b/internal/models/accounts.go @@ -94,7 +94,7 @@ type Account struct { } type AccountList struct { - Accounts *[]Account `json:"accounts"` + Accounts *[]*Account `json:"accounts"` } func (a *Account) Write(w http.ResponseWriter) error { diff --git a/internal/store/db/accounts.go b/internal/store/db/accounts.go new file mode 100644 index 0000000..07d0208 --- /dev/null +++ b/internal/store/db/accounts.go @@ -0,0 +1,133 @@ +package db + +import ( + "errors" + "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" +) + +func (tx *Tx) GetAccount(accountid int64, userid int64) (*models.Account, error) { + var account models.Account + + err := tx.SelectOne(&account, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) + if err != nil { + return nil, err + } + return &account, nil +} + +func (tx *Tx) GetAccounts(userid int64) (*[]*models.Account, error) { + var accounts []*models.Account + + _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) + if err != nil { + return nil, err + } + return &accounts, nil +} + +func (tx *Tx) FindMatchingAccounts(account *models.Account) (*[]*models.Account, error) { + var accounts []*models.Account + + _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC", account.UserId, account.SecurityId, account.Type, account.Name, account.ParentAccountId) + if err != nil { + return nil, err + } + return &accounts, nil +} + +func (tx *Tx) insertUpdateAccount(account *models.Account, insert bool) error { + found := make(map[int64]bool) + if !insert { + found[account.AccountId] = true + } + parentid := account.ParentAccountId + depth := 0 + for parentid != -1 { + depth += 1 + if depth > 100 { + return store.TooMuchNestingError{} + } + + var a models.Account + err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) + if err != nil { + return store.ParentAccountMissingError{} + } + + // Insertion by itself can never result in circular dependencies + if insert { + break + } + + found[parentid] = true + parentid = a.ParentAccountId + if _, ok := found[parentid]; ok { + return store.CircularAccountsError{} + } + } + + if insert { + err := tx.Insert(account) + if err != nil { + return err + } + } else { + oldacct, err := tx.GetAccount(account.AccountId, account.UserId) + if err != nil { + return err + } + + account.AccountVersion = oldacct.AccountVersion + 1 + + count, err := tx.Update(account) + if err != nil { + return err + } + if count != 1 { + return errors.New("Updated more than one account") + } + } + + return nil +} + +func (tx *Tx) InsertAccount(account *models.Account) error { + return tx.insertUpdateAccount(account, true) +} + +func (tx *Tx) UpdateAccount(account *models.Account) error { + return tx.insertUpdateAccount(account, false) +} + +func (tx *Tx) DeleteAccount(account *models.Account) error { + if account.ParentAccountId != -1 { + // Re-parent splits to this account's parent account if this account isn't a root account + _, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", account.ParentAccountId, account.AccountId) + if err != nil { + return err + } + } else { + // Delete splits if this account is a root account + _, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", account.AccountId) + if err != nil { + return err + } + } + + // Re-parent child accounts to this account's parent account + _, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", account.ParentAccountId, account.AccountId) + if err != nil { + return err + } + + count, err := tx.Delete(account) + if err != nil { + return err + } + if count != 1 { + return errors.New("Was going to delete more than one account") + } + + return nil +} diff --git a/internal/store/db/securities.go b/internal/store/db/securities.go index 0acba29..83a5659 100644 --- a/internal/store/db/securities.go +++ b/internal/store/db/securities.go @@ -3,16 +3,9 @@ package db import ( "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" ) -type SecurityInUseError struct { - message string -} - -func (e SecurityInUseError) Error() string { - return e.message -} - func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) { var s models.Security @@ -33,10 +26,10 @@ func (tx *Tx) GetSecurities(userid int64) (*[]*models.Security, error) { return &securities, nil } -func (tx *Tx) FindMatchingSecurities(userid int64, security *models.Security) (*[]*models.Security, error) { +func (tx *Tx) FindMatchingSecurities(security *models.Security) (*[]*models.Security, error) { var securities []*models.Security - _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision) + _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", security.UserId, security.Type, security.AlternateId, security.Precision) if err != nil { return nil, err } @@ -67,14 +60,14 @@ func (tx *Tx) DeleteSecurity(s *models.Security) error { accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) if accounts != 0 { - return SecurityInUseError{"One or more accounts still use this security"} + return store.SecurityInUseError{"One or more accounts still use this security"} } user, err := tx.GetUser(s.UserId) if err != nil { return err } else if user.DefaultCurrency == s.SecurityId { - return SecurityInUseError{"Cannot delete security which is user's default currency"} + return store.SecurityInUseError{"Cannot delete security which is user's default currency"} } // Remove all prices involving this security (either of this security, or diff --git a/internal/store/store.go b/internal/store/store.go index 86d6c66..fe99c9c 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -20,15 +20,50 @@ type UserStore interface { DeleteUser(user *models.User) error } +type SecurityInUseError struct { + Message string +} + +func (e SecurityInUseError) Error() string { + return e.Message +} + type SecurityStore interface { InsertSecurity(security *models.Security) error GetSecurity(securityid int64, userid int64) (*models.Security, error) GetSecurities(userid int64) (*[]*models.Security, error) - FindMatchingSecurities(userid int64, security *models.Security) (*[]*models.Security, error) + FindMatchingSecurities(security *models.Security) (*[]*models.Security, error) UpdateSecurity(security *models.Security) error DeleteSecurity(security *models.Security) error } +type ParentAccountMissingError struct{} + +func (pame ParentAccountMissingError) Error() string { + return "Parent account missing" +} + +type TooMuchNestingError struct{} + +func (tmne TooMuchNestingError) Error() string { + return "Too much account nesting" +} + +type CircularAccountsError struct{} + +func (cae CircularAccountsError) Error() string { + return "Would result in circular account relationship" +} + +type AccountStore interface { + InsertAccount(account *models.Account) error + GetAccount(accountid int64, userid int64) (*models.Account, error) + GetAccounts(userid int64) (*[]*models.Account, error) + FindMatchingAccounts(account *models.Account) (*[]*models.Account, error) + UpdateAccount(account *models.Account) error + DeleteAccount(account *models.Account) error +} + type Tx interface { Commit() error Rollback() error @@ -36,6 +71,7 @@ type Tx interface { SessionStore UserStore SecurityStore + AccountStore } type Store interface { From 61676598dd1f657064c0abcb2786bf3703919ec7 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Thu, 7 Dec 2017 21:05:55 -0500 Subject: [PATCH 4/9] Move prices to store --- internal/handlers/prices.go | 66 ++++++------------------------- internal/store/db/db.go | 4 +- internal/store/db/prices.go | 78 +++++++++++++++++++++++++++++++++++++ internal/store/store.go | 37 ++++++++++++++---- 4 files changed, 121 insertions(+), 64 deletions(-) create mode 100644 internal/store/db/prices.go diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index f08d37e..c51d5c8 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -18,67 +18,25 @@ func CreatePriceIfNotExist(tx *db.Tx, price *models.Price) error { return nil } - var prices []*models.Price - - _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) + exists, err := tx.PriceExists(price) if err != nil { return err } - - if len(prices) > 0 { + if exists { return nil // price already exists } - err = tx.Insert(price) + err = tx.InsertPrice(price) if err != nil { return err } return nil } -func GetPrice(tx *db.Tx, priceid, securityid int64) (*models.Price, error) { - var p models.Price - err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid) - if err != nil { - return nil, err - } - return &p, nil -} - -func GetPrices(tx *db.Tx, securityid int64) (*[]*models.Price, error) { - var prices []*models.Price - - _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid) - if err != nil { - return nil, err - } - return &prices, nil -} - -// Return the latest price for security in currency units before date -func GetLatestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { - var p models.Price - err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) - if err != nil { - return nil, err - } - return &p, nil -} - -// Return the earliest price for security in currency units after date -func GetEarliestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { - var p models.Price - err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) - if err != nil { - return nil, err - } - return &p, nil -} - // Return the price for security in currency closest to date func GetClosestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { - earliest, _ := GetEarliestPrice(tx, security, currency, date) - latest, err := GetLatestPrice(tx, security, currency, date) + earliest, _ := tx.GetEarliestPrice(security, currency, date) + latest, err := tx.GetLatestPrice(security, currency, date) // Return early if either earliest or latest are invalid if earliest == nil { @@ -129,7 +87,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security //Return all this security's prices var pl models.PriceList - prices, err := GetPrices(context.Tx, security.SecurityId) + prices, err := context.Tx.GetPrices(security.SecurityId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -144,7 +102,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security return NewError(3 /*Invalid Request*/) } - price, err := GetPrice(context.Tx, priceid, security.SecurityId) + price, err := context.Tx.GetPrice(priceid, security.SecurityId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -170,21 +128,21 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security return NewError(3 /*Invalid Request*/) } - count, err := context.Tx.Update(&price) - if err != nil || count != 1 { + err = context.Tx.UpdatePrice(&price) + if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) } return &price } else if r.Method == "DELETE" { - price, err := GetPrice(context.Tx, priceid, security.SecurityId) + price, err := context.Tx.GetPrice(priceid, security.SecurityId) if err != nil { return NewError(3 /*Invalid Request*/) } - count, err := context.Tx.Delete(price) - if err != nil || count != 1 { + err = context.Tx.DeletePrice(price) + if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) } diff --git a/internal/store/db/db.go b/internal/store/db/db.go index 3a4031e..d8b043b 100644 --- a/internal/store/db/db.go +++ b/internal/store/db/db.go @@ -39,11 +39,11 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { dbmap := &gorp.DbMap{Db: db, Dialect: dialect} dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId") dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId") - dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId") dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId") + dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId") + dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId") dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId") dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId") - dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId") rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId") rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer) diff --git a/internal/store/db/prices.go b/internal/store/db/prices.go new file mode 100644 index 0000000..2df2ab9 --- /dev/null +++ b/internal/store/db/prices.go @@ -0,0 +1,78 @@ +package db + +import ( + "fmt" + "github.com/aclindsa/moneygo/internal/models" + "time" +) + +func (tx *Tx) PriceExists(price *models.Price) (bool, error) { + var prices []*models.Price + _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) + return len(prices) > 0, err +} + +func (tx *Tx) InsertPrice(price *models.Price) error { + return tx.Insert(price) +} + +func (tx *Tx) GetPrice(priceid, securityid int64) (*models.Price, error) { + var price models.Price + err := tx.SelectOne(&price, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid) + if err != nil { + return nil, err + } + return &price, nil +} + +func (tx *Tx) GetPrices(securityid int64) (*[]*models.Price, error) { + var prices []*models.Price + + _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid) + if err != nil { + return nil, err + } + return &prices, nil +} + +// Return the latest price for security in currency units before date +func (tx *Tx) GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) { + var price models.Price + err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) + if err != nil { + return nil, err + } + return &price, nil +} + +// Return the earliest price for security in currency units after date +func (tx *Tx) GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) { + var price models.Price + err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) + if err != nil { + return nil, err + } + return &price, nil +} + +func (tx *Tx) UpdatePrice(price *models.Price) error { + count, err := tx.Update(price) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to update 1 price, was going to update %d", count) + } + return nil +} + +func (tx *Tx) DeletePrice(price *models.Price) error { + count, err := tx.Delete(price) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to delete 1 price, was going to delete %d", count) + } + return nil +} diff --git a/internal/store/store.go b/internal/store/store.go index fe99c9c..b1ffc8a 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -2,15 +2,9 @@ package store import ( "github.com/aclindsa/moneygo/internal/models" + "time" ) -type SessionStore interface { - SessionExists(secret string) (bool, error) - InsertSession(session *models.Session) error - GetSession(secret string) (*models.Session, error) - DeleteSession(session *models.Session) error -} - type UserStore interface { UsernameExists(username string) (bool, error) InsertUser(user *models.User) error @@ -20,6 +14,13 @@ type UserStore interface { DeleteUser(user *models.User) error } +type SessionStore interface { + SessionExists(secret string) (bool, error) + InsertSession(session *models.Session) error + GetSession(secret string) (*models.Session, error) + DeleteSession(session *models.Session) error +} + type SecurityInUseError struct { Message string } @@ -37,6 +38,17 @@ type SecurityStore interface { DeleteSecurity(security *models.Security) error } +type PriceStore interface { + PriceExists(price *models.Price) (bool, error) + InsertPrice(price *models.Price) error + GetPrice(priceid, securityid int64) (*models.Price, error) + GetPrices(securityid int64) (*[]*models.Price, error) + GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) + GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) + UpdatePrice(price *models.Price) error + DeletePrice(price *models.Price) error +} + type ParentAccountMissingError struct{} func (pame ParentAccountMissingError) Error() string { @@ -64,14 +76,23 @@ type AccountStore interface { DeleteAccount(account *models.Account) error } +type TransactionStore interface { +} + +type ReportStore interface { +} + type Tx interface { Commit() error Rollback() error - SessionStore UserStore + SessionStore SecurityStore + PriceStore AccountStore + TransactionStore + ReportStore } type Store interface { From da7e025509838e8bf45e51054d235a30f316d1dd Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Fri, 8 Dec 2017 21:27:03 -0500 Subject: [PATCH 5/9] Move splits/transactions to store --- internal/handlers/accounts_lua.go | 24 +- internal/handlers/gnucash.go | 4 +- internal/handlers/imports.go | 4 +- internal/handlers/transactions.go | 399 +------------------------ internal/handlers/transactions_test.go | 4 +- internal/models/transactions.go | 4 +- internal/store/db/transactions.go | 361 ++++++++++++++++++++++ internal/store/store.go | 16 + 8 files changed, 410 insertions(+), 406 deletions(-) create mode 100644 internal/store/db/transactions.go diff --git a/internal/handlers/accounts_lua.go b/internal/handlers/accounts_lua.go index 2036e62..2985ff5 100644 --- a/internal/handlers/accounts_lua.go +++ b/internal/handlers/accounts_lua.go @@ -6,7 +6,6 @@ import ( "github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/store/db" "github.com/yuin/gopher-lua" - "math/big" "strings" ) @@ -168,24 +167,29 @@ func luaAccountBalance(L *lua.LState) int { panic("SecurityId not in lua security_map") } date := luaWeakCheckTime(L, 2) - var b Balance - var rat *big.Rat + var splits *[]*models.Split if date != nil { end := luaWeakCheckTime(L, 3) if end != nil { - rat, err = GetAccountBalanceDateRange(tx, user, a.AccountId, date, end) + splits, err = tx.GetAccountSplitsDateRange(user, a.AccountId, date, end) } else { - rat, err = GetAccountBalanceDate(tx, user, a.AccountId, date) + splits, err = tx.GetAccountSplitsDate(user, a.AccountId, date) } } else { - rat, err = GetAccountBalance(tx, user, a.AccountId) + splits, err = tx.GetAccountSplits(user, a.AccountId) } if err != nil { - panic("Failed to GetAccountBalance:" + err.Error()) + panic("Failed to fetch splits for account:" + err.Error()) } - b.Amount = rat - b.Security = security - L.Push(BalanceToLua(L, &b)) + rat, err := BalanceFromSplits(splits) + if err != nil { + panic("Failed to calculate balance for account:" + err.Error()) + } + b := &Balance{ + Amount: rat, + Security: security, + } + L.Push(BalanceToLua(L, b)) return 1 } diff --git a/internal/handlers/gnucash.go b/internal/handlers/gnucash.go index 2399a6b..05fa0af 100644 --- a/internal/handlers/gnucash.go +++ b/internal/handlers/gnucash.go @@ -437,7 +437,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite } split.AccountId = acctId - exists, err := SplitAlreadyImported(context.Tx, split) + exists, err := context.Tx.SplitExists(split) if err != nil { log.Print("Error checking if split was already imported:", err) return NewError(999 /*Internal Error*/) @@ -446,7 +446,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite } } if !already_imported { - err := InsertTransaction(context.Tx, &transaction, user) + err := context.Tx.InsertTransaction(&transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 78cee60..b76ae68 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -187,7 +187,7 @@ func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) split.SecurityId = -1 } - exists, err := SplitAlreadyImported(tx, split) + exists, err := tx.SplitExists(split) if err != nil { log.Print("Error checking if split was already imported:", err) return NewError(999 /*Internal Error*/) @@ -202,7 +202,7 @@ func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) } for _, transaction := range transactions { - err := InsertTransaction(tx, &transaction, user) + err := tx.InsertTransaction(&transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 9ba4952..2d16da1 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -2,22 +2,16 @@ package handlers import ( "errors" - "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" "github.com/aclindsa/moneygo/internal/store/db" "log" "math/big" "net/http" "net/url" "strconv" - "time" ) -func SplitAlreadyImported(tx *db.Tx, s *models.Split) (bool, error) { - count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) - return count == 1, err -} - // Return a map of security ID's to big.Rat's containing the amount that // security is imbalanced by func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.Rat, error) { @@ -64,219 +58,6 @@ func TransactionBalanced(tx *db.Tx, t *models.Transaction) (bool, error) { return true, nil } -func GetTransaction(tx *db.Tx, transactionid int64, userid int64) (*models.Transaction, error) { - var t models.Transaction - - err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) - if err != nil { - return nil, err - } - - _, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) - if err != nil { - return nil, err - } - - return &t, nil -} - -func GetTransactions(tx *db.Tx, userid int64) (*[]models.Transaction, error) { - var transactions []models.Transaction - - _, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) - if err != nil { - return nil, err - } - - for i := range transactions { - _, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) - if err != nil { - return nil, err - } - } - - return &transactions, nil -} - -func incrementAccountVersions(tx *db.Tx, user *models.User, accountids []int64) error { - for i := range accountids { - account, err := tx.GetAccount(accountids[i], user.UserId) - if err != nil { - return err - } - account.AccountVersion++ - count, err := tx.Update(account) - if err != nil { - return err - } - if count != 1 { - return errors.New("Updated more than one account") - } - } - return nil -} - -type AccountMissingError struct{} - -func (ame AccountMissingError) Error() string { - return "Account missing" -} - -func InsertTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error { - // Map of any accounts with transaction splits being added - a_map := make(map[int64]bool) - for i := range t.Splits { - if t.Splits[i].AccountId != -1 { - existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) - if err != nil { - return err - } - if existing != 1 { - return AccountMissingError{} - } - a_map[t.Splits[i].AccountId] = true - } else if t.Splits[i].SecurityId == -1 { - return AccountMissingError{} - } - } - - //increment versions for all accounts - var a_ids []int64 - for id := range a_map { - a_ids = append(a_ids, id) - } - // ensure at least one of the splits is associated with an actual account - if len(a_ids) < 1 { - return AccountMissingError{} - } - err := incrementAccountVersions(tx, user, a_ids) - if err != nil { - return err - } - - t.UserId = user.UserId - err = tx.Insert(t) - if err != nil { - return err - } - - for i := range t.Splits { - t.Splits[i].TransactionId = t.TransactionId - t.Splits[i].SplitId = -1 - err = tx.Insert(t.Splits[i]) - if err != nil { - return err - } - } - - return nil -} - -func UpdateTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error { - var existing_splits []*models.Split - - _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) - if err != nil { - return err - } - - // Map of any accounts with transaction splits being added - a_map := make(map[int64]bool) - - // Make a map with any existing splits for this transaction - s_map := make(map[int64]bool) - for i := range existing_splits { - s_map[existing_splits[i].SplitId] = true - } - - // Insert splits, updating any pre-existing ones - for i := range t.Splits { - t.Splits[i].TransactionId = t.TransactionId - _, ok := s_map[t.Splits[i].SplitId] - if ok { - count, err := tx.Update(t.Splits[i]) - if err != nil { - return err - } - if count > 1 { - return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count) - } - delete(s_map, t.Splits[i].SplitId) - } else { - t.Splits[i].SplitId = -1 - err := tx.Insert(t.Splits[i]) - if err != nil { - return err - } - } - if t.Splits[i].AccountId != -1 { - a_map[t.Splits[i].AccountId] = true - } - } - - // Delete any remaining pre-existing splits - for i := range existing_splits { - _, ok := s_map[existing_splits[i].SplitId] - if existing_splits[i].AccountId != -1 { - a_map[existing_splits[i].AccountId] = true - } - if ok { - _, err := tx.Delete(existing_splits[i]) - if err != nil { - return err - } - } - } - - // Increment versions for all accounts with modified splits - var a_ids []int64 - for id := range a_map { - a_ids = append(a_ids, id) - } - err = incrementAccountVersions(tx, user, a_ids) - if err != nil { - return err - } - - count, err := tx.Update(t) - if err != nil { - return err - } - if count > 1 { - return fmt.Errorf("Updated %d transactions (expected 1)", count) - } - - return nil -} - -func DeleteTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error { - var accountids []int64 - _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) - if err != nil { - return err - } - - _, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) - if err != nil { - return err - } - - count, err := tx.Delete(t) - if err != nil { - return err - } - if count != 1 { - return errors.New("Deleted more than one transaction") - } - - err = incrementAccountVersions(tx, user, accountids) - if err != nil { - return err - } - - return nil -} - func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter { user, err := GetUserFromSession(context.Tx, r) if err != nil { @@ -311,9 +92,9 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return NewError(3 /*Invalid Request*/) } - err = InsertTransaction(context.Tx, &transaction, user) + err = context.Tx.InsertTransaction(&transaction, user) if err != nil { - if _, ok := err.(AccountMissingError); ok { + if _, ok := err.(store.AccountMissingError); ok { return NewError(3 /*Invalid Request*/) } else { log.Print(err) @@ -326,7 +107,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter if context.LastLevel() { //Return all Transactions var al models.TransactionList - transactions, err := GetTransactions(context.Tx, user.UserId) + transactions, err := context.Tx.GetTransactions(user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -339,7 +120,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter if err != nil { return NewError(3 /*Invalid Request*/) } - transaction, err := GetTransaction(context.Tx, transactionid, user.UserId) + transaction, err := context.Tx.GetTransaction(transactionid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -377,7 +158,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter } } - err = UpdateTransaction(context.Tx, &transaction, user) + err = context.Tx.UpdateTransaction(&transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -385,12 +166,12 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return &transaction } else if r.Method == "DELETE" { - transaction, err := GetTransaction(context.Tx, transactionid, user.UserId) + transaction, err := context.Tx.GetTransaction(transactionid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - err = DeleteTransaction(context.Tx, transaction, user) + err = context.Tx.DeleteTransaction(transaction, user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -402,41 +183,9 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return NewError(3 /*Invalid Request*/) } -func TransactionsBalanceDifference(tx *db.Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) { - var pageDifference, tmp big.Rat - for i := range transactions { - _, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) - if err != nil { - return nil, err - } - - // Sum up the amounts from the splits we're returning so we can return - // an ending balance - for j := range transactions[i].Splits { - if transactions[i].Splits[j].AccountId == accountid { - rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount) - if err != nil { - return nil, err - } - tmp.Add(&pageDifference, rat_amount) - pageDifference.Set(&tmp) - } - } - } - return &pageDifference, nil -} - -func GetAccountBalance(tx *db.Tx, user *models.User, accountid int64) (*big.Rat, error) { - var splits []models.Split - - sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" - _, err := tx.Select(&splits, sql, accountid, user.UserId) - if err != nil { - return nil, err - } - +func BalanceFromSplits(splits *[]*models.Split) (*big.Rat, error) { var balance, tmp big.Rat - for _, s := range splits { + for _, s := range *splits { rat_amount, err := models.GetBigAmount(s.Amount) if err != nil { return nil, err @@ -448,132 +197,6 @@ func GetAccountBalance(tx *db.Tx, user *models.User, accountid int64) (*big.Rat, return &balance, nil } -// Assumes accountid is valid and is owned by the current user -func GetAccountBalanceDate(tx *db.Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) { - var splits []models.Split - - sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" - _, err := tx.Select(&splits, sql, accountid, user.UserId, date) - if err != nil { - return nil, err - } - - var balance, tmp big.Rat - for _, s := range splits { - rat_amount, err := models.GetBigAmount(s.Amount) - if err != nil { - return nil, err - } - tmp.Add(&balance, rat_amount) - balance.Set(&tmp) - } - - return &balance, nil -} - -func GetAccountBalanceDateRange(tx *db.Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) { - var splits []models.Split - - sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" - _, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end) - if err != nil { - return nil, err - } - - var balance, tmp big.Rat - for _, s := range splits { - rat_amount, err := models.GetBigAmount(s.Amount) - if err != nil { - return nil, err - } - tmp.Add(&balance, rat_amount) - balance.Set(&tmp) - } - - return &balance, nil -} - -func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) { - var transactions []models.Transaction - var atl models.AccountTransactionsList - - var sqlsort, balanceLimitOffset string - var balanceLimitOffsetArg uint64 - if sort == "date-asc" { - sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC" - balanceLimitOffset = " LIMIT ?" - balanceLimitOffsetArg = page * limit - } else if sort == "date-desc" { - numSplits, err := tx.SelectInt("SELECT count(*) FROM splits") - if err != nil { - return nil, err - } - sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC" - balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits) - balanceLimitOffsetArg = (page + 1) * limit - } - - var sqloffset string - if page > 0 { - sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) - } - - account, err := tx.GetAccount(accountid, user.UserId) - if err != nil { - return nil, err - } - atl.Account = account - - sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset - _, err = tx.Select(&transactions, sql, user.UserId, accountid, limit) - if err != nil { - return nil, err - } - atl.Transactions = &transactions - - pageDifference, err := TransactionsBalanceDifference(tx, accountid, transactions) - if err != nil { - return nil, err - } - - count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) - if err != nil { - return nil, err - } - atl.TotalTransactions = count - - security, err := tx.GetSecurity(atl.Account.SecurityId, user.UserId) - if err != nil { - return nil, err - } - if security == nil { - return nil, errors.New("Security not found") - } - - // Sum all the splits for all transaction splits for this account that - // occurred before the page we're returning - var amounts []string - sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" - _, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) - if err != nil { - return nil, err - } - - var tmp, balance big.Rat - for _, amount := range amounts { - rat_amount, err := models.GetBigAmount(amount) - if err != nil { - return nil, err - } - tmp.Add(&balance, rat_amount) - balance.Set(&tmp) - } - atl.BeginningBalance = balance.FloatString(security.Precision) - atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) - - return &atl, nil -} - // Return only those transactions which have at least one split pertaining to // an account func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { @@ -609,7 +232,7 @@ func AccountTransactionsHandler(context *Context, r *http.Request, user *models. sort = sortstring } - accountTransactions, err := GetAccountTransactions(context.Tx, user, accountid, sort, page, limit) + accountTransactions, err := context.Tx.GetAccountTransactions(user, accountid, sort, page, limit) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/transactions_test.go b/internal/handlers/transactions_test.go index 0f3b68d..4d2ee54 100644 --- a/internal/handlers/transactions_test.go +++ b/internal/handlers/transactions_test.go @@ -276,7 +276,7 @@ func TestGetTransactions(t *testing.T) { found := false for _, tran := range *tl.Transactions { if tran.TransactionId == curr.TransactionId { - ensureTransactionsMatch(t, &curr, &tran, nil, true, true) + ensureTransactionsMatch(t, &curr, tran, nil, true, true) if _, ok := foundIds[tran.TransactionId]; ok { continue } @@ -410,7 +410,7 @@ func helperTestAccountTransactions(t *testing.T, d *TestData, account *models.Ac } if atl.Transactions != nil { for _, tran := range *atl.Transactions { - transactions = append(transactions, tran) + transactions = append(transactions, *tran) } lastFetchCount = int64(len(*atl.Transactions)) } else { diff --git a/internal/models/transactions.go b/internal/models/transactions.go index 8076995..5046e65 100644 --- a/internal/models/transactions.go +++ b/internal/models/transactions.go @@ -82,12 +82,12 @@ type Transaction struct { } type TransactionList struct { - Transactions *[]Transaction `json:"transactions"` + Transactions *[]*Transaction `json:"transactions"` } type AccountTransactionsList struct { Account *Account - Transactions *[]Transaction + Transactions *[]*Transaction TotalTransactions int64 BeginningBalance string EndingBalance string diff --git a/internal/store/db/transactions.go b/internal/store/db/transactions.go new file mode 100644 index 0000000..29168df --- /dev/null +++ b/internal/store/db/transactions.go @@ -0,0 +1,361 @@ +package db + +import ( + "errors" + "fmt" + "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" + "math/big" + "time" +) + +func (tx *Tx) incrementAccountVersions(user *models.User, accountids []int64) error { + for i := range accountids { + account, err := tx.GetAccount(accountids[i], user.UserId) + if err != nil { + return err + } + account.AccountVersion++ + count, err := tx.Update(account) + if err != nil { + return err + } + if count != 1 { + return errors.New("Updated more than one account") + } + } + return nil +} + +func (tx *Tx) InsertTransaction(t *models.Transaction, user *models.User) error { + // Map of any accounts with transaction splits being added + a_map := make(map[int64]bool) + for i := range t.Splits { + if t.Splits[i].AccountId != -1 { + existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) + if err != nil { + return err + } + if existing != 1 { + return store.AccountMissingError{} + } + a_map[t.Splits[i].AccountId] = true + } else if t.Splits[i].SecurityId == -1 { + return store.AccountMissingError{} + } + } + + //increment versions for all accounts + var a_ids []int64 + for id := range a_map { + a_ids = append(a_ids, id) + } + // ensure at least one of the splits is associated with an actual account + if len(a_ids) < 1 { + return store.AccountMissingError{} + } + err := tx.incrementAccountVersions(user, a_ids) + if err != nil { + return err + } + + t.UserId = user.UserId + err = tx.Insert(t) + if err != nil { + return err + } + + for i := range t.Splits { + t.Splits[i].TransactionId = t.TransactionId + t.Splits[i].SplitId = -1 + err = tx.Insert(t.Splits[i]) + if err != nil { + return err + } + } + + return nil +} + +func (tx *Tx) SplitExists(s *models.Split) (bool, error) { + count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) + return count == 1, err +} + +func (tx *Tx) GetTransaction(transactionid int64, userid int64) (*models.Transaction, error) { + var t models.Transaction + + err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) + if err != nil { + return nil, err + } + + _, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid) + if err != nil { + return nil, err + } + + return &t, nil +} + +func (tx *Tx) GetTransactions(userid int64) (*[]*models.Transaction, error) { + var transactions []*models.Transaction + + _, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) + if err != nil { + return nil, err + } + + for i := range transactions { + _, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId) + if err != nil { + return nil, err + } + } + + return &transactions, nil +} + +func (tx *Tx) UpdateTransaction(t *models.Transaction, user *models.User) error { + var existing_splits []*models.Split + + _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) + if err != nil { + return err + } + + // Map of any accounts with transaction splits being added + a_map := make(map[int64]bool) + + // Make a map with any existing splits for this transaction + s_map := make(map[int64]bool) + for i := range existing_splits { + s_map[existing_splits[i].SplitId] = true + } + + // Insert splits, updating any pre-existing ones + for i := range t.Splits { + t.Splits[i].TransactionId = t.TransactionId + _, ok := s_map[t.Splits[i].SplitId] + if ok { + count, err := tx.Update(t.Splits[i]) + if err != nil { + return err + } + if count > 1 { + return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count) + } + delete(s_map, t.Splits[i].SplitId) + } else { + t.Splits[i].SplitId = -1 + err := tx.Insert(t.Splits[i]) + if err != nil { + return err + } + } + if t.Splits[i].AccountId != -1 { + a_map[t.Splits[i].AccountId] = true + } + } + + // Delete any remaining pre-existing splits + for i := range existing_splits { + _, ok := s_map[existing_splits[i].SplitId] + if existing_splits[i].AccountId != -1 { + a_map[existing_splits[i].AccountId] = true + } + if ok { + _, err := tx.Delete(existing_splits[i]) + if err != nil { + return err + } + } + } + + // Increment versions for all accounts with modified splits + var a_ids []int64 + for id := range a_map { + a_ids = append(a_ids, id) + } + err = tx.incrementAccountVersions(user, a_ids) + if err != nil { + return err + } + + count, err := tx.Update(t) + if err != nil { + return err + } + if count > 1 { + return fmt.Errorf("Updated %d transactions (expected 1)", count) + } + + return nil +} + +func (tx *Tx) DeleteTransaction(t *models.Transaction, user *models.User) error { + var accountids []int64 + _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) + if err != nil { + return err + } + + _, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId) + if err != nil { + return err + } + + count, err := tx.Delete(t) + if err != nil { + return err + } + if count != 1 { + return errors.New("Deleted more than one transaction") + } + + err = tx.incrementAccountVersions(user, accountids) + if err != nil { + return err + } + + return nil +} + +func (tx *Tx) GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error) { + var splits []*models.Split + + sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" + _, err := tx.Select(&splits, sql, accountid, user.UserId) + if err != nil { + return nil, err + } + return &splits, nil +} + +// Assumes accountid is valid and is owned by the current user +func (tx *Tx) GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error) { + var splits []*models.Split + + sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" + _, err := tx.Select(&splits, sql, accountid, user.UserId, date) + if err != nil { + return nil, err + } + return &splits, err +} + +func (tx *Tx) GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error) { + var splits []*models.Split + + sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" + _, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end) + if err != nil { + return nil, err + } + return &splits, nil +} + +func (tx *Tx) transactionsBalanceDifference(accountid int64, transactions []*models.Transaction) (*big.Rat, error) { + var pageDifference, tmp big.Rat + for i := range transactions { + _, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) + if err != nil { + return nil, err + } + + // Sum up the amounts from the splits we're returning so we can return + // an ending balance + for j := range transactions[i].Splits { + if transactions[i].Splits[j].AccountId == accountid { + rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount) + if err != nil { + return nil, err + } + tmp.Add(&pageDifference, rat_amount) + pageDifference.Set(&tmp) + } + } + } + return &pageDifference, nil +} + +func (tx *Tx) GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) { + var transactions []*models.Transaction + var atl models.AccountTransactionsList + + var sqlsort, balanceLimitOffset string + var balanceLimitOffsetArg uint64 + if sort == "date-asc" { + sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC" + balanceLimitOffset = " LIMIT ?" + balanceLimitOffsetArg = page * limit + } else if sort == "date-desc" { + numSplits, err := tx.SelectInt("SELECT count(*) FROM splits") + if err != nil { + return nil, err + } + sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC" + balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits) + balanceLimitOffsetArg = (page + 1) * limit + } + + var sqloffset string + if page > 0 { + sqloffset = fmt.Sprintf(" OFFSET %d", page*limit) + } + + account, err := tx.GetAccount(accountid, user.UserId) + if err != nil { + return nil, err + } + atl.Account = account + + sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset + _, err = tx.Select(&transactions, sql, user.UserId, accountid, limit) + if err != nil { + return nil, err + } + atl.Transactions = &transactions + + pageDifference, err := tx.transactionsBalanceDifference(accountid, transactions) + if err != nil { + return nil, err + } + + count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) + if err != nil { + return nil, err + } + atl.TotalTransactions = count + + security, err := tx.GetSecurity(atl.Account.SecurityId, user.UserId) + if err != nil { + return nil, err + } + if security == nil { + return nil, errors.New("Security not found") + } + + // Sum all the splits for all transaction splits for this account that + // occurred before the page we're returning + var amounts []string + sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?" + _, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid) + if err != nil { + return nil, err + } + + var tmp, balance big.Rat + for _, amount := range amounts { + rat_amount, err := models.GetBigAmount(amount) + if err != nil { + return nil, err + } + tmp.Add(&balance, rat_amount) + balance.Set(&tmp) + } + atl.BeginningBalance = balance.FloatString(security.Precision) + atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) + + return &atl, nil +} diff --git a/internal/store/store.go b/internal/store/store.go index b1ffc8a..c890dd7 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -76,7 +76,23 @@ type AccountStore interface { DeleteAccount(account *models.Account) error } +type AccountMissingError struct{} + +func (ame AccountMissingError) Error() string { + return "Account missing" +} + type TransactionStore interface { + SplitExists(s *models.Split) (bool, error) + InsertTransaction(t *models.Transaction, user *models.User) error + GetTransaction(transactionid int64, userid int64) (*models.Transaction, error) + GetTransactions(userid int64) (*[]*models.Transaction, error) + UpdateTransaction(t *models.Transaction, user *models.User) error + DeleteTransaction(t *models.Transaction, user *models.User) error + GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error) + GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error) + GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error) + GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) } type ReportStore interface { From af97f92df5646866994d01bd635b59d1f44f9e11 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sat, 9 Dec 2017 05:47:25 -0500 Subject: [PATCH 6/9] db: Paper over MySQL returning count=0 for unchanged updates --- internal/store/db/tx.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/internal/store/db/tx.go b/internal/store/db/tx.go index 5187201..c899dc0 100644 --- a/internal/store/db/tx.go +++ b/internal/store/db/tx.go @@ -41,7 +41,20 @@ func (tx *Tx) Insert(list ...interface{}) error { } func (tx *Tx) Update(list ...interface{}) (int64, error) { - return tx.Tx.Update(list...) + count, err := tx.Tx.Update(list...) + if count == 0 { + switch tx.Dialect.(type) { + case gorp.MySQLDialect: + // Always return 1 for 0 if we're using MySQL because it returns + // count=0 if the row data was unchanged, even if the row existed + + // TODO Find another way to fix this without risking ignoring + // errors + + count = 1 + } + } + return count, err } func (tx *Tx) Delete(list ...interface{}) (int64, error) { From 32aef11da58b0b8171484ebd109399337d094c6c Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sat, 9 Dec 2017 05:56:45 -0500 Subject: [PATCH 7/9] Finish 'store' separation --- internal/handlers/accounts.go | 9 ++-- internal/handlers/accounts_lua.go | 6 +-- internal/handlers/handlers.go | 32 +------------ internal/handlers/imports.go | 4 +- internal/handlers/prices.go | 10 ++--- internal/handlers/reports.go | 70 +++++------------------------ internal/handlers/securities.go | 5 +-- internal/handlers/securities_lua.go | 6 +-- internal/handlers/sessions.go | 8 ++-- internal/handlers/transactions.go | 5 +-- internal/handlers/tx.go | 14 ------ internal/handlers/users.go | 2 +- internal/models/reports.go | 2 +- internal/store/db/reports.go | 56 +++++++++++++++++++++++ internal/store/store.go | 5 +++ 15 files changed, 100 insertions(+), 134 deletions(-) delete mode 100644 internal/handlers/tx.go create mode 100644 internal/store/db/reports.go diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index 60a2199..7f2cc23 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -4,14 +4,13 @@ import ( "errors" "github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/store" - "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" ) // Get (and attempt to create if it doesn't exist). Matches on UserId, // SecurityId, Type, Name, and ParentAccountId -func GetCreateAccount(tx *db.Tx, a models.Account) (*models.Account, error) { +func GetCreateAccount(tx store.Tx, a models.Account) (*models.Account, error) { var account models.Account accounts, err := tx.FindMatchingAccounts(&a) @@ -27,7 +26,7 @@ func GetCreateAccount(tx *db.Tx, a models.Account) (*models.Account, error) { account.Name = a.Name account.ParentAccountId = a.ParentAccountId - err = tx.Insert(&account) + err = tx.InsertAccount(&account) if err != nil { return nil, err } @@ -37,7 +36,7 @@ func GetCreateAccount(tx *db.Tx, a models.Account) (*models.Account, error) { // Get (and attempt to create if it doesn't exist) the security/currency // trading account for the supplied security/currency -func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) { +func GetTradingAccount(tx store.Tx, userid int64, securityid int64) (*models.Account, error) { var tradingAccount models.Account var account models.Account @@ -79,7 +78,7 @@ func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Accou // Get (and attempt to create if it doesn't exist) the security/currency // imbalance account for the supplied security/currency -func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) { +func GetImbalanceAccount(tx store.Tx, userid int64, securityid int64) (*models.Account, error) { var imbalanceAccount models.Account var account models.Account xxxtemplate := FindSecurityTemplate("XXX", models.Currency) diff --git a/internal/handlers/accounts_lua.go b/internal/handlers/accounts_lua.go index 2985ff5..6d135a6 100644 --- a/internal/handlers/accounts_lua.go +++ b/internal/handlers/accounts_lua.go @@ -4,7 +4,7 @@ import ( "context" "errors" "github.com/aclindsa/moneygo/internal/models" - "github.com/aclindsa/moneygo/internal/store/db" + "github.com/aclindsa/moneygo/internal/store" "github.com/yuin/gopher-lua" "strings" ) @@ -16,7 +16,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) { ctx := L.Context() - tx, ok := ctx.Value(dbContextKey).(*db.Tx) + tx, ok := ctx.Value(dbContextKey).(store.Tx) if !ok { return nil, errors.New("Couldn't find tx in lua's Context") } @@ -150,7 +150,7 @@ func luaAccountBalance(L *lua.LState) int { a := luaCheckAccount(L, 1) ctx := L.Context() - tx, ok := ctx.Value(dbContextKey).(*db.Tx) + tx, ok := ctx.Value(dbContextKey).(store.Tx) if !ok { panic("Couldn't find tx in lua's Context") } diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 47a58b6..88a11fd 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -17,8 +17,7 @@ type ResponseWriterWriter interface { } type Context struct { - Tx *db.Tx - StoreTx store.Tx + Tx store.Tx User *models.User remainingURL string // portion of URL path not yet reached in the hierarchy } @@ -52,33 +51,6 @@ type APIHandler struct { } func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { - tx, err := GetTx(ah.Store.DbMap) - if err != nil { - log.Print(err) - return NewError(999 /*Internal Error*/) - } - defer func() { - if r := recover(); r != nil { - tx.Rollback() - panic(r) - } - if _, ok := writer.(*Error); ok { - tx.Rollback() - } else { - err = tx.Commit() - if err != nil { - log.Print(err) - writer = NewError(999 /*Internal Error*/) - } - } - }() - - context.Tx = tx - context.StoreTx = tx - return h(r, context) -} - -func (ah *APIHandler) storeTxWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { tx, err := ah.Store.Begin() if err != nil { log.Print(err) @@ -100,7 +72,7 @@ func (ah *APIHandler) storeTxWrapper(h Handler, r *http.Request, context *Contex } }() - context.StoreTx = tx + context.Tx = tx return h(r, context) } diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index b76ae68..58696d0 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -3,7 +3,7 @@ package handlers import ( "encoding/json" "github.com/aclindsa/moneygo/internal/models" - "github.com/aclindsa/moneygo/internal/store/db" + "github.com/aclindsa/moneygo/internal/store" "github.com/aclindsa/ofxgo" "io" "log" @@ -24,7 +24,7 @@ func (od *OFXDownload) Read(json_str string) error { return dec.Decode(od) } -func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter { +func ofxImportHelper(tx store.Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter { itl, err := ImportOFX(r) if err != nil { diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index c51d5c8..f737df1 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -2,16 +2,16 @@ package handlers import ( "github.com/aclindsa/moneygo/internal/models" - "github.com/aclindsa/moneygo/internal/store/db" + "github.com/aclindsa/moneygo/internal/store" "log" "net/http" "time" ) -func CreatePriceIfNotExist(tx *db.Tx, price *models.Price) error { +func CreatePriceIfNotExist(tx store.Tx, price *models.Price) error { if len(price.RemoteId) == 0 { // Always create a new price if we can't match on the RemoteId - err := tx.Insert(price) + err := tx.InsertPrice(price) if err != nil { return err } @@ -34,7 +34,7 @@ func CreatePriceIfNotExist(tx *db.Tx, price *models.Price) error { } // Return the price for security in currency closest to date -func GetClosestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { +func GetClosestPrice(tx store.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { earliest, _ := tx.GetEarliestPrice(security, currency, date) latest, err := tx.GetLatestPrice(security, currency, date) @@ -75,7 +75,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security return NewError(3 /*Invalid Request*/) } - err = context.Tx.Insert(&price) + err = context.Tx.InsertPrice(&price) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index 46d8061..abf554f 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" "github.com/aclindsa/moneygo/internal/models" - "github.com/aclindsa/moneygo/internal/store/db" + "github.com/aclindsa/moneygo/internal/store" "github.com/yuin/gopher-lua" "log" "net/http" @@ -25,57 +25,7 @@ const ( const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for -func GetReport(tx *db.Tx, reportid int64, userid int64) (*models.Report, error) { - var r models.Report - - err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) - if err != nil { - return nil, err - } - return &r, nil -} - -func GetReports(tx *db.Tx, userid int64) (*[]models.Report, error) { - var reports []models.Report - - _, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid) - if err != nil { - return nil, err - } - return &reports, nil -} - -func InsertReport(tx *db.Tx, r *models.Report) error { - err := tx.Insert(r) - if err != nil { - return err - } - return nil -} - -func UpdateReport(tx *db.Tx, r *models.Report) error { - count, err := tx.Update(r) - if err != nil { - return err - } - if count != 1 { - return errors.New("Updated more than one report") - } - return nil -} - -func DeleteReport(tx *db.Tx, r *models.Report) error { - count, err := tx.Delete(r) - if err != nil { - return err - } - if count != 1 { - return errors.New("Deleted more than one report") - } - return nil -} - -func runReport(tx *db.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) { +func runReport(tx store.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) { // Create a new LState without opening the default libs for security L := lua.NewState(lua.Options{SkipOpenLibs: true}) defer L.Close() @@ -139,8 +89,8 @@ func runReport(tx *db.Tx, user *models.User, report *models.Report) (*models.Tab } } -func ReportTabulationHandler(tx *db.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { - report, err := GetReport(tx, reportid, user.UserId) +func ReportTabulationHandler(tx store.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { + report, err := tx.GetReport(reportid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -175,7 +125,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - err = InsertReport(context.Tx, &report) + err = context.Tx.InsertReport(&report) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -186,7 +136,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { if context.LastLevel() { //Return all Reports var rl models.ReportList - reports, err := GetReports(context.Tx, user.UserId) + reports, err := context.Tx.GetReports(user.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -204,7 +154,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { return ReportTabulationHandler(context.Tx, r, user, reportid) } else { // Return Report with this Id - report, err := GetReport(context.Tx, reportid, user.UserId) + report, err := context.Tx.GetReport(reportid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -228,7 +178,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - err = UpdateReport(context.Tx, &report) + err = context.Tx.UpdateReport(&report) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -236,12 +186,12 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { return &report } else if r.Method == "DELETE" { - report, err := GetReport(context.Tx, reportid, user.UserId) + report, err := context.Tx.GetReport(reportid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) } - err = DeleteReport(context.Tx, report) + err = context.Tx.DeleteReport(report) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index 7c720eb..f5e82ff 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -6,7 +6,6 @@ import ( "errors" "github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/store" - "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" "net/url" @@ -51,7 +50,7 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security { return nil } -func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) { +func UpdateSecurity(tx store.Tx, s *models.Security) (err error) { user, err := tx.GetUser(s.UserId) if err != nil { return @@ -67,7 +66,7 @@ func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) { return nil } -func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) (*models.Security, error) { +func ImportGetCreateSecurity(tx store.Tx, userid int64, security *models.Security) (*models.Security, error) { security.UserId = userid if len(security.AlternateId) == 0 { // Always create a new local security if we can't match on the AlternateId diff --git a/internal/handlers/securities_lua.go b/internal/handlers/securities_lua.go index 78716f2..eaaf71d 100644 --- a/internal/handlers/securities_lua.go +++ b/internal/handlers/securities_lua.go @@ -4,7 +4,7 @@ import ( "context" "errors" "github.com/aclindsa/moneygo/internal/models" - "github.com/aclindsa/moneygo/internal/store/db" + "github.com/aclindsa/moneygo/internal/store" "github.com/yuin/gopher-lua" ) @@ -15,7 +15,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error) ctx := L.Context() - tx, ok := ctx.Value(dbContextKey).(*db.Tx) + tx, ok := ctx.Value(dbContextKey).(store.Tx) if !ok { return nil, errors.New("Couldn't find tx in lua's Context") } @@ -159,7 +159,7 @@ func luaClosestPrice(L *lua.LState) int { date := luaCheckTime(L, 3) ctx := L.Context() - tx, ok := ctx.Value(dbContextKey).(*db.Tx) + tx, ok := ctx.Value(dbContextKey).(store.Tx) if !ok { panic("Couldn't find tx in lua's Context") } diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index 71deff4..e273cb1 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -89,7 +89,7 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter { // attacks user.HashPassword() - dbuser, err := context.StoreTx.GetUserByUsername(user.Username) + dbuser, err := context.Tx.GetUserByUsername(user.Username) if err != nil { return NewError(2 /*Unauthorized Access*/) } @@ -98,21 +98,21 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(2 /*Unauthorized Access*/) } - sessionwriter, err := NewSession(context.StoreTx, r, dbuser.UserId) + sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) } return sessionwriter } else if r.Method == "GET" { - s, err := GetSession(context.StoreTx, r) + s, err := GetSession(context.Tx, r) if err != nil { return NewError(1 /*Not Signed In*/) } return s } else if r.Method == "DELETE" { - err := DeleteSessionIfExists(context.StoreTx, r) + err := DeleteSessionIfExists(context.Tx, r) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 2d16da1..1d522d0 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/store" - "github.com/aclindsa/moneygo/internal/store/db" "log" "math/big" "net/http" @@ -14,7 +13,7 @@ import ( // Return a map of security ID's to big.Rat's containing the amount that // security is imbalanced by -func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.Rat, error) { +func GetTransactionImbalances(tx store.Tx, t *models.Transaction) (map[int64]big.Rat, error) { sums := make(map[int64]big.Rat) if !t.Valid() { @@ -42,7 +41,7 @@ func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.R // Returns true if all securities contained in this transaction are balanced, // false otherwise -func TransactionBalanced(tx *db.Tx, t *models.Transaction) (bool, error) { +func TransactionBalanced(tx store.Tx, t *models.Transaction) (bool, error) { var zero big.Rat sums, err := GetTransactionImbalances(tx, t) diff --git a/internal/handlers/tx.go b/internal/handlers/tx.go deleted file mode 100644 index 750220a..0000000 --- a/internal/handlers/tx.go +++ /dev/null @@ -1,14 +0,0 @@ -package handlers - -import ( - "github.com/aclindsa/gorp" - "github.com/aclindsa/moneygo/internal/store/db" -) - -func GetTx(gdb *gorp.DbMap) (*db.Tx, error) { - tx, err := gdb.Begin() - if err != nil { - return nil, err - } - return &db.Tx{gdb.Dialect, tx}, nil -} diff --git a/internal/handlers/users.go b/internal/handlers/users.go index e9a468d..ed7275d 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -140,7 +140,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter { return user } else if r.Method == "DELETE" { - err := context.StoreTx.DeleteUser(user) + err := context.Tx.DeleteUser(user) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/models/reports.go b/internal/models/reports.go index 493fd21..01606ee 100644 --- a/internal/models/reports.go +++ b/internal/models/reports.go @@ -28,7 +28,7 @@ func (r *Report) Read(json_str string) error { } type ReportList struct { - Reports *[]Report `json:"reports"` + Reports *[]*Report `json:"reports"` } func (rl *ReportList) Write(w http.ResponseWriter) error { diff --git a/internal/store/db/reports.go b/internal/store/db/reports.go new file mode 100644 index 0000000..e220695 --- /dev/null +++ b/internal/store/db/reports.go @@ -0,0 +1,56 @@ +package db + +import ( + "fmt" + "github.com/aclindsa/moneygo/internal/models" +) + +func (tx *Tx) GetReport(reportid int64, userid int64) (*models.Report, error) { + var r models.Report + + err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) + if err != nil { + return nil, err + } + return &r, nil +} + +func (tx *Tx) GetReports(userid int64) (*[]*models.Report, error) { + var reports []*models.Report + + _, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid) + if err != nil { + return nil, err + } + return &reports, nil +} + +func (tx *Tx) InsertReport(report *models.Report) error { + err := tx.Insert(report) + if err != nil { + return err + } + return nil +} + +func (tx *Tx) UpdateReport(report *models.Report) error { + count, err := tx.Update(report) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to update 1 report, was going to update %d", count) + } + return nil +} + +func (tx *Tx) DeleteReport(report *models.Report) error { + count, err := tx.Delete(report) + if err != nil { + return err + } + if count != 1 { + return fmt.Errorf("Expected to delete 1 report, was going to delete %d", count) + } + return nil +} diff --git a/internal/store/store.go b/internal/store/store.go index c890dd7..3f87880 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -96,6 +96,11 @@ type TransactionStore interface { } type ReportStore interface { + InsertReport(report *models.Report) error + GetReport(reportid int64, userid int64) (*models.Report, error) + GetReports(userid int64) (*[]*models.Report, error) + UpdateReport(report *models.Report) error + DeleteReport(report *models.Report) error } type Tx interface { From e89198fe2e21a251db8c086d2066438e3d6b00a8 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sat, 9 Dec 2017 05:59:30 -0500 Subject: [PATCH 8/9] .travis.yml: Don't wait on OSX builds --- .travis.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 340436f..8ee80f7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,12 +20,11 @@ env: - MONEYGO_TEST_DB=mysql - MONEYGO_TEST_DB=postgres -# OSX builds take too long, so don't wait for all of them +# OSX builds take too long, so don't wait for them matrix: fast_finish: true allow_failures: - os: osx - go: master before_install: # Fetch/build coverage reporting tools From c4d2fe27e6651a9faec6a7518154483f888f823d Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sat, 9 Dec 2017 06:06:20 -0500 Subject: [PATCH 9/9] Take unexported database methods 'private' --- internal/handlers/common_test.go | 5 +---- internal/store/db/db.go | 22 +++++++++++++--------- internal/store/store.go | 1 + 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/internal/handlers/common_test.go b/internal/handlers/common_test.go index 6cbd9ef..87fd7dc 100644 --- a/internal/handlers/common_test.go +++ b/internal/handlers/common_test.go @@ -258,10 +258,7 @@ func RunTests(m *testing.M) int { } defer db.Close() - err = db.DbMap.TruncateTables() - if err != nil { - log.Fatal(err) - } + db.Empty() // clear the DB tables server = httptest.NewTLSServer(&handlers.APIHandler{Store: db}) defer server.Close() diff --git a/internal/store/db/db.go b/internal/store/db/db.go index d8b043b..ba01e61 100644 --- a/internal/store/db/db.go +++ b/internal/store/db/db.go @@ -19,7 +19,7 @@ import ( // implementation's string type specified by the same. const luaMaxLengthBuffer int = 4096 -func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { +func getDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { var dialect gorp.Dialect if dbtype == config.SQLite { dialect = gorp.SqliteDialect{} @@ -55,7 +55,7 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { return dbmap, nil } -func GetDSN(dbtype config.DbType, dsn string) string { +func getDSN(dbtype config.DbType, dsn string) string { if dbtype == config.MySQL && !strings.Contains(dsn, "parseTime=true") { log.Fatalf("The DSN for MySQL MUST contain 'parseTime=True' but does not!") } @@ -63,25 +63,29 @@ func GetDSN(dbtype config.DbType, dsn string) string { } type DbStore struct { - DbMap *gorp.DbMap + dbMap *gorp.DbMap +} + +func (db *DbStore) Empty() error { + return db.dbMap.TruncateTables() } func (db *DbStore) Begin() (store.Tx, error) { - tx, err := db.DbMap.Begin() + tx, err := db.dbMap.Begin() if err != nil { return nil, err } - return &Tx{db.DbMap.Dialect, tx}, nil + return &Tx{db.dbMap.Dialect, tx}, nil } func (db *DbStore) Close() error { - err := db.DbMap.Db.Close() - db.DbMap = nil + err := db.dbMap.Db.Close() + db.dbMap = nil return err } func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) { - dsn = GetDSN(dbtype, dsn) + dsn = getDSN(dbtype, dsn) database, err := sql.Open(dbtype.String(), dsn) if err != nil { return nil, err @@ -92,7 +96,7 @@ func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) { } }() - dbmap, err := GetDbMap(database, dbtype) + dbmap, err := getDbMap(database, dbtype) if err != nil { return nil, err } diff --git a/internal/store/store.go b/internal/store/store.go index 3f87880..412b7c6 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -117,6 +117,7 @@ type Tx interface { } type Store interface { + Empty() error Begin() (Tx, error) Close() error }