mirror of
				https://github.com/aclindsa/moneygo.git
				synced 2025-11-03 18:13:27 -05:00 
			
		
		
		
	Lay groundwork and move sessions to 'store'
This commit is contained in:
		@@ -3,11 +3,12 @@ package handlers
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) {
 | 
			
		||||
func GetAccount(tx *db.Tx, accountid int64, userid int64) (*models.Account, error) {
 | 
			
		||||
	var a models.Account
 | 
			
		||||
 | 
			
		||||
	err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
 | 
			
		||||
@@ -17,7 +18,7 @@ func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error)
 | 
			
		||||
	return &a, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) {
 | 
			
		||||
func GetAccounts(tx *db.Tx, userid int64) (*[]models.Account, error) {
 | 
			
		||||
	var accounts []models.Account
 | 
			
		||||
 | 
			
		||||
	_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
 | 
			
		||||
@@ -29,7 +30,7 @@ func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) {
 | 
			
		||||
 | 
			
		||||
// Get (and attempt to create if it doesn't exist). Matches on UserId,
 | 
			
		||||
// SecurityId, Type, Name, and ParentAccountId
 | 
			
		||||
func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
 | 
			
		||||
func GetCreateAccount(tx *db.Tx, a models.Account) (*models.Account, error) {
 | 
			
		||||
	var accounts []models.Account
 | 
			
		||||
	var account models.Account
 | 
			
		||||
 | 
			
		||||
@@ -57,7 +58,7 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
 | 
			
		||||
 | 
			
		||||
// Get (and attempt to create if it doesn't exist) the security/currency
 | 
			
		||||
// trading account for the supplied security/currency
 | 
			
		||||
func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) {
 | 
			
		||||
func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) {
 | 
			
		||||
	var tradingAccount models.Account
 | 
			
		||||
	var account models.Account
 | 
			
		||||
 | 
			
		||||
@@ -99,7 +100,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account,
 | 
			
		||||
 | 
			
		||||
// Get (and attempt to create if it doesn't exist) the security/currency
 | 
			
		||||
// imbalance account for the supplied security/currency
 | 
			
		||||
func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) {
 | 
			
		||||
func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) {
 | 
			
		||||
	var imbalanceAccount models.Account
 | 
			
		||||
	var account models.Account
 | 
			
		||||
	xxxtemplate := FindSecurityTemplate("XXX", models.Currency)
 | 
			
		||||
@@ -160,7 +161,7 @@ func (cae CircularAccountsError) Error() string {
 | 
			
		||||
	return "Would result in circular account relationship"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error {
 | 
			
		||||
func insertUpdateAccount(tx *db.Tx, a *models.Account, insert bool) error {
 | 
			
		||||
	found := make(map[int64]bool)
 | 
			
		||||
	if !insert {
 | 
			
		||||
		found[a.AccountId] = true
 | 
			
		||||
@@ -216,15 +217,15 @@ func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InsertAccount(tx *Tx, a *models.Account) error {
 | 
			
		||||
func InsertAccount(tx *db.Tx, a *models.Account) error {
 | 
			
		||||
	return insertUpdateAccount(tx, a, true)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateAccount(tx *Tx, a *models.Account) error {
 | 
			
		||||
func UpdateAccount(tx *db.Tx, a *models.Account) error {
 | 
			
		||||
	return insertUpdateAccount(tx, a, false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteAccount(tx *Tx, a *models.Account) error {
 | 
			
		||||
func DeleteAccount(tx *db.Tx, a *models.Account) error {
 | 
			
		||||
	if a.ParentAccountId != -1 {
 | 
			
		||||
		// Re-parent splits to this account's parent account if this account isn't a root account
 | 
			
		||||
		_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId)
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"github.com/yuin/gopher-lua"
 | 
			
		||||
	"math/big"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -16,7 +17,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) {
 | 
			
		||||
 | 
			
		||||
	ctx := L.Context()
 | 
			
		||||
 | 
			
		||||
	tx, ok := ctx.Value(dbContextKey).(*Tx)
 | 
			
		||||
	tx, ok := ctx.Value(dbContextKey).(*db.Tx)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, errors.New("Couldn't find tx in lua's Context")
 | 
			
		||||
	}
 | 
			
		||||
@@ -150,7 +151,7 @@ func luaAccountBalance(L *lua.LState) int {
 | 
			
		||||
	a := luaCheckAccount(L, 1)
 | 
			
		||||
 | 
			
		||||
	ctx := L.Context()
 | 
			
		||||
	tx, ok := ctx.Value(dbContextKey).(*Tx)
 | 
			
		||||
	tx, ok := ctx.Value(dbContextKey).(*db.Tx)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		panic("Couldn't find tx in lua's Context")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,12 +2,11 @@ package handlers_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/config"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/db"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/handlers"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"log"
 | 
			
		||||
@@ -253,24 +252,18 @@ func RunTests(m *testing.M) int {
 | 
			
		||||
		dsn = envDSN
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dsn = db.GetDSN(dbType, dsn)
 | 
			
		||||
	database, err := sql.Open(dbType.String(), dsn)
 | 
			
		||||
	db, err := db.GetStore(dbType, dsn)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer database.Close()
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	dbmap, err := db.GetDbMap(database, dbType)
 | 
			
		||||
	err = db.DbMap.TruncateTables()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = dbmap.TruncateTables()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	server = httptest.NewTLSServer(&handlers.APIHandler{DB: dbmap})
 | 
			
		||||
	server = httptest.NewTLSServer(&handlers.APIHandler{Store: db})
 | 
			
		||||
	defer server.Close()
 | 
			
		||||
 | 
			
		||||
	return m.Run()
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,9 @@
 | 
			
		||||
package handlers
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/aclindsa/gorp"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"path"
 | 
			
		||||
@@ -16,7 +17,8 @@ type ResponseWriterWriter interface {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Context struct {
 | 
			
		||||
	Tx           *Tx
 | 
			
		||||
	Tx           *db.Tx
 | 
			
		||||
	StoreTx      store.Tx
 | 
			
		||||
	User         *models.User
 | 
			
		||||
	remainingURL string // portion of URL path not yet reached in the hierarchy
 | 
			
		||||
}
 | 
			
		||||
@@ -46,11 +48,11 @@ func (c *Context) LastLevel() bool {
 | 
			
		||||
type Handler func(*http.Request, *Context) ResponseWriterWriter
 | 
			
		||||
 | 
			
		||||
type APIHandler struct {
 | 
			
		||||
	DB *gorp.DbMap
 | 
			
		||||
	Store *db.DbStore
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
 | 
			
		||||
	tx, err := GetTx(ah.DB)
 | 
			
		||||
	tx, err := GetTx(ah.Store.DbMap)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Print(err)
 | 
			
		||||
		return NewError(999 /*Internal Error*/)
 | 
			
		||||
@@ -72,6 +74,33 @@ func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (w
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	context.Tx = tx
 | 
			
		||||
	context.StoreTx = tx
 | 
			
		||||
	return h(r, context)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ah *APIHandler) storeTxWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
 | 
			
		||||
	tx, err := ah.Store.Begin()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Print(err)
 | 
			
		||||
		return NewError(999 /*Internal Error*/)
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if r := recover(); r != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			panic(r)
 | 
			
		||||
		}
 | 
			
		||||
		if _, ok := writer.(*Error); ok {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
		} else {
 | 
			
		||||
			err = tx.Commit()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Print(err)
 | 
			
		||||
				writer = NewError(999 /*Internal Error*/)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	context.StoreTx = tx
 | 
			
		||||
	return h(r, context)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package handlers
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"github.com/aclindsa/ofxgo"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
@@ -23,7 +24,7 @@ func (od *OFXDownload) Read(json_str string) error {
 | 
			
		||||
	return dec.Decode(od)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
 | 
			
		||||
func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
 | 
			
		||||
	itl, err := ImportOFX(r)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -2,12 +2,13 @@ package handlers
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func CreatePriceIfNotExist(tx *Tx, price *models.Price) error {
 | 
			
		||||
func CreatePriceIfNotExist(tx *db.Tx, price *models.Price) error {
 | 
			
		||||
	if len(price.RemoteId) == 0 {
 | 
			
		||||
		// Always create a new price if we can't match on the RemoteId
 | 
			
		||||
		err := tx.Insert(price)
 | 
			
		||||
@@ -35,7 +36,7 @@ func CreatePriceIfNotExist(tx *Tx, price *models.Price) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) {
 | 
			
		||||
func GetPrice(tx *db.Tx, priceid, securityid int64) (*models.Price, error) {
 | 
			
		||||
	var p models.Price
 | 
			
		||||
	err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -44,7 +45,7 @@ func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) {
 | 
			
		||||
	return &p, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) {
 | 
			
		||||
func GetPrices(tx *db.Tx, securityid int64) (*[]*models.Price, error) {
 | 
			
		||||
	var prices []*models.Price
 | 
			
		||||
 | 
			
		||||
	_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid)
 | 
			
		||||
@@ -55,7 +56,7 @@ func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Return the latest price for security in currency units before date
 | 
			
		||||
func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
 | 
			
		||||
func GetLatestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
 | 
			
		||||
	var p models.Price
 | 
			
		||||
	err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -65,7 +66,7 @@ func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Return the earliest price for security in currency units after date
 | 
			
		||||
func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
 | 
			
		||||
func GetEarliestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
 | 
			
		||||
	var p models.Price
 | 
			
		||||
	err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -75,7 +76,7 @@ func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Ti
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Return the price for security in currency closest to date
 | 
			
		||||
func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
 | 
			
		||||
func GetClosestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
 | 
			
		||||
	earliest, _ := GetEarliestPrice(tx, security, currency, date)
 | 
			
		||||
	latest, err := GetLatestPrice(tx, security, currency, date)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"github.com/yuin/gopher-lua"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -24,7 +25,7 @@ const (
 | 
			
		||||
 | 
			
		||||
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
 | 
			
		||||
 | 
			
		||||
func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) {
 | 
			
		||||
func GetReport(tx *db.Tx, reportid int64, userid int64) (*models.Report, error) {
 | 
			
		||||
	var r models.Report
 | 
			
		||||
 | 
			
		||||
	err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
 | 
			
		||||
@@ -34,7 +35,7 @@ func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) {
 | 
			
		||||
	return &r, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetReports(tx *Tx, userid int64) (*[]models.Report, error) {
 | 
			
		||||
func GetReports(tx *db.Tx, userid int64) (*[]models.Report, error) {
 | 
			
		||||
	var reports []models.Report
 | 
			
		||||
 | 
			
		||||
	_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
 | 
			
		||||
@@ -44,7 +45,7 @@ func GetReports(tx *Tx, userid int64) (*[]models.Report, error) {
 | 
			
		||||
	return &reports, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InsertReport(tx *Tx, r *models.Report) error {
 | 
			
		||||
func InsertReport(tx *db.Tx, r *models.Report) error {
 | 
			
		||||
	err := tx.Insert(r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -52,7 +53,7 @@ func InsertReport(tx *Tx, r *models.Report) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateReport(tx *Tx, r *models.Report) error {
 | 
			
		||||
func UpdateReport(tx *db.Tx, r *models.Report) error {
 | 
			
		||||
	count, err := tx.Update(r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -63,7 +64,7 @@ func UpdateReport(tx *Tx, r *models.Report) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteReport(tx *Tx, r *models.Report) error {
 | 
			
		||||
func DeleteReport(tx *db.Tx, r *models.Report) error {
 | 
			
		||||
	count, err := tx.Delete(r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -74,7 +75,7 @@ func DeleteReport(tx *Tx, r *models.Report) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
 | 
			
		||||
func runReport(tx *db.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
 | 
			
		||||
	// Create a new LState without opening the default libs for security
 | 
			
		||||
	L := lua.NewState(lua.Options{SkipOpenLibs: true})
 | 
			
		||||
	defer L.Close()
 | 
			
		||||
@@ -138,7 +139,7 @@ func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabula
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
 | 
			
		||||
func ReportTabulationHandler(tx *db.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
 | 
			
		||||
	report, err := GetReport(tx, reportid, user.UserId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return NewError(3 /*Invalid Request*/)
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
@@ -50,7 +51,7 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, error) {
 | 
			
		||||
func GetSecurity(tx *db.Tx, securityid int64, userid int64) (*models.Security, error) {
 | 
			
		||||
	var s models.Security
 | 
			
		||||
 | 
			
		||||
	err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
 | 
			
		||||
@@ -60,7 +61,7 @@ func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, erro
 | 
			
		||||
	return &s, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) {
 | 
			
		||||
func GetSecurities(tx *db.Tx, userid int64) (*[]*models.Security, error) {
 | 
			
		||||
	var securities []*models.Security
 | 
			
		||||
 | 
			
		||||
	_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
 | 
			
		||||
@@ -70,7 +71,7 @@ func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) {
 | 
			
		||||
	return &securities, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InsertSecurity(tx *Tx, s *models.Security) error {
 | 
			
		||||
func InsertSecurity(tx *db.Tx, s *models.Security) error {
 | 
			
		||||
	err := tx.Insert(s)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -78,7 +79,7 @@ func InsertSecurity(tx *Tx, s *models.Security) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateSecurity(tx *Tx, s *models.Security) (err error) {
 | 
			
		||||
func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) {
 | 
			
		||||
	user, err := GetUser(tx, s.UserId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
@@ -105,7 +106,7 @@ func (e SecurityInUseError) Error() string {
 | 
			
		||||
	return e.message
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteSecurity(tx *Tx, s *models.Security) error {
 | 
			
		||||
func DeleteSecurity(tx *db.Tx, s *models.Security) error {
 | 
			
		||||
	// First, ensure no accounts are using this security
 | 
			
		||||
	accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
 | 
			
		||||
 | 
			
		||||
@@ -138,7 +139,7 @@ func DeleteSecurity(tx *Tx, s *models.Security) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*models.Security, error) {
 | 
			
		||||
func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) (*models.Security, error) {
 | 
			
		||||
	security.UserId = userid
 | 
			
		||||
	if len(security.AlternateId) == 0 {
 | 
			
		||||
		// Always create a new local security if we can't match on the AlternateId
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"github.com/yuin/gopher-lua"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -14,7 +15,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error)
 | 
			
		||||
 | 
			
		||||
	ctx := L.Context()
 | 
			
		||||
 | 
			
		||||
	tx, ok := ctx.Value(dbContextKey).(*Tx)
 | 
			
		||||
	tx, ok := ctx.Value(dbContextKey).(*db.Tx)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, errors.New("Couldn't find tx in lua's Context")
 | 
			
		||||
	}
 | 
			
		||||
@@ -158,7 +159,7 @@ func luaClosestPrice(L *lua.LState) int {
 | 
			
		||||
	date := luaCheckTime(L, 3)
 | 
			
		||||
 | 
			
		||||
	ctx := L.Context()
 | 
			
		||||
	tx, ok := ctx.Value(dbContextKey).(*Tx)
 | 
			
		||||
	tx, ok := ctx.Value(dbContextKey).(*db.Tx)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		panic("Couldn't find tx in lua's Context")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,36 +3,37 @@ package handlers
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetSession(tx *Tx, r *http.Request) (*models.Session, error) {
 | 
			
		||||
	var s models.Session
 | 
			
		||||
 | 
			
		||||
func GetSession(tx store.Tx, r *http.Request) (*models.Session, error) {
 | 
			
		||||
	cookie, err := r.Cookie("moneygo-session")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("moneygo-session cookie not set")
 | 
			
		||||
	}
 | 
			
		||||
	s.SessionSecret = cookie.Value
 | 
			
		||||
 | 
			
		||||
	err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
 | 
			
		||||
	s, err := tx.GetSession(cookie.Value)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if s.Expires.Before(time.Now()) {
 | 
			
		||||
		tx.Delete(&s)
 | 
			
		||||
		err := tx.DeleteSession(s)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Printf("Unexpected error when attempting to delete expired session: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
		return nil, fmt.Errorf("Session has expired")
 | 
			
		||||
	}
 | 
			
		||||
	return &s, nil
 | 
			
		||||
	return s, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteSessionIfExists(tx *Tx, r *http.Request) error {
 | 
			
		||||
func DeleteSessionIfExists(tx store.Tx, r *http.Request) error {
 | 
			
		||||
	session, err := GetSession(tx, r)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		_, err := tx.Delete(session)
 | 
			
		||||
		err := tx.DeleteSession(session)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
@@ -50,21 +51,26 @@ func (n *NewSessionWriter) Write(w http.ResponseWriter) error {
 | 
			
		||||
	return n.session.Write(w)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
 | 
			
		||||
func NewSession(tx store.Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
 | 
			
		||||
	err := DeleteSessionIfExists(tx, r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s, err := models.NewSession(userid)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", s.SessionSecret)
 | 
			
		||||
	exists, err := tx.SessionExists(s.SessionSecret)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if existing > 0 {
 | 
			
		||||
		return nil, fmt.Errorf("%d session(s) exist with the generated session_secret", existing)
 | 
			
		||||
	if exists {
 | 
			
		||||
		return nil, fmt.Errorf("Session already exists with the generated session_secret")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = tx.Insert(s)
 | 
			
		||||
	err = tx.InsertSession(s)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -89,27 +95,21 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter {
 | 
			
		||||
			return NewError(2 /*Unauthorized Access*/)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = DeleteSessionIfExists(context.Tx, r)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Print(err)
 | 
			
		||||
			return NewError(999 /*Internal Error*/)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId)
 | 
			
		||||
		sessionwriter, err := NewSession(context.StoreTx, r, dbuser.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Print(err)
 | 
			
		||||
			return NewError(999 /*Internal Error*/)
 | 
			
		||||
		}
 | 
			
		||||
		return sessionwriter
 | 
			
		||||
	} else if r.Method == "GET" {
 | 
			
		||||
		s, err := GetSession(context.Tx, r)
 | 
			
		||||
		s, err := GetSession(context.StoreTx, r)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return NewError(1 /*Not Signed In*/)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return s
 | 
			
		||||
	} else if r.Method == "DELETE" {
 | 
			
		||||
		err := DeleteSessionIfExists(context.Tx, r)
 | 
			
		||||
		err := DeleteSessionIfExists(context.StoreTx, r)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Print(err)
 | 
			
		||||
			return NewError(999 /*Internal Error*/)
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"log"
 | 
			
		||||
	"math/big"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -12,14 +13,14 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SplitAlreadyImported(tx *Tx, s *models.Split) (bool, error) {
 | 
			
		||||
func SplitAlreadyImported(tx *db.Tx, s *models.Split) (bool, error) {
 | 
			
		||||
	count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
 | 
			
		||||
	return count == 1, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Return a map of security ID's to big.Rat's containing the amount that
 | 
			
		||||
// security is imbalanced by
 | 
			
		||||
func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, error) {
 | 
			
		||||
func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.Rat, error) {
 | 
			
		||||
	sums := make(map[int64]big.Rat)
 | 
			
		||||
 | 
			
		||||
	if !t.Valid() {
 | 
			
		||||
@@ -47,7 +48,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat,
 | 
			
		||||
 | 
			
		||||
// Returns true if all securities contained in this transaction are balanced,
 | 
			
		||||
// false otherwise
 | 
			
		||||
func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) {
 | 
			
		||||
func TransactionBalanced(tx *db.Tx, t *models.Transaction) (bool, error) {
 | 
			
		||||
	var zero big.Rat
 | 
			
		||||
 | 
			
		||||
	sums, err := GetTransactionImbalances(tx, t)
 | 
			
		||||
@@ -63,7 +64,7 @@ func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) {
 | 
			
		||||
	return true, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transaction, error) {
 | 
			
		||||
func GetTransaction(tx *db.Tx, transactionid int64, userid int64) (*models.Transaction, error) {
 | 
			
		||||
	var t models.Transaction
 | 
			
		||||
 | 
			
		||||
	err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
 | 
			
		||||
@@ -79,7 +80,7 @@ func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transact
 | 
			
		||||
	return &t, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) {
 | 
			
		||||
func GetTransactions(tx *db.Tx, userid int64) (*[]models.Transaction, error) {
 | 
			
		||||
	var transactions []models.Transaction
 | 
			
		||||
 | 
			
		||||
	_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
 | 
			
		||||
@@ -97,7 +98,7 @@ func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) {
 | 
			
		||||
	return &transactions, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error {
 | 
			
		||||
func incrementAccountVersions(tx *db.Tx, user *models.User, accountids []int64) error {
 | 
			
		||||
	for i := range accountids {
 | 
			
		||||
		account, err := GetAccount(tx, accountids[i], user.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -121,7 +122,7 @@ func (ame AccountMissingError) Error() string {
 | 
			
		||||
	return "Account missing"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
 | 
			
		||||
func InsertTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error {
 | 
			
		||||
	// Map of any accounts with transaction splits being added
 | 
			
		||||
	a_map := make(map[int64]bool)
 | 
			
		||||
	for i := range t.Splits {
 | 
			
		||||
@@ -171,7 +172,7 @@ func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
 | 
			
		||||
func UpdateTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error {
 | 
			
		||||
	var existing_splits []*models.Split
 | 
			
		||||
 | 
			
		||||
	_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
 | 
			
		||||
@@ -248,7 +249,7 @@ func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
 | 
			
		||||
func DeleteTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error {
 | 
			
		||||
	var accountids []int64
 | 
			
		||||
	_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -401,7 +402,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
 | 
			
		||||
	return NewError(3 /*Invalid Request*/)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) {
 | 
			
		||||
func TransactionsBalanceDifference(tx *db.Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) {
 | 
			
		||||
	var pageDifference, tmp big.Rat
 | 
			
		||||
	for i := range transactions {
 | 
			
		||||
		_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
 | 
			
		||||
@@ -425,7 +426,7 @@ func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []model
 | 
			
		||||
	return &pageDifference, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) {
 | 
			
		||||
func GetAccountBalance(tx *db.Tx, user *models.User, accountid int64) (*big.Rat, error) {
 | 
			
		||||
	var splits []models.Split
 | 
			
		||||
 | 
			
		||||
	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
 | 
			
		||||
@@ -448,7 +449,7 @@ func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, er
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Assumes accountid is valid and is owned by the current user
 | 
			
		||||
func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) {
 | 
			
		||||
func GetAccountBalanceDate(tx *db.Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) {
 | 
			
		||||
	var splits []models.Split
 | 
			
		||||
 | 
			
		||||
	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
 | 
			
		||||
@@ -470,7 +471,7 @@ func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *tim
 | 
			
		||||
	return &balance, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
 | 
			
		||||
func GetAccountBalanceDateRange(tx *db.Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
 | 
			
		||||
	var splits []models.Split
 | 
			
		||||
 | 
			
		||||
	sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
 | 
			
		||||
@@ -492,7 +493,7 @@ func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begi
 | 
			
		||||
	return &balance, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
 | 
			
		||||
func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
 | 
			
		||||
	var transactions []models.Transaction
 | 
			
		||||
	var atl models.AccountTransactionsList
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,65 +1,14 @@
 | 
			
		||||
package handlers
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"github.com/aclindsa/gorp"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Tx struct {
 | 
			
		||||
	Dialect gorp.Dialect
 | 
			
		||||
	Tx      *gorp.Transaction
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Rebind(query string) string {
 | 
			
		||||
	chunks := strings.Split(query, "?")
 | 
			
		||||
	str := chunks[0]
 | 
			
		||||
	for i := 1; i < len(chunks); i++ {
 | 
			
		||||
		str += tx.Dialect.BindVar(i-1) + chunks[i]
 | 
			
		||||
	}
 | 
			
		||||
	return str
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
 | 
			
		||||
	return tx.Tx.Select(i, tx.Rebind(query), args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
 | 
			
		||||
	return tx.Tx.Exec(tx.Rebind(query), args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) {
 | 
			
		||||
	return tx.Tx.SelectInt(tx.Rebind(query), args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error {
 | 
			
		||||
	return tx.Tx.SelectOne(holder, tx.Rebind(query), args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Insert(list ...interface{}) error {
 | 
			
		||||
	return tx.Tx.Insert(list...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Update(list ...interface{}) (int64, error) {
 | 
			
		||||
	return tx.Tx.Update(list...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
 | 
			
		||||
	return tx.Tx.Delete(list...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Commit() error {
 | 
			
		||||
	return tx.Tx.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Rollback() error {
 | 
			
		||||
	return tx.Tx.Rollback()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetTx(db *gorp.DbMap) (*Tx, error) {
 | 
			
		||||
	tx, err := db.Begin()
 | 
			
		||||
func GetTx(gdb *gorp.DbMap) (*db.Tx, error) {
 | 
			
		||||
	tx, err := gdb.Begin()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &Tx{db.Dialect, tx}, nil
 | 
			
		||||
	return &db.Tx{gdb.Dialect, tx}, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store/db"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
@@ -14,7 +15,7 @@ func (ueu UserExistsError) Error() string {
 | 
			
		||||
	return "User exists"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUser(tx *Tx, userid int64) (*models.User, error) {
 | 
			
		||||
func GetUser(tx *db.Tx, userid int64) (*models.User, error) {
 | 
			
		||||
	var u models.User
 | 
			
		||||
 | 
			
		||||
	err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
 | 
			
		||||
@@ -24,7 +25,7 @@ func GetUser(tx *Tx, userid int64) (*models.User, error) {
 | 
			
		||||
	return &u, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserByUsername(tx *Tx, username string) (*models.User, error) {
 | 
			
		||||
func GetUserByUsername(tx *db.Tx, username string) (*models.User, error) {
 | 
			
		||||
	var u models.User
 | 
			
		||||
 | 
			
		||||
	err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
 | 
			
		||||
@@ -34,7 +35,7 @@ func GetUserByUsername(tx *Tx, username string) (*models.User, error) {
 | 
			
		||||
	return &u, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InsertUser(tx *Tx, u *models.User) error {
 | 
			
		||||
func InsertUser(tx *db.Tx, u *models.User) error {
 | 
			
		||||
	security_template := FindCurrencyTemplate(u.DefaultCurrency)
 | 
			
		||||
	if security_template == nil {
 | 
			
		||||
		return errors.New("Invalid ISO4217 Default Currency")
 | 
			
		||||
@@ -75,7 +76,7 @@ func InsertUser(tx *Tx, u *models.User) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) {
 | 
			
		||||
func GetUserFromSession(tx *db.Tx, r *http.Request) (*models.User, error) {
 | 
			
		||||
	s, err := GetSession(tx, r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -83,7 +84,7 @@ func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) {
 | 
			
		||||
	return GetUser(tx, s.UserId)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateUser(tx *Tx, u *models.User) error {
 | 
			
		||||
func UpdateUser(tx *db.Tx, u *models.User) error {
 | 
			
		||||
	security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -103,7 +104,7 @@ func UpdateUser(tx *Tx, u *models.User) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteUser(tx *Tx, u *models.User) error {
 | 
			
		||||
func DeleteUser(tx *db.Tx, u *models.User) error {
 | 
			
		||||
	count, err := tx.Delete(u)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"github.com/aclindsa/gorp"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/config"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/store"
 | 
			
		||||
	_ "github.com/go-sql-driver/mysql"
 | 
			
		||||
	_ "github.com/lib/pq"
 | 
			
		||||
	_ "github.com/mattn/go-sqlite3"
 | 
			
		||||
@@ -60,3 +61,40 @@ func GetDSN(dbtype config.DbType, dsn string) string {
 | 
			
		||||
	}
 | 
			
		||||
	return dsn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DbStore struct {
 | 
			
		||||
	DbMap *gorp.DbMap
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *DbStore) Begin() (store.Tx, error) {
 | 
			
		||||
	tx, err := db.DbMap.Begin()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &Tx{db.DbMap.Dialect, tx}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *DbStore) Close() error {
 | 
			
		||||
	err := db.DbMap.Db.Close()
 | 
			
		||||
	db.DbMap = nil
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) {
 | 
			
		||||
	dsn = GetDSN(dbtype, dsn)
 | 
			
		||||
	database, err := sql.Open(dbtype.String(), dsn)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			database.Close()
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	dbmap, err := GetDbMap(database, dbtype)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &DbStore{dbmap}, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										36
									
								
								internal/store/db/sessions.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								internal/store/db/sessions.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
			
		||||
package db
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) InsertSession(session *models.Session) error {
 | 
			
		||||
	return tx.Insert(session)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) GetSession(secret string) (*models.Session, error) {
 | 
			
		||||
	var s models.Session
 | 
			
		||||
 | 
			
		||||
	err := tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", secret)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if s.Expires.Before(time.Now()) {
 | 
			
		||||
		tx.Delete(&s)
 | 
			
		||||
		return nil, fmt.Errorf("Session has expired")
 | 
			
		||||
	}
 | 
			
		||||
	return &s, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) SessionExists(secret string) (bool, error) {
 | 
			
		||||
	existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", secret)
 | 
			
		||||
	return existing != 0, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) DeleteSession(session *models.Session) error {
 | 
			
		||||
	_, err := tx.Delete(session)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										57
									
								
								internal/store/db/tx.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								internal/store/db/tx.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,57 @@
 | 
			
		||||
package db
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"github.com/aclindsa/gorp"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Tx struct {
 | 
			
		||||
	Dialect gorp.Dialect
 | 
			
		||||
	Tx      *gorp.Transaction
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Rebind(query string) string {
 | 
			
		||||
	chunks := strings.Split(query, "?")
 | 
			
		||||
	str := chunks[0]
 | 
			
		||||
	for i := 1; i < len(chunks); i++ {
 | 
			
		||||
		str += tx.Dialect.BindVar(i-1) + chunks[i]
 | 
			
		||||
	}
 | 
			
		||||
	return str
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
 | 
			
		||||
	return tx.Tx.Select(i, tx.Rebind(query), args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
 | 
			
		||||
	return tx.Tx.Exec(tx.Rebind(query), args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) {
 | 
			
		||||
	return tx.Tx.SelectInt(tx.Rebind(query), args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error {
 | 
			
		||||
	return tx.Tx.SelectOne(holder, tx.Rebind(query), args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Insert(list ...interface{}) error {
 | 
			
		||||
	return tx.Tx.Insert(list...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Update(list ...interface{}) (int64, error) {
 | 
			
		||||
	return tx.Tx.Update(list...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
 | 
			
		||||
	return tx.Tx.Delete(list...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Commit() error {
 | 
			
		||||
	return tx.Tx.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *Tx) Rollback() error {
 | 
			
		||||
	return tx.Tx.Rollback()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										24
									
								
								internal/store/store.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								internal/store/store.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,24 @@
 | 
			
		||||
package store
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/aclindsa/moneygo/internal/models"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type SessionStore interface {
 | 
			
		||||
	InsertSession(session *models.Session) error
 | 
			
		||||
	GetSession(secret string) (*models.Session, error)
 | 
			
		||||
	SessionExists(secret string) (bool, error)
 | 
			
		||||
	DeleteSession(session *models.Session) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Tx interface {
 | 
			
		||||
	Commit() error
 | 
			
		||||
	Rollback() error
 | 
			
		||||
 | 
			
		||||
	SessionStore
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Store interface {
 | 
			
		||||
	Begin() (Tx, error)
 | 
			
		||||
	Close() error
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user