1
0
mirror of https://github.com/aclindsa/moneygo.git synced 2024-12-25 23:23:21 -05:00

Lay groundwork and move sessions to 'store'

This commit is contained in:
Aaron Lindsay 2017-12-06 21:09:47 -05:00
parent 6bdde8e83b
commit c452984f23
18 changed files with 286 additions and 158 deletions

View File

@ -3,11 +3,12 @@ package handlers
import ( import (
"errors" "errors"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"log" "log"
"net/http" "net/http"
) )
func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) { func GetAccount(tx *db.Tx, accountid int64, userid int64) (*models.Account, error) {
var a models.Account var a models.Account
err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
@ -17,7 +18,7 @@ func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error)
return &a, nil return &a, nil
} }
func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) { func GetAccounts(tx *db.Tx, userid int64) (*[]models.Account, error) {
var accounts []models.Account var accounts []models.Account
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
@ -29,7 +30,7 @@ func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) {
// Get (and attempt to create if it doesn't exist). Matches on UserId, // Get (and attempt to create if it doesn't exist). Matches on UserId,
// SecurityId, Type, Name, and ParentAccountId // SecurityId, Type, Name, and ParentAccountId
func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) { func GetCreateAccount(tx *db.Tx, a models.Account) (*models.Account, error) {
var accounts []models.Account var accounts []models.Account
var account models.Account var account models.Account
@ -57,7 +58,7 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
// Get (and attempt to create if it doesn't exist) the security/currency // Get (and attempt to create if it doesn't exist) the security/currency
// trading account for the supplied security/currency // trading account for the supplied security/currency
func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) { func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) {
var tradingAccount models.Account var tradingAccount models.Account
var account models.Account var account models.Account
@ -99,7 +100,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account,
// Get (and attempt to create if it doesn't exist) the security/currency // Get (and attempt to create if it doesn't exist) the security/currency
// imbalance account for the supplied security/currency // imbalance account for the supplied security/currency
func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) { func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) {
var imbalanceAccount models.Account var imbalanceAccount models.Account
var account models.Account var account models.Account
xxxtemplate := FindSecurityTemplate("XXX", models.Currency) xxxtemplate := FindSecurityTemplate("XXX", models.Currency)
@ -160,7 +161,7 @@ func (cae CircularAccountsError) Error() string {
return "Would result in circular account relationship" return "Would result in circular account relationship"
} }
func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error { func insertUpdateAccount(tx *db.Tx, a *models.Account, insert bool) error {
found := make(map[int64]bool) found := make(map[int64]bool)
if !insert { if !insert {
found[a.AccountId] = true found[a.AccountId] = true
@ -216,15 +217,15 @@ func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error {
return nil return nil
} }
func InsertAccount(tx *Tx, a *models.Account) error { func InsertAccount(tx *db.Tx, a *models.Account) error {
return insertUpdateAccount(tx, a, true) return insertUpdateAccount(tx, a, true)
} }
func UpdateAccount(tx *Tx, a *models.Account) error { func UpdateAccount(tx *db.Tx, a *models.Account) error {
return insertUpdateAccount(tx, a, false) return insertUpdateAccount(tx, a, false)
} }
func DeleteAccount(tx *Tx, a *models.Account) error { func DeleteAccount(tx *db.Tx, a *models.Account) error {
if a.ParentAccountId != -1 { if a.ParentAccountId != -1 {
// Re-parent splits to this account's parent account if this account isn't a root account // Re-parent splits to this account's parent account if this account isn't a root account
_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) _, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId)

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua"
"math/big" "math/big"
"strings" "strings"
@ -16,7 +17,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) {
ctx := L.Context() ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(*Tx) tx, ok := ctx.Value(dbContextKey).(*db.Tx)
if !ok { if !ok {
return nil, errors.New("Couldn't find tx in lua's Context") return nil, errors.New("Couldn't find tx in lua's Context")
} }
@ -150,7 +151,7 @@ func luaAccountBalance(L *lua.LState) int {
a := luaCheckAccount(L, 1) a := luaCheckAccount(L, 1)
ctx := L.Context() ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(*Tx) tx, ok := ctx.Value(dbContextKey).(*db.Tx)
if !ok { if !ok {
panic("Couldn't find tx in lua's Context") panic("Couldn't find tx in lua's Context")
} }

View File

@ -2,12 +2,11 @@ package handlers_test
import ( import (
"bytes" "bytes"
"database/sql"
"encoding/json" "encoding/json"
"github.com/aclindsa/moneygo/internal/config" "github.com/aclindsa/moneygo/internal/config"
"github.com/aclindsa/moneygo/internal/db"
"github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/handlers"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
@ -253,24 +252,18 @@ func RunTests(m *testing.M) int {
dsn = envDSN dsn = envDSN
} }
dsn = db.GetDSN(dbType, dsn) db, err := db.GetStore(dbType, dsn)
database, err := sql.Open(dbType.String(), dsn)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer database.Close() defer db.Close()
dbmap, err := db.GetDbMap(database, dbType) err = db.DbMap.TruncateTables()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
err = dbmap.TruncateTables() server = httptest.NewTLSServer(&handlers.APIHandler{Store: db})
if err != nil {
log.Fatal(err)
}
server = httptest.NewTLSServer(&handlers.APIHandler{DB: dbmap})
defer server.Close() defer server.Close()
return m.Run() return m.Run()

View File

@ -1,8 +1,9 @@
package handlers package handlers
import ( import (
"github.com/aclindsa/gorp"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"github.com/aclindsa/moneygo/internal/store/db"
"log" "log"
"net/http" "net/http"
"path" "path"
@ -16,7 +17,8 @@ type ResponseWriterWriter interface {
} }
type Context struct { type Context struct {
Tx *Tx Tx *db.Tx
StoreTx store.Tx
User *models.User User *models.User
remainingURL string // portion of URL path not yet reached in the hierarchy remainingURL string // portion of URL path not yet reached in the hierarchy
} }
@ -46,11 +48,11 @@ func (c *Context) LastLevel() bool {
type Handler func(*http.Request, *Context) ResponseWriterWriter type Handler func(*http.Request, *Context) ResponseWriterWriter
type APIHandler struct { type APIHandler struct {
DB *gorp.DbMap Store *db.DbStore
} }
func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) { func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
tx, err := GetTx(ah.DB) tx, err := GetTx(ah.Store.DbMap)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
@ -72,6 +74,33 @@ func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (w
}() }()
context.Tx = tx context.Tx = tx
context.StoreTx = tx
return h(r, context)
}
func (ah *APIHandler) storeTxWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
tx, err := ah.Store.Begin()
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
if _, ok := writer.(*Error); ok {
tx.Rollback()
} else {
err = tx.Commit()
if err != nil {
log.Print(err)
writer = NewError(999 /*Internal Error*/)
}
}
}()
context.StoreTx = tx
return h(r, context) return h(r, context)
} }

View File

@ -3,6 +3,7 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"github.com/aclindsa/ofxgo" "github.com/aclindsa/ofxgo"
"io" "io"
"log" "log"
@ -23,7 +24,7 @@ func (od *OFXDownload) Read(json_str string) error {
return dec.Decode(od) return dec.Decode(od)
} }
func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter { func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
itl, err := ImportOFX(r) itl, err := ImportOFX(r)
if err != nil { if err != nil {

View File

@ -2,12 +2,13 @@ package handlers
import ( import (
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"log" "log"
"net/http" "net/http"
"time" "time"
) )
func CreatePriceIfNotExist(tx *Tx, price *models.Price) error { func CreatePriceIfNotExist(tx *db.Tx, price *models.Price) error {
if len(price.RemoteId) == 0 { if len(price.RemoteId) == 0 {
// Always create a new price if we can't match on the RemoteId // Always create a new price if we can't match on the RemoteId
err := tx.Insert(price) err := tx.Insert(price)
@ -35,7 +36,7 @@ func CreatePriceIfNotExist(tx *Tx, price *models.Price) error {
return nil return nil
} }
func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) { func GetPrice(tx *db.Tx, priceid, securityid int64) (*models.Price, error) {
var p models.Price var p models.Price
err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid) err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid)
if err != nil { if err != nil {
@ -44,7 +45,7 @@ func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) {
return &p, nil return &p, nil
} }
func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) { func GetPrices(tx *db.Tx, securityid int64) (*[]*models.Price, error) {
var prices []*models.Price var prices []*models.Price
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid) _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid)
@ -55,7 +56,7 @@ func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) {
} }
// Return the latest price for security in currency units before date // Return the latest price for security in currency units before date
func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { func GetLatestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
var p models.Price var p models.Price
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
if err != nil { if err != nil {
@ -65,7 +66,7 @@ func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time
} }
// Return the earliest price for security in currency units after date // Return the earliest price for security in currency units after date
func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { func GetEarliestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
var p models.Price var p models.Price
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
if err != nil { if err != nil {
@ -75,7 +76,7 @@ func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Ti
} }
// Return the price for security in currency closest to date // Return the price for security in currency closest to date
func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { func GetClosestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
earliest, _ := GetEarliestPrice(tx, security, currency, date) earliest, _ := GetEarliestPrice(tx, security, currency, date)
latest, err := GetLatestPrice(tx, security, currency, date) latest, err := GetLatestPrice(tx, security, currency, date)

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua"
"log" "log"
"net/http" "net/http"
@ -24,7 +25,7 @@ const (
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) { func GetReport(tx *db.Tx, reportid int64, userid int64) (*models.Report, error) {
var r models.Report var r models.Report
err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
@ -34,7 +35,7 @@ func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) {
return &r, nil return &r, nil
} }
func GetReports(tx *Tx, userid int64) (*[]models.Report, error) { func GetReports(tx *db.Tx, userid int64) (*[]models.Report, error) {
var reports []models.Report var reports []models.Report
_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid) _, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
@ -44,7 +45,7 @@ func GetReports(tx *Tx, userid int64) (*[]models.Report, error) {
return &reports, nil return &reports, nil
} }
func InsertReport(tx *Tx, r *models.Report) error { func InsertReport(tx *db.Tx, r *models.Report) error {
err := tx.Insert(r) err := tx.Insert(r)
if err != nil { if err != nil {
return err return err
@ -52,7 +53,7 @@ func InsertReport(tx *Tx, r *models.Report) error {
return nil return nil
} }
func UpdateReport(tx *Tx, r *models.Report) error { func UpdateReport(tx *db.Tx, r *models.Report) error {
count, err := tx.Update(r) count, err := tx.Update(r)
if err != nil { if err != nil {
return err return err
@ -63,7 +64,7 @@ func UpdateReport(tx *Tx, r *models.Report) error {
return nil return nil
} }
func DeleteReport(tx *Tx, r *models.Report) error { func DeleteReport(tx *db.Tx, r *models.Report) error {
count, err := tx.Delete(r) count, err := tx.Delete(r)
if err != nil { if err != nil {
return err return err
@ -74,7 +75,7 @@ func DeleteReport(tx *Tx, r *models.Report) error {
return nil return nil
} }
func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabulation, error) { func runReport(tx *db.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
// Create a new LState without opening the default libs for security // Create a new LState without opening the default libs for security
L := lua.NewState(lua.Options{SkipOpenLibs: true}) L := lua.NewState(lua.Options{SkipOpenLibs: true})
defer L.Close() defer L.Close()
@ -138,7 +139,7 @@ func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabula
} }
} }
func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { func ReportTabulationHandler(tx *db.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
report, err := GetReport(tx, reportid, user.UserId) report, err := GetReport(tx, reportid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)

View File

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
@ -50,7 +51,7 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security {
return nil return nil
} }
func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, error) { func GetSecurity(tx *db.Tx, securityid int64, userid int64) (*models.Security, error) {
var s models.Security var s models.Security
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
@ -60,7 +61,7 @@ func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, erro
return &s, nil return &s, nil
} }
func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) { func GetSecurities(tx *db.Tx, userid int64) (*[]*models.Security, error) {
var securities []*models.Security var securities []*models.Security
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) _, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
@ -70,7 +71,7 @@ func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) {
return &securities, nil return &securities, nil
} }
func InsertSecurity(tx *Tx, s *models.Security) error { func InsertSecurity(tx *db.Tx, s *models.Security) error {
err := tx.Insert(s) err := tx.Insert(s)
if err != nil { if err != nil {
return err return err
@ -78,7 +79,7 @@ func InsertSecurity(tx *Tx, s *models.Security) error {
return nil return nil
} }
func UpdateSecurity(tx *Tx, s *models.Security) (err error) { func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) {
user, err := GetUser(tx, s.UserId) user, err := GetUser(tx, s.UserId)
if err != nil { if err != nil {
return return
@ -105,7 +106,7 @@ func (e SecurityInUseError) Error() string {
return e.message return e.message
} }
func DeleteSecurity(tx *Tx, s *models.Security) error { func DeleteSecurity(tx *db.Tx, s *models.Security) error {
// First, ensure no accounts are using this security // First, ensure no accounts are using this security
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
@ -138,7 +139,7 @@ func DeleteSecurity(tx *Tx, s *models.Security) error {
return nil return nil
} }
func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*models.Security, error) { func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) (*models.Security, error) {
security.UserId = userid security.UserId = userid
if len(security.AlternateId) == 0 { if len(security.AlternateId) == 0 {
// Always create a new local security if we can't match on the AlternateId // Always create a new local security if we can't match on the AlternateId

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua"
) )
@ -14,7 +15,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error)
ctx := L.Context() ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(*Tx) tx, ok := ctx.Value(dbContextKey).(*db.Tx)
if !ok { if !ok {
return nil, errors.New("Couldn't find tx in lua's Context") return nil, errors.New("Couldn't find tx in lua's Context")
} }
@ -158,7 +159,7 @@ func luaClosestPrice(L *lua.LState) int {
date := luaCheckTime(L, 3) date := luaCheckTime(L, 3)
ctx := L.Context() ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(*Tx) tx, ok := ctx.Value(dbContextKey).(*db.Tx)
if !ok { if !ok {
panic("Couldn't find tx in lua's Context") panic("Couldn't find tx in lua's Context")
} }

View File

@ -3,36 +3,37 @@ package handlers
import ( import (
"fmt" "fmt"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"log" "log"
"net/http" "net/http"
"time" "time"
) )
func GetSession(tx *Tx, r *http.Request) (*models.Session, error) { func GetSession(tx store.Tx, r *http.Request) (*models.Session, error) {
var s models.Session
cookie, err := r.Cookie("moneygo-session") cookie, err := r.Cookie("moneygo-session")
if err != nil { if err != nil {
return nil, fmt.Errorf("moneygo-session cookie not set") return nil, fmt.Errorf("moneygo-session cookie not set")
} }
s.SessionSecret = cookie.Value
err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret) s, err := tx.GetSession(cookie.Value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.Expires.Before(time.Now()) { if s.Expires.Before(time.Now()) {
tx.Delete(&s) err := tx.DeleteSession(s)
if err != nil {
log.Printf("Unexpected error when attempting to delete expired session: %s", err)
}
return nil, fmt.Errorf("Session has expired") return nil, fmt.Errorf("Session has expired")
} }
return &s, nil return s, nil
} }
func DeleteSessionIfExists(tx *Tx, r *http.Request) error { func DeleteSessionIfExists(tx store.Tx, r *http.Request) error {
session, err := GetSession(tx, r) session, err := GetSession(tx, r)
if err == nil { if err == nil {
_, err := tx.Delete(session) err := tx.DeleteSession(session)
if err != nil { if err != nil {
return err return err
} }
@ -50,21 +51,26 @@ func (n *NewSessionWriter) Write(w http.ResponseWriter) error {
return n.session.Write(w) return n.session.Write(w)
} }
func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) { func NewSession(tx store.Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
err := DeleteSessionIfExists(tx, r)
if err != nil {
return nil, err
}
s, err := models.NewSession(userid) s, err := models.NewSession(userid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", s.SessionSecret) exists, err := tx.SessionExists(s.SessionSecret)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if existing > 0 { if exists {
return nil, fmt.Errorf("%d session(s) exist with the generated session_secret", existing) return nil, fmt.Errorf("Session already exists with the generated session_secret")
} }
err = tx.Insert(s) err = tx.InsertSession(s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -89,27 +95,21 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter {
return NewError(2 /*Unauthorized Access*/) return NewError(2 /*Unauthorized Access*/)
} }
err = DeleteSessionIfExists(context.Tx, r) sessionwriter, err := NewSession(context.StoreTx, r, dbuser.UserId)
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)
}
sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)
} }
return sessionwriter return sessionwriter
} else if r.Method == "GET" { } else if r.Method == "GET" {
s, err := GetSession(context.Tx, r) s, err := GetSession(context.StoreTx, r)
if err != nil { if err != nil {
return NewError(1 /*Not Signed In*/) return NewError(1 /*Not Signed In*/)
} }
return s return s
} else if r.Method == "DELETE" { } else if r.Method == "DELETE" {
err := DeleteSessionIfExists(context.Tx, r) err := DeleteSessionIfExists(context.StoreTx, r)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return NewError(999 /*Internal Error*/) return NewError(999 /*Internal Error*/)

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"log" "log"
"math/big" "math/big"
"net/http" "net/http"
@ -12,14 +13,14 @@ import (
"time" "time"
) )
func SplitAlreadyImported(tx *Tx, s *models.Split) (bool, error) { func SplitAlreadyImported(tx *db.Tx, s *models.Split) (bool, error) {
count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
return count == 1, err return count == 1, err
} }
// Return a map of security ID's to big.Rat's containing the amount that // Return a map of security ID's to big.Rat's containing the amount that
// security is imbalanced by // security is imbalanced by
func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, error) { func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.Rat, error) {
sums := make(map[int64]big.Rat) sums := make(map[int64]big.Rat)
if !t.Valid() { if !t.Valid() {
@ -47,7 +48,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat,
// Returns true if all securities contained in this transaction are balanced, // Returns true if all securities contained in this transaction are balanced,
// false otherwise // false otherwise
func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) { func TransactionBalanced(tx *db.Tx, t *models.Transaction) (bool, error) {
var zero big.Rat var zero big.Rat
sums, err := GetTransactionImbalances(tx, t) sums, err := GetTransactionImbalances(tx, t)
@ -63,7 +64,7 @@ func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) {
return true, nil return true, nil
} }
func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transaction, error) { func GetTransaction(tx *db.Tx, transactionid int64, userid int64) (*models.Transaction, error) {
var t models.Transaction var t models.Transaction
err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
@ -79,7 +80,7 @@ func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transact
return &t, nil return &t, nil
} }
func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) { func GetTransactions(tx *db.Tx, userid int64) (*[]models.Transaction, error) {
var transactions []models.Transaction var transactions []models.Transaction
_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) _, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
@ -97,7 +98,7 @@ func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) {
return &transactions, nil return &transactions, nil
} }
func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error { func incrementAccountVersions(tx *db.Tx, user *models.User, accountids []int64) error {
for i := range accountids { for i := range accountids {
account, err := GetAccount(tx, accountids[i], user.UserId) account, err := GetAccount(tx, accountids[i], user.UserId)
if err != nil { if err != nil {
@ -121,7 +122,7 @@ func (ame AccountMissingError) Error() string {
return "Account missing" return "Account missing"
} }
func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error { func InsertTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error {
// Map of any accounts with transaction splits being added // Map of any accounts with transaction splits being added
a_map := make(map[int64]bool) a_map := make(map[int64]bool)
for i := range t.Splits { for i := range t.Splits {
@ -171,7 +172,7 @@ func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
return nil return nil
} }
func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error { func UpdateTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error {
var existing_splits []*models.Split var existing_splits []*models.Split
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
@ -248,7 +249,7 @@ func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
return nil return nil
} }
func DeleteTransaction(tx *Tx, t *models.Transaction, user *models.User) error { func DeleteTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error {
var accountids []int64 var accountids []int64
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
if err != nil { if err != nil {
@ -401,7 +402,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) { func TransactionsBalanceDifference(tx *db.Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) {
var pageDifference, tmp big.Rat var pageDifference, tmp big.Rat
for i := range transactions { for i := range transactions {
_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) _, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
@ -425,7 +426,7 @@ func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []model
return &pageDifference, nil return &pageDifference, nil
} }
func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) { func GetAccountBalance(tx *db.Tx, user *models.User, accountid int64) (*big.Rat, error) {
var splits []models.Split var splits []models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
@ -448,7 +449,7 @@ func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, er
} }
// Assumes accountid is valid and is owned by the current user // Assumes accountid is valid and is owned by the current user
func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) { func GetAccountBalanceDate(tx *db.Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) {
var splits []models.Split var splits []models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
@ -470,7 +471,7 @@ func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *tim
return &balance, nil return &balance, nil
} }
func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) { func GetAccountBalanceDateRange(tx *db.Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
var splits []models.Split var splits []models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
@ -492,7 +493,7 @@ func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begi
return &balance, nil return &balance, nil
} }
func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) { func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
var transactions []models.Transaction var transactions []models.Transaction
var atl models.AccountTransactionsList var atl models.AccountTransactionsList

View File

@ -1,65 +1,14 @@
package handlers package handlers
import ( import (
"database/sql"
"github.com/aclindsa/gorp" "github.com/aclindsa/gorp"
"strings" "github.com/aclindsa/moneygo/internal/store/db"
) )
type Tx struct { func GetTx(gdb *gorp.DbMap) (*db.Tx, error) {
Dialect gorp.Dialect tx, err := gdb.Begin()
Tx *gorp.Transaction
}
func (tx *Tx) Rebind(query string) string {
chunks := strings.Split(query, "?")
str := chunks[0]
for i := 1; i < len(chunks); i++ {
str += tx.Dialect.BindVar(i-1) + chunks[i]
}
return str
}
func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
return tx.Tx.Select(i, tx.Rebind(query), args...)
}
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.Tx.Exec(tx.Rebind(query), args...)
}
func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) {
return tx.Tx.SelectInt(tx.Rebind(query), args...)
}
func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error {
return tx.Tx.SelectOne(holder, tx.Rebind(query), args...)
}
func (tx *Tx) Insert(list ...interface{}) error {
return tx.Tx.Insert(list...)
}
func (tx *Tx) Update(list ...interface{}) (int64, error) {
return tx.Tx.Update(list...)
}
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
return tx.Tx.Delete(list...)
}
func (tx *Tx) Commit() error {
return tx.Tx.Commit()
}
func (tx *Tx) Rollback() error {
return tx.Tx.Rollback()
}
func GetTx(db *gorp.DbMap) (*Tx, error) {
tx, err := db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Tx{db.Dialect, tx}, nil return &db.Tx{gdb.Dialect, tx}, nil
} }

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"log" "log"
"net/http" "net/http"
) )
@ -14,7 +15,7 @@ func (ueu UserExistsError) Error() string {
return "User exists" return "User exists"
} }
func GetUser(tx *Tx, userid int64) (*models.User, error) { func GetUser(tx *db.Tx, userid int64) (*models.User, error) {
var u models.User var u models.User
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
@ -24,7 +25,7 @@ func GetUser(tx *Tx, userid int64) (*models.User, error) {
return &u, nil return &u, nil
} }
func GetUserByUsername(tx *Tx, username string) (*models.User, error) { func GetUserByUsername(tx *db.Tx, username string) (*models.User, error) {
var u models.User var u models.User
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
@ -34,7 +35,7 @@ func GetUserByUsername(tx *Tx, username string) (*models.User, error) {
return &u, nil return &u, nil
} }
func InsertUser(tx *Tx, u *models.User) error { func InsertUser(tx *db.Tx, u *models.User) error {
security_template := FindCurrencyTemplate(u.DefaultCurrency) security_template := FindCurrencyTemplate(u.DefaultCurrency)
if security_template == nil { if security_template == nil {
return errors.New("Invalid ISO4217 Default Currency") return errors.New("Invalid ISO4217 Default Currency")
@ -75,7 +76,7 @@ func InsertUser(tx *Tx, u *models.User) error {
return nil return nil
} }
func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) { func GetUserFromSession(tx *db.Tx, r *http.Request) (*models.User, error) {
s, err := GetSession(tx, r) s, err := GetSession(tx, r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -83,7 +84,7 @@ func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) {
return GetUser(tx, s.UserId) return GetUser(tx, s.UserId)
} }
func UpdateUser(tx *Tx, u *models.User) error { func UpdateUser(tx *db.Tx, u *models.User) error {
security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId) security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
if err != nil { if err != nil {
return err return err
@ -103,7 +104,7 @@ func UpdateUser(tx *Tx, u *models.User) error {
return nil return nil
} }
func DeleteUser(tx *Tx, u *models.User) error { func DeleteUser(tx *db.Tx, u *models.User) error {
count, err := tx.Delete(u) count, err := tx.Delete(u)
if err != nil { if err != nil {
return err return err

View File

@ -6,6 +6,7 @@ import (
"github.com/aclindsa/gorp" "github.com/aclindsa/gorp"
"github.com/aclindsa/moneygo/internal/config" "github.com/aclindsa/moneygo/internal/config"
"github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -60,3 +61,40 @@ func GetDSN(dbtype config.DbType, dsn string) string {
} }
return dsn return dsn
} }
type DbStore struct {
DbMap *gorp.DbMap
}
func (db *DbStore) Begin() (store.Tx, error) {
tx, err := db.DbMap.Begin()
if err != nil {
return nil, err
}
return &Tx{db.DbMap.Dialect, tx}, nil
}
func (db *DbStore) Close() error {
err := db.DbMap.Db.Close()
db.DbMap = nil
return err
}
func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) {
dsn = GetDSN(dbtype, dsn)
database, err := sql.Open(dbtype.String(), dsn)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
database.Close()
}
}()
dbmap, err := GetDbMap(database, dbtype)
if err != nil {
return nil, err
}
return &DbStore{dbmap}, nil
}

View File

@ -0,0 +1,36 @@
package db
import (
"fmt"
"github.com/aclindsa/moneygo/internal/models"
"time"
)
func (tx *Tx) InsertSession(session *models.Session) error {
return tx.Insert(session)
}
func (tx *Tx) GetSession(secret string) (*models.Session, error) {
var s models.Session
err := tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", secret)
if err != nil {
return nil, err
}
if s.Expires.Before(time.Now()) {
tx.Delete(&s)
return nil, fmt.Errorf("Session has expired")
}
return &s, nil
}
func (tx *Tx) SessionExists(secret string) (bool, error) {
existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", secret)
return existing != 0, err
}
func (tx *Tx) DeleteSession(session *models.Session) error {
_, err := tx.Delete(session)
return err
}

57
internal/store/db/tx.go Normal file
View File

@ -0,0 +1,57 @@
package db
import (
"database/sql"
"github.com/aclindsa/gorp"
"strings"
)
type Tx struct {
Dialect gorp.Dialect
Tx *gorp.Transaction
}
func (tx *Tx) Rebind(query string) string {
chunks := strings.Split(query, "?")
str := chunks[0]
for i := 1; i < len(chunks); i++ {
str += tx.Dialect.BindVar(i-1) + chunks[i]
}
return str
}
func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
return tx.Tx.Select(i, tx.Rebind(query), args...)
}
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.Tx.Exec(tx.Rebind(query), args...)
}
func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) {
return tx.Tx.SelectInt(tx.Rebind(query), args...)
}
func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error {
return tx.Tx.SelectOne(holder, tx.Rebind(query), args...)
}
func (tx *Tx) Insert(list ...interface{}) error {
return tx.Tx.Insert(list...)
}
func (tx *Tx) Update(list ...interface{}) (int64, error) {
return tx.Tx.Update(list...)
}
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
return tx.Tx.Delete(list...)
}
func (tx *Tx) Commit() error {
return tx.Tx.Commit()
}
func (tx *Tx) Rollback() error {
return tx.Tx.Rollback()
}

24
internal/store/store.go Normal file
View File

@ -0,0 +1,24 @@
package store
import (
"github.com/aclindsa/moneygo/internal/models"
)
type SessionStore interface {
InsertSession(session *models.Session) error
GetSession(secret string) (*models.Session, error)
SessionExists(secret string) (bool, error)
DeleteSession(session *models.Session) error
}
type Tx interface {
Commit() error
Rollback() error
SessionStore
}
type Store interface {
Begin() (Tx, error)
Close() error
}

15
main.go
View File

@ -3,11 +3,10 @@ package main
//go:generate make //go:generate make
import ( import (
"database/sql"
"flag" "flag"
"github.com/aclindsa/moneygo/internal/config" "github.com/aclindsa/moneygo/internal/config"
"github.com/aclindsa/moneygo/internal/db"
"github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/handlers"
"github.com/aclindsa/moneygo/internal/store/db"
"github.com/kabukky/httpscerts" "github.com/kabukky/httpscerts"
"log" "log"
"net" "net"
@ -67,21 +66,15 @@ func staticHandler(w http.ResponseWriter, r *http.Request, basedir string) {
} }
func main() { func main() {
dsn := db.GetDSN(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN) db, err := db.GetStore(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN)
database, err := sql.Open(cfg.MoneyGo.DBType.String(), dsn)
if err != nil {
log.Fatal(err)
}
defer database.Close()
dbmap, err := db.GetDbMap(database, cfg.MoneyGo.DBType)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer db.Close()
// Get ServeMux for API and add our own handlers for files // Get ServeMux for API and add our own handlers for files
servemux := http.NewServeMux() servemux := http.NewServeMux()
servemux.Handle("/v1/", &handlers.APIHandler{DB: dbmap}) servemux.Handle("/v1/", &handlers.APIHandler{Store: db})
servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir)) servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir))
servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir)) servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir))