Lay groundwork and move sessions to 'store'

This commit is contained in:
Aaron Lindsay 2017-12-06 21:09:47 -05:00
parent 6bdde8e83b
commit c452984f23
18 changed files with 286 additions and 158 deletions

View File

@ -3,11 +3,12 @@ package handlers
import (
"errors"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"log"
"net/http"
)
func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) {
func GetAccount(tx *db.Tx, accountid int64, userid int64) (*models.Account, error) {
var a models.Account
err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
@ -17,7 +18,7 @@ func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error)
return &a, nil
}
func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) {
func GetAccounts(tx *db.Tx, userid int64) (*[]models.Account, error) {
var accounts []models.Account
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
@ -29,7 +30,7 @@ func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) {
// Get (and attempt to create if it doesn't exist). Matches on UserId,
// SecurityId, Type, Name, and ParentAccountId
func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
func GetCreateAccount(tx *db.Tx, a models.Account) (*models.Account, error) {
var accounts []models.Account
var account models.Account
@ -57,7 +58,7 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
// Get (and attempt to create if it doesn't exist) the security/currency
// trading account for the supplied security/currency
func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) {
func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) {
var tradingAccount models.Account
var account models.Account
@ -99,7 +100,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account,
// Get (and attempt to create if it doesn't exist) the security/currency
// imbalance account for the supplied security/currency
func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) {
func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) {
var imbalanceAccount models.Account
var account models.Account
xxxtemplate := FindSecurityTemplate("XXX", models.Currency)
@ -160,7 +161,7 @@ func (cae CircularAccountsError) Error() string {
return "Would result in circular account relationship"
}
func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error {
func insertUpdateAccount(tx *db.Tx, a *models.Account, insert bool) error {
found := make(map[int64]bool)
if !insert {
found[a.AccountId] = true
@ -216,15 +217,15 @@ func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error {
return nil
}
func InsertAccount(tx *Tx, a *models.Account) error {
func InsertAccount(tx *db.Tx, a *models.Account) error {
return insertUpdateAccount(tx, a, true)
}
func UpdateAccount(tx *Tx, a *models.Account) error {
func UpdateAccount(tx *db.Tx, a *models.Account) error {
return insertUpdateAccount(tx, a, false)
}
func DeleteAccount(tx *Tx, a *models.Account) error {
func DeleteAccount(tx *db.Tx, a *models.Account) error {
if a.ParentAccountId != -1 {
// Re-parent splits to this account's parent account if this account isn't a root account
_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId)

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"github.com/yuin/gopher-lua"
"math/big"
"strings"
@ -16,7 +17,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) {
ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(*Tx)
tx, ok := ctx.Value(dbContextKey).(*db.Tx)
if !ok {
return nil, errors.New("Couldn't find tx in lua's Context")
}
@ -150,7 +151,7 @@ func luaAccountBalance(L *lua.LState) int {
a := luaCheckAccount(L, 1)
ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(*Tx)
tx, ok := ctx.Value(dbContextKey).(*db.Tx)
if !ok {
panic("Couldn't find tx in lua's Context")
}

View File

@ -2,12 +2,11 @@ package handlers_test
import (
"bytes"
"database/sql"
"encoding/json"
"github.com/aclindsa/moneygo/internal/config"
"github.com/aclindsa/moneygo/internal/db"
"github.com/aclindsa/moneygo/internal/handlers"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"io"
"io/ioutil"
"log"
@ -253,24 +252,18 @@ func RunTests(m *testing.M) int {
dsn = envDSN
}
dsn = db.GetDSN(dbType, dsn)
database, err := sql.Open(dbType.String(), dsn)
db, err := db.GetStore(dbType, dsn)
if err != nil {
log.Fatal(err)
}
defer database.Close()
defer db.Close()
dbmap, err := db.GetDbMap(database, dbType)
err = db.DbMap.TruncateTables()
if err != nil {
log.Fatal(err)
}
err = dbmap.TruncateTables()
if err != nil {
log.Fatal(err)
}
server = httptest.NewTLSServer(&handlers.APIHandler{DB: dbmap})
server = httptest.NewTLSServer(&handlers.APIHandler{Store: db})
defer server.Close()
return m.Run()

View File

@ -1,8 +1,9 @@
package handlers
import (
"github.com/aclindsa/gorp"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"github.com/aclindsa/moneygo/internal/store/db"
"log"
"net/http"
"path"
@ -16,7 +17,8 @@ type ResponseWriterWriter interface {
}
type Context struct {
Tx *Tx
Tx *db.Tx
StoreTx store.Tx
User *models.User
remainingURL string // portion of URL path not yet reached in the hierarchy
}
@ -46,11 +48,11 @@ func (c *Context) LastLevel() bool {
type Handler func(*http.Request, *Context) ResponseWriterWriter
type APIHandler struct {
DB *gorp.DbMap
Store *db.DbStore
}
func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
tx, err := GetTx(ah.DB)
tx, err := GetTx(ah.Store.DbMap)
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)
@ -72,6 +74,33 @@ func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (w
}()
context.Tx = tx
context.StoreTx = tx
return h(r, context)
}
func (ah *APIHandler) storeTxWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
tx, err := ah.Store.Begin()
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
if _, ok := writer.(*Error); ok {
tx.Rollback()
} else {
err = tx.Commit()
if err != nil {
log.Print(err)
writer = NewError(999 /*Internal Error*/)
}
}
}()
context.StoreTx = tx
return h(r, context)
}

View File

@ -3,6 +3,7 @@ package handlers
import (
"encoding/json"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"github.com/aclindsa/ofxgo"
"io"
"log"
@ -23,7 +24,7 @@ func (od *OFXDownload) Read(json_str string) error {
return dec.Decode(od)
}
func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
itl, err := ImportOFX(r)
if err != nil {

View File

@ -2,12 +2,13 @@ package handlers
import (
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"log"
"net/http"
"time"
)
func CreatePriceIfNotExist(tx *Tx, price *models.Price) error {
func CreatePriceIfNotExist(tx *db.Tx, price *models.Price) error {
if len(price.RemoteId) == 0 {
// Always create a new price if we can't match on the RemoteId
err := tx.Insert(price)
@ -35,7 +36,7 @@ func CreatePriceIfNotExist(tx *Tx, price *models.Price) error {
return nil
}
func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) {
func GetPrice(tx *db.Tx, priceid, securityid int64) (*models.Price, error) {
var p models.Price
err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid)
if err != nil {
@ -44,7 +45,7 @@ func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) {
return &p, nil
}
func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) {
func GetPrices(tx *db.Tx, securityid int64) (*[]*models.Price, error) {
var prices []*models.Price
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid)
@ -55,7 +56,7 @@ func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) {
}
// Return the latest price for security in currency units before date
func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
func GetLatestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
var p models.Price
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
if err != nil {
@ -65,7 +66,7 @@ func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time
}
// Return the earliest price for security in currency units after date
func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
func GetEarliestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
var p models.Price
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
if err != nil {
@ -75,7 +76,7 @@ func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Ti
}
// Return the price for security in currency closest to date
func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
func GetClosestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
earliest, _ := GetEarliestPrice(tx, security, currency, date)
latest, err := GetLatestPrice(tx, security, currency, date)

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"github.com/yuin/gopher-lua"
"log"
"net/http"
@ -24,7 +25,7 @@ const (
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) {
func GetReport(tx *db.Tx, reportid int64, userid int64) (*models.Report, error) {
var r models.Report
err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
@ -34,7 +35,7 @@ func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) {
return &r, nil
}
func GetReports(tx *Tx, userid int64) (*[]models.Report, error) {
func GetReports(tx *db.Tx, userid int64) (*[]models.Report, error) {
var reports []models.Report
_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
@ -44,7 +45,7 @@ func GetReports(tx *Tx, userid int64) (*[]models.Report, error) {
return &reports, nil
}
func InsertReport(tx *Tx, r *models.Report) error {
func InsertReport(tx *db.Tx, r *models.Report) error {
err := tx.Insert(r)
if err != nil {
return err
@ -52,7 +53,7 @@ func InsertReport(tx *Tx, r *models.Report) error {
return nil
}
func UpdateReport(tx *Tx, r *models.Report) error {
func UpdateReport(tx *db.Tx, r *models.Report) error {
count, err := tx.Update(r)
if err != nil {
return err
@ -63,7 +64,7 @@ func UpdateReport(tx *Tx, r *models.Report) error {
return nil
}
func DeleteReport(tx *Tx, r *models.Report) error {
func DeleteReport(tx *db.Tx, r *models.Report) error {
count, err := tx.Delete(r)
if err != nil {
return err
@ -74,7 +75,7 @@ func DeleteReport(tx *Tx, r *models.Report) error {
return nil
}
func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
func runReport(tx *db.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
// Create a new LState without opening the default libs for security
L := lua.NewState(lua.Options{SkipOpenLibs: true})
defer L.Close()
@ -138,7 +139,7 @@ func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabula
}
}
func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
func ReportTabulationHandler(tx *db.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
report, err := GetReport(tx, reportid, user.UserId)
if err != nil {
return NewError(3 /*Invalid Request*/)

View File

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"log"
"net/http"
"net/url"
@ -50,7 +51,7 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security {
return nil
}
func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, error) {
func GetSecurity(tx *db.Tx, securityid int64, userid int64) (*models.Security, error) {
var s models.Security
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
@ -60,7 +61,7 @@ func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, erro
return &s, nil
}
func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) {
func GetSecurities(tx *db.Tx, userid int64) (*[]*models.Security, error) {
var securities []*models.Security
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
@ -70,7 +71,7 @@ func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) {
return &securities, nil
}
func InsertSecurity(tx *Tx, s *models.Security) error {
func InsertSecurity(tx *db.Tx, s *models.Security) error {
err := tx.Insert(s)
if err != nil {
return err
@ -78,7 +79,7 @@ func InsertSecurity(tx *Tx, s *models.Security) error {
return nil
}
func UpdateSecurity(tx *Tx, s *models.Security) (err error) {
func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) {
user, err := GetUser(tx, s.UserId)
if err != nil {
return
@ -105,7 +106,7 @@ func (e SecurityInUseError) Error() string {
return e.message
}
func DeleteSecurity(tx *Tx, s *models.Security) error {
func DeleteSecurity(tx *db.Tx, s *models.Security) error {
// First, ensure no accounts are using this security
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
@ -138,7 +139,7 @@ func DeleteSecurity(tx *Tx, s *models.Security) error {
return nil
}
func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*models.Security, error) {
func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) (*models.Security, error) {
security.UserId = userid
if len(security.AlternateId) == 0 {
// Always create a new local security if we can't match on the AlternateId

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"github.com/yuin/gopher-lua"
)
@ -14,7 +15,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error)
ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(*Tx)
tx, ok := ctx.Value(dbContextKey).(*db.Tx)
if !ok {
return nil, errors.New("Couldn't find tx in lua's Context")
}
@ -158,7 +159,7 @@ func luaClosestPrice(L *lua.LState) int {
date := luaCheckTime(L, 3)
ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(*Tx)
tx, ok := ctx.Value(dbContextKey).(*db.Tx)
if !ok {
panic("Couldn't find tx in lua's Context")
}

View File

@ -3,36 +3,37 @@ package handlers
import (
"fmt"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"log"
"net/http"
"time"
)
func GetSession(tx *Tx, r *http.Request) (*models.Session, error) {
var s models.Session
func GetSession(tx store.Tx, r *http.Request) (*models.Session, error) {
cookie, err := r.Cookie("moneygo-session")
if err != nil {
return nil, fmt.Errorf("moneygo-session cookie not set")
}
s.SessionSecret = cookie.Value
err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
s, err := tx.GetSession(cookie.Value)
if err != nil {
return nil, err
}
if s.Expires.Before(time.Now()) {
tx.Delete(&s)
err := tx.DeleteSession(s)
if err != nil {
log.Printf("Unexpected error when attempting to delete expired session: %s", err)
}
return nil, fmt.Errorf("Session has expired")
}
return &s, nil
return s, nil
}
func DeleteSessionIfExists(tx *Tx, r *http.Request) error {
func DeleteSessionIfExists(tx store.Tx, r *http.Request) error {
session, err := GetSession(tx, r)
if err == nil {
_, err := tx.Delete(session)
err := tx.DeleteSession(session)
if err != nil {
return err
}
@ -50,21 +51,26 @@ func (n *NewSessionWriter) Write(w http.ResponseWriter) error {
return n.session.Write(w)
}
func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
func NewSession(tx store.Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
err := DeleteSessionIfExists(tx, r)
if err != nil {
return nil, err
}
s, err := models.NewSession(userid)
if err != nil {
return nil, err
}
existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", s.SessionSecret)
exists, err := tx.SessionExists(s.SessionSecret)
if err != nil {
return nil, err
}
if existing > 0 {
return nil, fmt.Errorf("%d session(s) exist with the generated session_secret", existing)
if exists {
return nil, fmt.Errorf("Session already exists with the generated session_secret")
}
err = tx.Insert(s)
err = tx.InsertSession(s)
if err != nil {
return nil, err
}
@ -89,27 +95,21 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter {
return NewError(2 /*Unauthorized Access*/)
}
err = DeleteSessionIfExists(context.Tx, r)
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)
}
sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId)
sessionwriter, err := NewSession(context.StoreTx, r, dbuser.UserId)
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)
}
return sessionwriter
} else if r.Method == "GET" {
s, err := GetSession(context.Tx, r)
s, err := GetSession(context.StoreTx, r)
if err != nil {
return NewError(1 /*Not Signed In*/)
}
return s
} else if r.Method == "DELETE" {
err := DeleteSessionIfExists(context.Tx, r)
err := DeleteSessionIfExists(context.StoreTx, r)
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"log"
"math/big"
"net/http"
@ -12,14 +13,14 @@ import (
"time"
)
func SplitAlreadyImported(tx *Tx, s *models.Split) (bool, error) {
func SplitAlreadyImported(tx *db.Tx, s *models.Split) (bool, error) {
count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
return count == 1, err
}
// Return a map of security ID's to big.Rat's containing the amount that
// security is imbalanced by
func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, error) {
func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.Rat, error) {
sums := make(map[int64]big.Rat)
if !t.Valid() {
@ -47,7 +48,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat,
// Returns true if all securities contained in this transaction are balanced,
// false otherwise
func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) {
func TransactionBalanced(tx *db.Tx, t *models.Transaction) (bool, error) {
var zero big.Rat
sums, err := GetTransactionImbalances(tx, t)
@ -63,7 +64,7 @@ func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) {
return true, nil
}
func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transaction, error) {
func GetTransaction(tx *db.Tx, transactionid int64, userid int64) (*models.Transaction, error) {
var t models.Transaction
err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
@ -79,7 +80,7 @@ func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transact
return &t, nil
}
func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) {
func GetTransactions(tx *db.Tx, userid int64) (*[]models.Transaction, error) {
var transactions []models.Transaction
_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
@ -97,7 +98,7 @@ func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) {
return &transactions, nil
}
func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error {
func incrementAccountVersions(tx *db.Tx, user *models.User, accountids []int64) error {
for i := range accountids {
account, err := GetAccount(tx, accountids[i], user.UserId)
if err != nil {
@ -121,7 +122,7 @@ func (ame AccountMissingError) Error() string {
return "Account missing"
}
func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
func InsertTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error {
// Map of any accounts with transaction splits being added
a_map := make(map[int64]bool)
for i := range t.Splits {
@ -171,7 +172,7 @@ func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
return nil
}
func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
func UpdateTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error {
var existing_splits []*models.Split
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
@ -248,7 +249,7 @@ func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
return nil
}
func DeleteTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
func DeleteTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error {
var accountids []int64
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
if err != nil {
@ -401,7 +402,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
return NewError(3 /*Invalid Request*/)
}
func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) {
func TransactionsBalanceDifference(tx *db.Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) {
var pageDifference, tmp big.Rat
for i := range transactions {
_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
@ -425,7 +426,7 @@ func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []model
return &pageDifference, nil
}
func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) {
func GetAccountBalance(tx *db.Tx, user *models.User, accountid int64) (*big.Rat, error) {
var splits []models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
@ -448,7 +449,7 @@ func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, er
}
// Assumes accountid is valid and is owned by the current user
func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) {
func GetAccountBalanceDate(tx *db.Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) {
var splits []models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
@ -470,7 +471,7 @@ func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *tim
return &balance, nil
}
func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
func GetAccountBalanceDateRange(tx *db.Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
var splits []models.Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
@ -492,7 +493,7 @@ func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begi
return &balance, nil
}
func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
var transactions []models.Transaction
var atl models.AccountTransactionsList

View File

@ -1,65 +1,14 @@
package handlers
import (
"database/sql"
"github.com/aclindsa/gorp"
"strings"
"github.com/aclindsa/moneygo/internal/store/db"
)
type Tx struct {
Dialect gorp.Dialect
Tx *gorp.Transaction
}
func (tx *Tx) Rebind(query string) string {
chunks := strings.Split(query, "?")
str := chunks[0]
for i := 1; i < len(chunks); i++ {
str += tx.Dialect.BindVar(i-1) + chunks[i]
}
return str
}
func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
return tx.Tx.Select(i, tx.Rebind(query), args...)
}
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.Tx.Exec(tx.Rebind(query), args...)
}
func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) {
return tx.Tx.SelectInt(tx.Rebind(query), args...)
}
func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error {
return tx.Tx.SelectOne(holder, tx.Rebind(query), args...)
}
func (tx *Tx) Insert(list ...interface{}) error {
return tx.Tx.Insert(list...)
}
func (tx *Tx) Update(list ...interface{}) (int64, error) {
return tx.Tx.Update(list...)
}
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
return tx.Tx.Delete(list...)
}
func (tx *Tx) Commit() error {
return tx.Tx.Commit()
}
func (tx *Tx) Rollback() error {
return tx.Tx.Rollback()
}
func GetTx(db *gorp.DbMap) (*Tx, error) {
tx, err := db.Begin()
func GetTx(gdb *gorp.DbMap) (*db.Tx, error) {
tx, err := gdb.Begin()
if err != nil {
return nil, err
}
return &Tx{db.Dialect, tx}, nil
return &db.Tx{gdb.Dialect, tx}, nil
}

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store/db"
"log"
"net/http"
)
@ -14,7 +15,7 @@ func (ueu UserExistsError) Error() string {
return "User exists"
}
func GetUser(tx *Tx, userid int64) (*models.User, error) {
func GetUser(tx *db.Tx, userid int64) (*models.User, error) {
var u models.User
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
@ -24,7 +25,7 @@ func GetUser(tx *Tx, userid int64) (*models.User, error) {
return &u, nil
}
func GetUserByUsername(tx *Tx, username string) (*models.User, error) {
func GetUserByUsername(tx *db.Tx, username string) (*models.User, error) {
var u models.User
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
@ -34,7 +35,7 @@ func GetUserByUsername(tx *Tx, username string) (*models.User, error) {
return &u, nil
}
func InsertUser(tx *Tx, u *models.User) error {
func InsertUser(tx *db.Tx, u *models.User) error {
security_template := FindCurrencyTemplate(u.DefaultCurrency)
if security_template == nil {
return errors.New("Invalid ISO4217 Default Currency")
@ -75,7 +76,7 @@ func InsertUser(tx *Tx, u *models.User) error {
return nil
}
func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) {
func GetUserFromSession(tx *db.Tx, r *http.Request) (*models.User, error) {
s, err := GetSession(tx, r)
if err != nil {
return nil, err
@ -83,7 +84,7 @@ func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) {
return GetUser(tx, s.UserId)
}
func UpdateUser(tx *Tx, u *models.User) error {
func UpdateUser(tx *db.Tx, u *models.User) error {
security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
if err != nil {
return err
@ -103,7 +104,7 @@ func UpdateUser(tx *Tx, u *models.User) error {
return nil
}
func DeleteUser(tx *Tx, u *models.User) error {
func DeleteUser(tx *db.Tx, u *models.User) error {
count, err := tx.Delete(u)
if err != nil {
return err

View File

@ -6,6 +6,7 @@ import (
"github.com/aclindsa/gorp"
"github.com/aclindsa/moneygo/internal/config"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
@ -60,3 +61,40 @@ func GetDSN(dbtype config.DbType, dsn string) string {
}
return dsn
}
type DbStore struct {
DbMap *gorp.DbMap
}
func (db *DbStore) Begin() (store.Tx, error) {
tx, err := db.DbMap.Begin()
if err != nil {
return nil, err
}
return &Tx{db.DbMap.Dialect, tx}, nil
}
func (db *DbStore) Close() error {
err := db.DbMap.Db.Close()
db.DbMap = nil
return err
}
func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) {
dsn = GetDSN(dbtype, dsn)
database, err := sql.Open(dbtype.String(), dsn)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
database.Close()
}
}()
dbmap, err := GetDbMap(database, dbtype)
if err != nil {
return nil, err
}
return &DbStore{dbmap}, nil
}

View File

@ -0,0 +1,36 @@
package db
import (
"fmt"
"github.com/aclindsa/moneygo/internal/models"
"time"
)
func (tx *Tx) InsertSession(session *models.Session) error {
return tx.Insert(session)
}
func (tx *Tx) GetSession(secret string) (*models.Session, error) {
var s models.Session
err := tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", secret)
if err != nil {
return nil, err
}
if s.Expires.Before(time.Now()) {
tx.Delete(&s)
return nil, fmt.Errorf("Session has expired")
}
return &s, nil
}
func (tx *Tx) SessionExists(secret string) (bool, error) {
existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", secret)
return existing != 0, err
}
func (tx *Tx) DeleteSession(session *models.Session) error {
_, err := tx.Delete(session)
return err
}

57
internal/store/db/tx.go Normal file
View File

@ -0,0 +1,57 @@
package db
import (
"database/sql"
"github.com/aclindsa/gorp"
"strings"
)
type Tx struct {
Dialect gorp.Dialect
Tx *gorp.Transaction
}
func (tx *Tx) Rebind(query string) string {
chunks := strings.Split(query, "?")
str := chunks[0]
for i := 1; i < len(chunks); i++ {
str += tx.Dialect.BindVar(i-1) + chunks[i]
}
return str
}
func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
return tx.Tx.Select(i, tx.Rebind(query), args...)
}
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.Tx.Exec(tx.Rebind(query), args...)
}
func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) {
return tx.Tx.SelectInt(tx.Rebind(query), args...)
}
func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error {
return tx.Tx.SelectOne(holder, tx.Rebind(query), args...)
}
func (tx *Tx) Insert(list ...interface{}) error {
return tx.Tx.Insert(list...)
}
func (tx *Tx) Update(list ...interface{}) (int64, error) {
return tx.Tx.Update(list...)
}
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
return tx.Tx.Delete(list...)
}
func (tx *Tx) Commit() error {
return tx.Tx.Commit()
}
func (tx *Tx) Rollback() error {
return tx.Tx.Rollback()
}

24
internal/store/store.go Normal file
View File

@ -0,0 +1,24 @@
package store
import (
"github.com/aclindsa/moneygo/internal/models"
)
type SessionStore interface {
InsertSession(session *models.Session) error
GetSession(secret string) (*models.Session, error)
SessionExists(secret string) (bool, error)
DeleteSession(session *models.Session) error
}
type Tx interface {
Commit() error
Rollback() error
SessionStore
}
type Store interface {
Begin() (Tx, error)
Close() error
}

15
main.go
View File

@ -3,11 +3,10 @@ package main
//go:generate make
import (
"database/sql"
"flag"
"github.com/aclindsa/moneygo/internal/config"
"github.com/aclindsa/moneygo/internal/db"
"github.com/aclindsa/moneygo/internal/handlers"
"github.com/aclindsa/moneygo/internal/store/db"
"github.com/kabukky/httpscerts"
"log"
"net"
@ -67,21 +66,15 @@ func staticHandler(w http.ResponseWriter, r *http.Request, basedir string) {
}
func main() {
dsn := db.GetDSN(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN)
database, err := sql.Open(cfg.MoneyGo.DBType.String(), dsn)
if err != nil {
log.Fatal(err)
}
defer database.Close()
dbmap, err := db.GetDbMap(database, cfg.MoneyGo.DBType)
db, err := db.GetStore(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN)
if err != nil {
log.Fatal(err)
}
defer db.Close()
// Get ServeMux for API and add our own handlers for files
servemux := http.NewServeMux()
servemux.Handle("/v1/", &handlers.APIHandler{DB: dbmap})
servemux.Handle("/v1/", &handlers.APIHandler{Store: db})
servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir))
servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir))