diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index 2812fb4..9ef9521 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -3,11 +3,12 @@ package handlers import ( "errors" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" ) -func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) { +func GetAccount(tx *db.Tx, accountid int64, userid int64) (*models.Account, error) { var a models.Account err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) @@ -17,7 +18,7 @@ func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) return &a, nil } -func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) { +func GetAccounts(tx *db.Tx, userid int64) (*[]models.Account, error) { var accounts []models.Account _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) @@ -29,7 +30,7 @@ func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) { // Get (and attempt to create if it doesn't exist). Matches on UserId, // SecurityId, Type, Name, and ParentAccountId -func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) { +func GetCreateAccount(tx *db.Tx, a models.Account) (*models.Account, error) { var accounts []models.Account var account models.Account @@ -57,7 +58,7 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) { // Get (and attempt to create if it doesn't exist) the security/currency // trading account for the supplied security/currency -func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) { +func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) { var tradingAccount models.Account var account models.Account @@ -99,7 +100,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, // Get (and attempt to create if it doesn't exist) the security/currency // imbalance account for the supplied security/currency -func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) { +func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) { var imbalanceAccount models.Account var account models.Account xxxtemplate := FindSecurityTemplate("XXX", models.Currency) @@ -160,7 +161,7 @@ func (cae CircularAccountsError) Error() string { return "Would result in circular account relationship" } -func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error { +func insertUpdateAccount(tx *db.Tx, a *models.Account, insert bool) error { found := make(map[int64]bool) if !insert { found[a.AccountId] = true @@ -216,15 +217,15 @@ func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error { return nil } -func InsertAccount(tx *Tx, a *models.Account) error { +func InsertAccount(tx *db.Tx, a *models.Account) error { return insertUpdateAccount(tx, a, true) } -func UpdateAccount(tx *Tx, a *models.Account) error { +func UpdateAccount(tx *db.Tx, a *models.Account) error { return insertUpdateAccount(tx, a, false) } -func DeleteAccount(tx *Tx, a *models.Account) error { +func DeleteAccount(tx *db.Tx, a *models.Account) error { if a.ParentAccountId != -1 { // Re-parent splits to this account's parent account if this account isn't a root account _, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) diff --git a/internal/handlers/accounts_lua.go b/internal/handlers/accounts_lua.go index 5a2fc23..ee91eb2 100644 --- a/internal/handlers/accounts_lua.go +++ b/internal/handlers/accounts_lua.go @@ -4,6 +4,7 @@ import ( "context" "errors" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "github.com/yuin/gopher-lua" "math/big" "strings" @@ -16,7 +17,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) { ctx := L.Context() - tx, ok := ctx.Value(dbContextKey).(*Tx) + tx, ok := ctx.Value(dbContextKey).(*db.Tx) if !ok { return nil, errors.New("Couldn't find tx in lua's Context") } @@ -150,7 +151,7 @@ func luaAccountBalance(L *lua.LState) int { a := luaCheckAccount(L, 1) ctx := L.Context() - tx, ok := ctx.Value(dbContextKey).(*Tx) + tx, ok := ctx.Value(dbContextKey).(*db.Tx) if !ok { panic("Couldn't find tx in lua's Context") } diff --git a/internal/handlers/common_test.go b/internal/handlers/common_test.go index a0ef7f8..6cbd9ef 100644 --- a/internal/handlers/common_test.go +++ b/internal/handlers/common_test.go @@ -2,12 +2,11 @@ package handlers_test import ( "bytes" - "database/sql" "encoding/json" "github.com/aclindsa/moneygo/internal/config" - "github.com/aclindsa/moneygo/internal/db" "github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "io" "io/ioutil" "log" @@ -253,24 +252,18 @@ func RunTests(m *testing.M) int { dsn = envDSN } - dsn = db.GetDSN(dbType, dsn) - database, err := sql.Open(dbType.String(), dsn) + db, err := db.GetStore(dbType, dsn) if err != nil { log.Fatal(err) } - defer database.Close() + defer db.Close() - dbmap, err := db.GetDbMap(database, dbType) + err = db.DbMap.TruncateTables() if err != nil { log.Fatal(err) } - err = dbmap.TruncateTables() - if err != nil { - log.Fatal(err) - } - - server = httptest.NewTLSServer(&handlers.APIHandler{DB: dbmap}) + server = httptest.NewTLSServer(&handlers.APIHandler{Store: db}) defer server.Close() return m.Run() diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 309eed1..47a58b6 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -1,8 +1,9 @@ package handlers import ( - "github.com/aclindsa/gorp" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" + "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" "path" @@ -16,7 +17,8 @@ type ResponseWriterWriter interface { } type Context struct { - Tx *Tx + Tx *db.Tx + StoreTx store.Tx User *models.User remainingURL string // portion of URL path not yet reached in the hierarchy } @@ -46,11 +48,11 @@ func (c *Context) LastLevel() bool { type Handler func(*http.Request, *Context) ResponseWriterWriter type APIHandler struct { - DB *gorp.DbMap + Store *db.DbStore } func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { - tx, err := GetTx(ah.DB) + tx, err := GetTx(ah.Store.DbMap) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -72,6 +74,33 @@ func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (w }() context.Tx = tx + context.StoreTx = tx + return h(r, context) +} + +func (ah *APIHandler) storeTxWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { + tx, err := ah.Store.Begin() + if err != nil { + log.Print(err) + return NewError(999 /*Internal Error*/) + } + defer func() { + if r := recover(); r != nil { + tx.Rollback() + panic(r) + } + if _, ok := writer.(*Error); ok { + tx.Rollback() + } else { + err = tx.Commit() + if err != nil { + log.Print(err) + writer = NewError(999 /*Internal Error*/) + } + } + }() + + context.StoreTx = tx return h(r, context) } diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 78d5236..08267d3 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -3,6 +3,7 @@ package handlers import ( "encoding/json" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "github.com/aclindsa/ofxgo" "io" "log" @@ -23,7 +24,7 @@ func (od *OFXDownload) Read(json_str string) error { return dec.Decode(od) } -func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter { +func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter { itl, err := ImportOFX(r) if err != nil { diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index c92eeb4..620150d 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -2,12 +2,13 @@ package handlers import ( "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" "time" ) -func CreatePriceIfNotExist(tx *Tx, price *models.Price) error { +func CreatePriceIfNotExist(tx *db.Tx, price *models.Price) error { if len(price.RemoteId) == 0 { // Always create a new price if we can't match on the RemoteId err := tx.Insert(price) @@ -35,7 +36,7 @@ func CreatePriceIfNotExist(tx *Tx, price *models.Price) error { return nil } -func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) { +func GetPrice(tx *db.Tx, priceid, securityid int64) (*models.Price, error) { var p models.Price err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid) if err != nil { @@ -44,7 +45,7 @@ func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) { return &p, nil } -func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) { +func GetPrices(tx *db.Tx, securityid int64) (*[]*models.Price, error) { var prices []*models.Price _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid) @@ -55,7 +56,7 @@ func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) { } // Return the latest price for security in currency units before date -func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { +func GetLatestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { var p models.Price err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { @@ -65,7 +66,7 @@ func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time } // Return the earliest price for security in currency units after date -func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { +func GetEarliestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { var p models.Price err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { @@ -75,7 +76,7 @@ func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Ti } // Return the price for security in currency closest to date -func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { +func GetClosestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { earliest, _ := GetEarliestPrice(tx, security, currency, date) latest, err := GetLatestPrice(tx, security, currency, date) diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index bec9525..46d8061 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "github.com/yuin/gopher-lua" "log" "net/http" @@ -24,7 +25,7 @@ const ( const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for -func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) { +func GetReport(tx *db.Tx, reportid int64, userid int64) (*models.Report, error) { var r models.Report err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) @@ -34,7 +35,7 @@ func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) { return &r, nil } -func GetReports(tx *Tx, userid int64) (*[]models.Report, error) { +func GetReports(tx *db.Tx, userid int64) (*[]models.Report, error) { var reports []models.Report _, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid) @@ -44,7 +45,7 @@ func GetReports(tx *Tx, userid int64) (*[]models.Report, error) { return &reports, nil } -func InsertReport(tx *Tx, r *models.Report) error { +func InsertReport(tx *db.Tx, r *models.Report) error { err := tx.Insert(r) if err != nil { return err @@ -52,7 +53,7 @@ func InsertReport(tx *Tx, r *models.Report) error { return nil } -func UpdateReport(tx *Tx, r *models.Report) error { +func UpdateReport(tx *db.Tx, r *models.Report) error { count, err := tx.Update(r) if err != nil { return err @@ -63,7 +64,7 @@ func UpdateReport(tx *Tx, r *models.Report) error { return nil } -func DeleteReport(tx *Tx, r *models.Report) error { +func DeleteReport(tx *db.Tx, r *models.Report) error { count, err := tx.Delete(r) if err != nil { return err @@ -74,7 +75,7 @@ func DeleteReport(tx *Tx, r *models.Report) error { return nil } -func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabulation, error) { +func runReport(tx *db.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) { // Create a new LState without opening the default libs for security L := lua.NewState(lua.Options{SkipOpenLibs: true}) defer L.Close() @@ -138,7 +139,7 @@ func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabula } } -func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { +func ReportTabulationHandler(tx *db.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { report, err := GetReport(tx, reportid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index ab58de4..e836822 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" "net/url" @@ -50,7 +51,7 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security { return nil } -func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, error) { +func GetSecurity(tx *db.Tx, securityid int64, userid int64) (*models.Security, error) { var s models.Security err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) @@ -60,7 +61,7 @@ func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, erro return &s, nil } -func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) { +func GetSecurities(tx *db.Tx, userid int64) (*[]*models.Security, error) { var securities []*models.Security _, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) @@ -70,7 +71,7 @@ func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) { return &securities, nil } -func InsertSecurity(tx *Tx, s *models.Security) error { +func InsertSecurity(tx *db.Tx, s *models.Security) error { err := tx.Insert(s) if err != nil { return err @@ -78,7 +79,7 @@ func InsertSecurity(tx *Tx, s *models.Security) error { return nil } -func UpdateSecurity(tx *Tx, s *models.Security) (err error) { +func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) { user, err := GetUser(tx, s.UserId) if err != nil { return @@ -105,7 +106,7 @@ func (e SecurityInUseError) Error() string { return e.message } -func DeleteSecurity(tx *Tx, s *models.Security) error { +func DeleteSecurity(tx *db.Tx, s *models.Security) error { // First, ensure no accounts are using this security accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) @@ -138,7 +139,7 @@ func DeleteSecurity(tx *Tx, s *models.Security) error { return nil } -func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*models.Security, error) { +func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) (*models.Security, error) { security.UserId = userid if len(security.AlternateId) == 0 { // Always create a new local security if we can't match on the AlternateId diff --git a/internal/handlers/securities_lua.go b/internal/handlers/securities_lua.go index 12783ce..a5307c6 100644 --- a/internal/handlers/securities_lua.go +++ b/internal/handlers/securities_lua.go @@ -4,6 +4,7 @@ import ( "context" "errors" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "github.com/yuin/gopher-lua" ) @@ -14,7 +15,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error) ctx := L.Context() - tx, ok := ctx.Value(dbContextKey).(*Tx) + tx, ok := ctx.Value(dbContextKey).(*db.Tx) if !ok { return nil, errors.New("Couldn't find tx in lua's Context") } @@ -158,7 +159,7 @@ func luaClosestPrice(L *lua.LState) int { date := luaCheckTime(L, 3) ctx := L.Context() - tx, ok := ctx.Value(dbContextKey).(*Tx) + tx, ok := ctx.Value(dbContextKey).(*db.Tx) if !ok { panic("Couldn't find tx in lua's Context") } diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index 8349613..f4d6c5e 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -3,36 +3,37 @@ package handlers import ( "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" "log" "net/http" "time" ) -func GetSession(tx *Tx, r *http.Request) (*models.Session, error) { - var s models.Session - +func GetSession(tx store.Tx, r *http.Request) (*models.Session, error) { cookie, err := r.Cookie("moneygo-session") if err != nil { return nil, fmt.Errorf("moneygo-session cookie not set") } - s.SessionSecret = cookie.Value - err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) + s, err := tx.GetSession(cookie.Value) if err != nil { return nil, err } if s.Expires.Before(time.Now()) { - tx.Delete(&s) + err := tx.DeleteSession(s) + if err != nil { + log.Printf("Unexpected error when attempting to delete expired session: %s", err) + } return nil, fmt.Errorf("Session has expired") } - return &s, nil + return s, nil } -func DeleteSessionIfExists(tx *Tx, r *http.Request) error { +func DeleteSessionIfExists(tx store.Tx, r *http.Request) error { session, err := GetSession(tx, r) if err == nil { - _, err := tx.Delete(session) + err := tx.DeleteSession(session) if err != nil { return err } @@ -50,21 +51,26 @@ func (n *NewSessionWriter) Write(w http.ResponseWriter) error { return n.session.Write(w) } -func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) { +func NewSession(tx store.Tx, r *http.Request, userid int64) (*NewSessionWriter, error) { + err := DeleteSessionIfExists(tx, r) + if err != nil { + return nil, err + } + s, err := models.NewSession(userid) if err != nil { return nil, err } - existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", s.SessionSecret) + exists, err := tx.SessionExists(s.SessionSecret) if err != nil { return nil, err } - if existing > 0 { - return nil, fmt.Errorf("%d session(s) exist with the generated session_secret", existing) + if exists { + return nil, fmt.Errorf("Session already exists with the generated session_secret") } - err = tx.Insert(s) + err = tx.InsertSession(s) if err != nil { return nil, err } @@ -89,27 +95,21 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(2 /*Unauthorized Access*/) } - err = DeleteSessionIfExists(context.Tx, r) - if err != nil { - log.Print(err) - return NewError(999 /*Internal Error*/) - } - - sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId) + sessionwriter, err := NewSession(context.StoreTx, r, dbuser.UserId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) } return sessionwriter } else if r.Method == "GET" { - s, err := GetSession(context.Tx, r) + s, err := GetSession(context.StoreTx, r) if err != nil { return NewError(1 /*Not Signed In*/) } return s } else if r.Method == "DELETE" { - err := DeleteSessionIfExists(context.Tx, r) + err := DeleteSessionIfExists(context.StoreTx, r) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 3795ce7..4ec94d1 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "log" "math/big" "net/http" @@ -12,14 +13,14 @@ import ( "time" ) -func SplitAlreadyImported(tx *Tx, s *models.Split) (bool, error) { +func SplitAlreadyImported(tx *db.Tx, s *models.Split) (bool, error) { count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) return count == 1, err } // Return a map of security ID's to big.Rat's containing the amount that // security is imbalanced by -func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, error) { +func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.Rat, error) { sums := make(map[int64]big.Rat) if !t.Valid() { @@ -47,7 +48,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, // Returns true if all securities contained in this transaction are balanced, // false otherwise -func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) { +func TransactionBalanced(tx *db.Tx, t *models.Transaction) (bool, error) { var zero big.Rat sums, err := GetTransactionImbalances(tx, t) @@ -63,7 +64,7 @@ func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) { return true, nil } -func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transaction, error) { +func GetTransaction(tx *db.Tx, transactionid int64, userid int64) (*models.Transaction, error) { var t models.Transaction err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) @@ -79,7 +80,7 @@ func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transact return &t, nil } -func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) { +func GetTransactions(tx *db.Tx, userid int64) (*[]models.Transaction, error) { var transactions []models.Transaction _, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) @@ -97,7 +98,7 @@ func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) { return &transactions, nil } -func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error { +func incrementAccountVersions(tx *db.Tx, user *models.User, accountids []int64) error { for i := range accountids { account, err := GetAccount(tx, accountids[i], user.UserId) if err != nil { @@ -121,7 +122,7 @@ func (ame AccountMissingError) Error() string { return "Account missing" } -func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error { +func InsertTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error { // Map of any accounts with transaction splits being added a_map := make(map[int64]bool) for i := range t.Splits { @@ -171,7 +172,7 @@ func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error { return nil } -func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error { +func UpdateTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error { var existing_splits []*models.Split _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) @@ -248,7 +249,7 @@ func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error { return nil } -func DeleteTransaction(tx *Tx, t *models.Transaction, user *models.User) error { +func DeleteTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error { var accountids []int64 _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) if err != nil { @@ -401,7 +402,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return NewError(3 /*Invalid Request*/) } -func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) { +func TransactionsBalanceDifference(tx *db.Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) { var pageDifference, tmp big.Rat for i := range transactions { _, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) @@ -425,7 +426,7 @@ func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []model return &pageDifference, nil } -func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) { +func GetAccountBalance(tx *db.Tx, user *models.User, accountid int64) (*big.Rat, error) { var splits []models.Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" @@ -448,7 +449,7 @@ func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, er } // Assumes accountid is valid and is owned by the current user -func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) { +func GetAccountBalanceDate(tx *db.Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) { var splits []models.Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" @@ -470,7 +471,7 @@ func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *tim return &balance, nil } -func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) { +func GetAccountBalanceDateRange(tx *db.Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) { var splits []models.Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" @@ -492,7 +493,7 @@ func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begi return &balance, nil } -func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) { +func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) { var transactions []models.Transaction var atl models.AccountTransactionsList diff --git a/internal/handlers/tx.go b/internal/handlers/tx.go index c0db452..750220a 100644 --- a/internal/handlers/tx.go +++ b/internal/handlers/tx.go @@ -1,65 +1,14 @@ package handlers import ( - "database/sql" "github.com/aclindsa/gorp" - "strings" + "github.com/aclindsa/moneygo/internal/store/db" ) -type Tx struct { - Dialect gorp.Dialect - Tx *gorp.Transaction -} - -func (tx *Tx) Rebind(query string) string { - chunks := strings.Split(query, "?") - str := chunks[0] - for i := 1; i < len(chunks); i++ { - str += tx.Dialect.BindVar(i-1) + chunks[i] - } - return str -} - -func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { - return tx.Tx.Select(i, tx.Rebind(query), args...) -} - -func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { - return tx.Tx.Exec(tx.Rebind(query), args...) -} - -func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) { - return tx.Tx.SelectInt(tx.Rebind(query), args...) -} - -func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error { - return tx.Tx.SelectOne(holder, tx.Rebind(query), args...) -} - -func (tx *Tx) Insert(list ...interface{}) error { - return tx.Tx.Insert(list...) -} - -func (tx *Tx) Update(list ...interface{}) (int64, error) { - return tx.Tx.Update(list...) -} - -func (tx *Tx) Delete(list ...interface{}) (int64, error) { - return tx.Tx.Delete(list...) -} - -func (tx *Tx) Commit() error { - return tx.Tx.Commit() -} - -func (tx *Tx) Rollback() error { - return tx.Tx.Rollback() -} - -func GetTx(db *gorp.DbMap) (*Tx, error) { - tx, err := db.Begin() +func GetTx(gdb *gorp.DbMap) (*db.Tx, error) { + tx, err := gdb.Begin() if err != nil { return nil, err } - return &Tx{db.Dialect, tx}, nil + return &db.Tx{gdb.Dialect, tx}, nil } diff --git a/internal/handlers/users.go b/internal/handlers/users.go index ba1a9d0..7225fe5 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store/db" "log" "net/http" ) @@ -14,7 +15,7 @@ func (ueu UserExistsError) Error() string { return "User exists" } -func GetUser(tx *Tx, userid int64) (*models.User, error) { +func GetUser(tx *db.Tx, userid int64) (*models.User, error) { var u models.User err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) @@ -24,7 +25,7 @@ func GetUser(tx *Tx, userid int64) (*models.User, error) { return &u, nil } -func GetUserByUsername(tx *Tx, username string) (*models.User, error) { +func GetUserByUsername(tx *db.Tx, username string) (*models.User, error) { var u models.User err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) @@ -34,7 +35,7 @@ func GetUserByUsername(tx *Tx, username string) (*models.User, error) { return &u, nil } -func InsertUser(tx *Tx, u *models.User) error { +func InsertUser(tx *db.Tx, u *models.User) error { security_template := FindCurrencyTemplate(u.DefaultCurrency) if security_template == nil { return errors.New("Invalid ISO4217 Default Currency") @@ -75,7 +76,7 @@ func InsertUser(tx *Tx, u *models.User) error { return nil } -func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) { +func GetUserFromSession(tx *db.Tx, r *http.Request) (*models.User, error) { s, err := GetSession(tx, r) if err != nil { return nil, err @@ -83,7 +84,7 @@ func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) { return GetUser(tx, s.UserId) } -func UpdateUser(tx *Tx, u *models.User) error { +func UpdateUser(tx *db.Tx, u *models.User) error { security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId) if err != nil { return err @@ -103,7 +104,7 @@ func UpdateUser(tx *Tx, u *models.User) error { return nil } -func DeleteUser(tx *Tx, u *models.User) error { +func DeleteUser(tx *db.Tx, u *models.User) error { count, err := tx.Delete(u) if err != nil { return err diff --git a/internal/db/db.go b/internal/store/db/db.go similarity index 74% rename from internal/db/db.go rename to internal/store/db/db.go index a33fb08..3a4031e 100644 --- a/internal/db/db.go +++ b/internal/store/db/db.go @@ -6,6 +6,7 @@ import ( "github.com/aclindsa/gorp" "github.com/aclindsa/moneygo/internal/config" "github.com/aclindsa/moneygo/internal/models" + "github.com/aclindsa/moneygo/internal/store" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -60,3 +61,40 @@ func GetDSN(dbtype config.DbType, dsn string) string { } return dsn } + +type DbStore struct { + DbMap *gorp.DbMap +} + +func (db *DbStore) Begin() (store.Tx, error) { + tx, err := db.DbMap.Begin() + if err != nil { + return nil, err + } + return &Tx{db.DbMap.Dialect, tx}, nil +} + +func (db *DbStore) Close() error { + err := db.DbMap.Db.Close() + db.DbMap = nil + return err +} + +func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) { + dsn = GetDSN(dbtype, dsn) + database, err := sql.Open(dbtype.String(), dsn) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + database.Close() + } + }() + + dbmap, err := GetDbMap(database, dbtype) + if err != nil { + return nil, err + } + return &DbStore{dbmap}, nil +} diff --git a/internal/store/db/sessions.go b/internal/store/db/sessions.go new file mode 100644 index 0000000..671c55a --- /dev/null +++ b/internal/store/db/sessions.go @@ -0,0 +1,36 @@ +package db + +import ( + "fmt" + "github.com/aclindsa/moneygo/internal/models" + "time" +) + +func (tx *Tx) InsertSession(session *models.Session) error { + return tx.Insert(session) +} + +func (tx *Tx) GetSession(secret string) (*models.Session, error) { + var s models.Session + + err := tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", secret) + if err != nil { + return nil, err + } + + if s.Expires.Before(time.Now()) { + tx.Delete(&s) + return nil, fmt.Errorf("Session has expired") + } + return &s, nil +} + +func (tx *Tx) SessionExists(secret string) (bool, error) { + existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", secret) + return existing != 0, err +} + +func (tx *Tx) DeleteSession(session *models.Session) error { + _, err := tx.Delete(session) + return err +} diff --git a/internal/store/db/tx.go b/internal/store/db/tx.go new file mode 100644 index 0000000..5187201 --- /dev/null +++ b/internal/store/db/tx.go @@ -0,0 +1,57 @@ +package db + +import ( + "database/sql" + "github.com/aclindsa/gorp" + "strings" +) + +type Tx struct { + Dialect gorp.Dialect + Tx *gorp.Transaction +} + +func (tx *Tx) Rebind(query string) string { + chunks := strings.Split(query, "?") + str := chunks[0] + for i := 1; i < len(chunks); i++ { + str += tx.Dialect.BindVar(i-1) + chunks[i] + } + return str +} + +func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { + return tx.Tx.Select(i, tx.Rebind(query), args...) +} + +func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { + return tx.Tx.Exec(tx.Rebind(query), args...) +} + +func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) { + return tx.Tx.SelectInt(tx.Rebind(query), args...) +} + +func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error { + return tx.Tx.SelectOne(holder, tx.Rebind(query), args...) +} + +func (tx *Tx) Insert(list ...interface{}) error { + return tx.Tx.Insert(list...) +} + +func (tx *Tx) Update(list ...interface{}) (int64, error) { + return tx.Tx.Update(list...) +} + +func (tx *Tx) Delete(list ...interface{}) (int64, error) { + return tx.Tx.Delete(list...) +} + +func (tx *Tx) Commit() error { + return tx.Tx.Commit() +} + +func (tx *Tx) Rollback() error { + return tx.Tx.Rollback() +} diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..9823236 --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,24 @@ +package store + +import ( + "github.com/aclindsa/moneygo/internal/models" +) + +type SessionStore interface { + InsertSession(session *models.Session) error + GetSession(secret string) (*models.Session, error) + SessionExists(secret string) (bool, error) + DeleteSession(session *models.Session) error +} + +type Tx interface { + Commit() error + Rollback() error + + SessionStore +} + +type Store interface { + Begin() (Tx, error) + Close() error +} diff --git a/main.go b/main.go index 2baf7c0..d87b5fa 100644 --- a/main.go +++ b/main.go @@ -3,11 +3,10 @@ package main //go:generate make import ( - "database/sql" "flag" "github.com/aclindsa/moneygo/internal/config" - "github.com/aclindsa/moneygo/internal/db" "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/store/db" "github.com/kabukky/httpscerts" "log" "net" @@ -67,21 +66,15 @@ func staticHandler(w http.ResponseWriter, r *http.Request, basedir string) { } func main() { - dsn := db.GetDSN(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN) - database, err := sql.Open(cfg.MoneyGo.DBType.String(), dsn) - if err != nil { - log.Fatal(err) - } - defer database.Close() - - dbmap, err := db.GetDbMap(database, cfg.MoneyGo.DBType) + db, err := db.GetStore(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN) if err != nil { log.Fatal(err) } + defer db.Close() // Get ServeMux for API and add our own handlers for files servemux := http.NewServeMux() - servemux.Handle("/v1/", &handlers.APIHandler{DB: dbmap}) + servemux.Handle("/v1/", &handlers.APIHandler{Store: db}) servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir)) servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir))