mirror of
https://github.com/aclindsa/moneygo.git
synced 2024-12-25 23:23:21 -05:00
Lay groundwork and move sessions to 'store'
This commit is contained in:
parent
6bdde8e83b
commit
c452984f23
@ -3,11 +3,12 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) {
|
func GetAccount(tx *db.Tx, accountid int64, userid int64) (*models.Account, error) {
|
||||||
var a models.Account
|
var a models.Account
|
||||||
|
|
||||||
err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
||||||
@ -17,7 +18,7 @@ func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error)
|
|||||||
return &a, nil
|
return &a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) {
|
func GetAccounts(tx *db.Tx, userid int64) (*[]models.Account, error) {
|
||||||
var accounts []models.Account
|
var accounts []models.Account
|
||||||
|
|
||||||
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
||||||
@ -29,7 +30,7 @@ func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) {
|
|||||||
|
|
||||||
// Get (and attempt to create if it doesn't exist). Matches on UserId,
|
// Get (and attempt to create if it doesn't exist). Matches on UserId,
|
||||||
// SecurityId, Type, Name, and ParentAccountId
|
// SecurityId, Type, Name, and ParentAccountId
|
||||||
func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
|
func GetCreateAccount(tx *db.Tx, a models.Account) (*models.Account, error) {
|
||||||
var accounts []models.Account
|
var accounts []models.Account
|
||||||
var account models.Account
|
var account models.Account
|
||||||
|
|
||||||
@ -57,7 +58,7 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
|
|||||||
|
|
||||||
// Get (and attempt to create if it doesn't exist) the security/currency
|
// Get (and attempt to create if it doesn't exist) the security/currency
|
||||||
// trading account for the supplied security/currency
|
// trading account for the supplied security/currency
|
||||||
func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) {
|
func GetTradingAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) {
|
||||||
var tradingAccount models.Account
|
var tradingAccount models.Account
|
||||||
var account models.Account
|
var account models.Account
|
||||||
|
|
||||||
@ -99,7 +100,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account,
|
|||||||
|
|
||||||
// Get (and attempt to create if it doesn't exist) the security/currency
|
// Get (and attempt to create if it doesn't exist) the security/currency
|
||||||
// imbalance account for the supplied security/currency
|
// imbalance account for the supplied security/currency
|
||||||
func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) {
|
func GetImbalanceAccount(tx *db.Tx, userid int64, securityid int64) (*models.Account, error) {
|
||||||
var imbalanceAccount models.Account
|
var imbalanceAccount models.Account
|
||||||
var account models.Account
|
var account models.Account
|
||||||
xxxtemplate := FindSecurityTemplate("XXX", models.Currency)
|
xxxtemplate := FindSecurityTemplate("XXX", models.Currency)
|
||||||
@ -160,7 +161,7 @@ func (cae CircularAccountsError) Error() string {
|
|||||||
return "Would result in circular account relationship"
|
return "Would result in circular account relationship"
|
||||||
}
|
}
|
||||||
|
|
||||||
func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error {
|
func insertUpdateAccount(tx *db.Tx, a *models.Account, insert bool) error {
|
||||||
found := make(map[int64]bool)
|
found := make(map[int64]bool)
|
||||||
if !insert {
|
if !insert {
|
||||||
found[a.AccountId] = true
|
found[a.AccountId] = true
|
||||||
@ -216,15 +217,15 @@ func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertAccount(tx *Tx, a *models.Account) error {
|
func InsertAccount(tx *db.Tx, a *models.Account) error {
|
||||||
return insertUpdateAccount(tx, a, true)
|
return insertUpdateAccount(tx, a, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateAccount(tx *Tx, a *models.Account) error {
|
func UpdateAccount(tx *db.Tx, a *models.Account) error {
|
||||||
return insertUpdateAccount(tx, a, false)
|
return insertUpdateAccount(tx, a, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteAccount(tx *Tx, a *models.Account) error {
|
func DeleteAccount(tx *db.Tx, a *models.Account) error {
|
||||||
if a.ParentAccountId != -1 {
|
if a.ParentAccountId != -1 {
|
||||||
// Re-parent splits to this account's parent account if this account isn't a root account
|
// Re-parent splits to this account's parent account if this account isn't a root account
|
||||||
_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId)
|
_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId)
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"github.com/yuin/gopher-lua"
|
"github.com/yuin/gopher-lua"
|
||||||
"math/big"
|
"math/big"
|
||||||
"strings"
|
"strings"
|
||||||
@ -16,7 +17,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) {
|
|||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
|
|
||||||
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
tx, ok := ctx.Value(dbContextKey).(*db.Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Couldn't find tx in lua's Context")
|
return nil, errors.New("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
@ -150,7 +151,7 @@ func luaAccountBalance(L *lua.LState) int {
|
|||||||
a := luaCheckAccount(L, 1)
|
a := luaCheckAccount(L, 1)
|
||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
tx, ok := ctx.Value(dbContextKey).(*db.Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("Couldn't find tx in lua's Context")
|
panic("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
|
@ -2,12 +2,11 @@ package handlers_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/aclindsa/moneygo/internal/config"
|
"github.com/aclindsa/moneygo/internal/config"
|
||||||
"github.com/aclindsa/moneygo/internal/db"
|
|
||||||
"github.com/aclindsa/moneygo/internal/handlers"
|
"github.com/aclindsa/moneygo/internal/handlers"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
@ -253,24 +252,18 @@ func RunTests(m *testing.M) int {
|
|||||||
dsn = envDSN
|
dsn = envDSN
|
||||||
}
|
}
|
||||||
|
|
||||||
dsn = db.GetDSN(dbType, dsn)
|
db, err := db.GetStore(dbType, dsn)
|
||||||
database, err := sql.Open(dbType.String(), dsn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
defer database.Close()
|
defer db.Close()
|
||||||
|
|
||||||
dbmap, err := db.GetDbMap(database, dbType)
|
err = db.DbMap.TruncateTables()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dbmap.TruncateTables()
|
server = httptest.NewTLSServer(&handlers.APIHandler{Store: db})
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
server = httptest.NewTLSServer(&handlers.APIHandler{DB: dbmap})
|
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
return m.Run()
|
return m.Run()
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/aclindsa/gorp"
|
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
@ -16,7 +17,8 @@ type ResponseWriterWriter interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Context struct {
|
type Context struct {
|
||||||
Tx *Tx
|
Tx *db.Tx
|
||||||
|
StoreTx store.Tx
|
||||||
User *models.User
|
User *models.User
|
||||||
remainingURL string // portion of URL path not yet reached in the hierarchy
|
remainingURL string // portion of URL path not yet reached in the hierarchy
|
||||||
}
|
}
|
||||||
@ -46,11 +48,11 @@ func (c *Context) LastLevel() bool {
|
|||||||
type Handler func(*http.Request, *Context) ResponseWriterWriter
|
type Handler func(*http.Request, *Context) ResponseWriterWriter
|
||||||
|
|
||||||
type APIHandler struct {
|
type APIHandler struct {
|
||||||
DB *gorp.DbMap
|
Store *db.DbStore
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
|
func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
|
||||||
tx, err := GetTx(ah.DB)
|
tx, err := GetTx(ah.Store.DbMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -72,6 +74,33 @@ func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (w
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
context.Tx = tx
|
context.Tx = tx
|
||||||
|
context.StoreTx = tx
|
||||||
|
return h(r, context)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ah *APIHandler) storeTxWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
|
||||||
|
tx, err := ah.Store.Begin()
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return NewError(999 /*Internal Error*/)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
if _, ok := writer.(*Error); ok {
|
||||||
|
tx.Rollback()
|
||||||
|
} else {
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
writer = NewError(999 /*Internal Error*/)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
context.StoreTx = tx
|
||||||
return h(r, context)
|
return h(r, context)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"github.com/aclindsa/ofxgo"
|
"github.com/aclindsa/ofxgo"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@ -23,7 +24,7 @@ func (od *OFXDownload) Read(json_str string) error {
|
|||||||
return dec.Decode(od)
|
return dec.Decode(od)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
|
func ofxImportHelper(tx *db.Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
|
||||||
itl, err := ImportOFX(r)
|
itl, err := ImportOFX(r)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2,12 +2,13 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CreatePriceIfNotExist(tx *Tx, price *models.Price) error {
|
func CreatePriceIfNotExist(tx *db.Tx, price *models.Price) error {
|
||||||
if len(price.RemoteId) == 0 {
|
if len(price.RemoteId) == 0 {
|
||||||
// Always create a new price if we can't match on the RemoteId
|
// Always create a new price if we can't match on the RemoteId
|
||||||
err := tx.Insert(price)
|
err := tx.Insert(price)
|
||||||
@ -35,7 +36,7 @@ func CreatePriceIfNotExist(tx *Tx, price *models.Price) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) {
|
func GetPrice(tx *db.Tx, priceid, securityid int64) (*models.Price, error) {
|
||||||
var p models.Price
|
var p models.Price
|
||||||
err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid)
|
err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -44,7 +45,7 @@ func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) {
|
|||||||
return &p, nil
|
return &p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) {
|
func GetPrices(tx *db.Tx, securityid int64) (*[]*models.Price, error) {
|
||||||
var prices []*models.Price
|
var prices []*models.Price
|
||||||
|
|
||||||
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid)
|
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid)
|
||||||
@ -55,7 +56,7 @@ func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the latest price for security in currency units before date
|
// Return the latest price for security in currency units before date
|
||||||
func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
func GetLatestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
||||||
var p models.Price
|
var p models.Price
|
||||||
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -65,7 +66,7 @@ func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the earliest price for security in currency units after date
|
// Return the earliest price for security in currency units after date
|
||||||
func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
func GetEarliestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
||||||
var p models.Price
|
var p models.Price
|
||||||
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -75,7 +76,7 @@ func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Ti
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the price for security in currency closest to date
|
// Return the price for security in currency closest to date
|
||||||
func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
func GetClosestPrice(tx *db.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
||||||
earliest, _ := GetEarliestPrice(tx, security, currency, date)
|
earliest, _ := GetEarliestPrice(tx, security, currency, date)
|
||||||
latest, err := GetLatestPrice(tx, security, currency, date)
|
latest, err := GetLatestPrice(tx, security, currency, date)
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"github.com/yuin/gopher-lua"
|
"github.com/yuin/gopher-lua"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -24,7 +25,7 @@ const (
|
|||||||
|
|
||||||
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
|
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
|
||||||
|
|
||||||
func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) {
|
func GetReport(tx *db.Tx, reportid int64, userid int64) (*models.Report, error) {
|
||||||
var r models.Report
|
var r models.Report
|
||||||
|
|
||||||
err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
||||||
@ -34,7 +35,7 @@ func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) {
|
|||||||
return &r, nil
|
return &r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetReports(tx *Tx, userid int64) (*[]models.Report, error) {
|
func GetReports(tx *db.Tx, userid int64) (*[]models.Report, error) {
|
||||||
var reports []models.Report
|
var reports []models.Report
|
||||||
|
|
||||||
_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
||||||
@ -44,7 +45,7 @@ func GetReports(tx *Tx, userid int64) (*[]models.Report, error) {
|
|||||||
return &reports, nil
|
return &reports, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertReport(tx *Tx, r *models.Report) error {
|
func InsertReport(tx *db.Tx, r *models.Report) error {
|
||||||
err := tx.Insert(r)
|
err := tx.Insert(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -52,7 +53,7 @@ func InsertReport(tx *Tx, r *models.Report) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateReport(tx *Tx, r *models.Report) error {
|
func UpdateReport(tx *db.Tx, r *models.Report) error {
|
||||||
count, err := tx.Update(r)
|
count, err := tx.Update(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -63,7 +64,7 @@ func UpdateReport(tx *Tx, r *models.Report) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteReport(tx *Tx, r *models.Report) error {
|
func DeleteReport(tx *db.Tx, r *models.Report) error {
|
||||||
count, err := tx.Delete(r)
|
count, err := tx.Delete(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -74,7 +75,7 @@ func DeleteReport(tx *Tx, r *models.Report) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
|
func runReport(tx *db.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
|
||||||
// Create a new LState without opening the default libs for security
|
// Create a new LState without opening the default libs for security
|
||||||
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
||||||
defer L.Close()
|
defer L.Close()
|
||||||
@ -138,7 +139,7 @@ func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabula
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
|
func ReportTabulationHandler(tx *db.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
|
||||||
report, err := GetReport(tx, reportid, user.UserId)
|
report, err := GetReport(tx, reportid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -50,7 +51,7 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, error) {
|
func GetSecurity(tx *db.Tx, securityid int64, userid int64) (*models.Security, error) {
|
||||||
var s models.Security
|
var s models.Security
|
||||||
|
|
||||||
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
||||||
@ -60,7 +61,7 @@ func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, erro
|
|||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) {
|
func GetSecurities(tx *db.Tx, userid int64) (*[]*models.Security, error) {
|
||||||
var securities []*models.Security
|
var securities []*models.Security
|
||||||
|
|
||||||
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
||||||
@ -70,7 +71,7 @@ func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) {
|
|||||||
return &securities, nil
|
return &securities, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertSecurity(tx *Tx, s *models.Security) error {
|
func InsertSecurity(tx *db.Tx, s *models.Security) error {
|
||||||
err := tx.Insert(s)
|
err := tx.Insert(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -78,7 +79,7 @@ func InsertSecurity(tx *Tx, s *models.Security) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateSecurity(tx *Tx, s *models.Security) (err error) {
|
func UpdateSecurity(tx *db.Tx, s *models.Security) (err error) {
|
||||||
user, err := GetUser(tx, s.UserId)
|
user, err := GetUser(tx, s.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -105,7 +106,7 @@ func (e SecurityInUseError) Error() string {
|
|||||||
return e.message
|
return e.message
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteSecurity(tx *Tx, s *models.Security) error {
|
func DeleteSecurity(tx *db.Tx, s *models.Security) error {
|
||||||
// First, ensure no accounts are using this security
|
// First, ensure no accounts are using this security
|
||||||
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
|
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
|
||||||
|
|
||||||
@ -138,7 +139,7 @@ func DeleteSecurity(tx *Tx, s *models.Security) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*models.Security, error) {
|
func ImportGetCreateSecurity(tx *db.Tx, userid int64, security *models.Security) (*models.Security, error) {
|
||||||
security.UserId = userid
|
security.UserId = userid
|
||||||
if len(security.AlternateId) == 0 {
|
if len(security.AlternateId) == 0 {
|
||||||
// Always create a new local security if we can't match on the AlternateId
|
// Always create a new local security if we can't match on the AlternateId
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"github.com/yuin/gopher-lua"
|
"github.com/yuin/gopher-lua"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -14,7 +15,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error)
|
|||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
|
|
||||||
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
tx, ok := ctx.Value(dbContextKey).(*db.Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Couldn't find tx in lua's Context")
|
return nil, errors.New("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
@ -158,7 +159,7 @@ func luaClosestPrice(L *lua.LState) int {
|
|||||||
date := luaCheckTime(L, 3)
|
date := luaCheckTime(L, 3)
|
||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
tx, ok := ctx.Value(dbContextKey).(*db.Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("Couldn't find tx in lua's Context")
|
panic("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
|
@ -3,36 +3,37 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetSession(tx *Tx, r *http.Request) (*models.Session, error) {
|
func GetSession(tx store.Tx, r *http.Request) (*models.Session, error) {
|
||||||
var s models.Session
|
|
||||||
|
|
||||||
cookie, err := r.Cookie("moneygo-session")
|
cookie, err := r.Cookie("moneygo-session")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("moneygo-session cookie not set")
|
return nil, fmt.Errorf("moneygo-session cookie not set")
|
||||||
}
|
}
|
||||||
s.SessionSecret = cookie.Value
|
|
||||||
|
|
||||||
err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
|
s, err := tx.GetSession(cookie.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.Expires.Before(time.Now()) {
|
if s.Expires.Before(time.Now()) {
|
||||||
tx.Delete(&s)
|
err := tx.DeleteSession(s)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Unexpected error when attempting to delete expired session: %s", err)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("Session has expired")
|
return nil, fmt.Errorf("Session has expired")
|
||||||
}
|
}
|
||||||
return &s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteSessionIfExists(tx *Tx, r *http.Request) error {
|
func DeleteSessionIfExists(tx store.Tx, r *http.Request) error {
|
||||||
session, err := GetSession(tx, r)
|
session, err := GetSession(tx, r)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_, err := tx.Delete(session)
|
err := tx.DeleteSession(session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -50,21 +51,26 @@ func (n *NewSessionWriter) Write(w http.ResponseWriter) error {
|
|||||||
return n.session.Write(w)
|
return n.session.Write(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
|
func NewSession(tx store.Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
|
||||||
|
err := DeleteSessionIfExists(tx, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
s, err := models.NewSession(userid)
|
s, err := models.NewSession(userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", s.SessionSecret)
|
exists, err := tx.SessionExists(s.SessionSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if existing > 0 {
|
if exists {
|
||||||
return nil, fmt.Errorf("%d session(s) exist with the generated session_secret", existing)
|
return nil, fmt.Errorf("Session already exists with the generated session_secret")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Insert(s)
|
err = tx.InsertSession(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -89,27 +95,21 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
return NewError(2 /*Unauthorized Access*/)
|
return NewError(2 /*Unauthorized Access*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteSessionIfExists(context.Tx, r)
|
sessionwriter, err := NewSession(context.StoreTx, r, dbuser.UserId)
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
return NewError(999 /*Internal Error*/)
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
return sessionwriter
|
return sessionwriter
|
||||||
} else if r.Method == "GET" {
|
} else if r.Method == "GET" {
|
||||||
s, err := GetSession(context.Tx, r)
|
s, err := GetSession(context.StoreTx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(1 /*Not Signed In*/)
|
return NewError(1 /*Not Signed In*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
err := DeleteSessionIfExists(context.Tx, r)
|
err := DeleteSessionIfExists(context.StoreTx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"log"
|
"log"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -12,14 +13,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SplitAlreadyImported(tx *Tx, s *models.Split) (bool, error) {
|
func SplitAlreadyImported(tx *db.Tx, s *models.Split) (bool, error) {
|
||||||
count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
|
count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
|
||||||
return count == 1, err
|
return count == 1, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a map of security ID's to big.Rat's containing the amount that
|
// Return a map of security ID's to big.Rat's containing the amount that
|
||||||
// security is imbalanced by
|
// security is imbalanced by
|
||||||
func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, error) {
|
func GetTransactionImbalances(tx *db.Tx, t *models.Transaction) (map[int64]big.Rat, error) {
|
||||||
sums := make(map[int64]big.Rat)
|
sums := make(map[int64]big.Rat)
|
||||||
|
|
||||||
if !t.Valid() {
|
if !t.Valid() {
|
||||||
@ -47,7 +48,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat,
|
|||||||
|
|
||||||
// Returns true if all securities contained in this transaction are balanced,
|
// Returns true if all securities contained in this transaction are balanced,
|
||||||
// false otherwise
|
// false otherwise
|
||||||
func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) {
|
func TransactionBalanced(tx *db.Tx, t *models.Transaction) (bool, error) {
|
||||||
var zero big.Rat
|
var zero big.Rat
|
||||||
|
|
||||||
sums, err := GetTransactionImbalances(tx, t)
|
sums, err := GetTransactionImbalances(tx, t)
|
||||||
@ -63,7 +64,7 @@ func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transaction, error) {
|
func GetTransaction(tx *db.Tx, transactionid int64, userid int64) (*models.Transaction, error) {
|
||||||
var t models.Transaction
|
var t models.Transaction
|
||||||
|
|
||||||
err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
||||||
@ -79,7 +80,7 @@ func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transact
|
|||||||
return &t, nil
|
return &t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) {
|
func GetTransactions(tx *db.Tx, userid int64) (*[]models.Transaction, error) {
|
||||||
var transactions []models.Transaction
|
var transactions []models.Transaction
|
||||||
|
|
||||||
_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
|
_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
|
||||||
@ -97,7 +98,7 @@ func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) {
|
|||||||
return &transactions, nil
|
return &transactions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error {
|
func incrementAccountVersions(tx *db.Tx, user *models.User, accountids []int64) error {
|
||||||
for i := range accountids {
|
for i := range accountids {
|
||||||
account, err := GetAccount(tx, accountids[i], user.UserId)
|
account, err := GetAccount(tx, accountids[i], user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -121,7 +122,7 @@ func (ame AccountMissingError) Error() string {
|
|||||||
return "Account missing"
|
return "Account missing"
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
|
func InsertTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error {
|
||||||
// Map of any accounts with transaction splits being added
|
// Map of any accounts with transaction splits being added
|
||||||
a_map := make(map[int64]bool)
|
a_map := make(map[int64]bool)
|
||||||
for i := range t.Splits {
|
for i := range t.Splits {
|
||||||
@ -171,7 +172,7 @@ func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
|
func UpdateTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error {
|
||||||
var existing_splits []*models.Split
|
var existing_splits []*models.Split
|
||||||
|
|
||||||
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
||||||
@ -248,7 +249,7 @@ func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
|
func DeleteTransaction(tx *db.Tx, t *models.Transaction, user *models.User) error {
|
||||||
var accountids []int64
|
var accountids []int64
|
||||||
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
|
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -401,7 +402,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) {
|
func TransactionsBalanceDifference(tx *db.Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) {
|
||||||
var pageDifference, tmp big.Rat
|
var pageDifference, tmp big.Rat
|
||||||
for i := range transactions {
|
for i := range transactions {
|
||||||
_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
|
_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
|
||||||
@ -425,7 +426,7 @@ func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []model
|
|||||||
return &pageDifference, nil
|
return &pageDifference, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) {
|
func GetAccountBalance(tx *db.Tx, user *models.User, accountid int64) (*big.Rat, error) {
|
||||||
var splits []models.Split
|
var splits []models.Split
|
||||||
|
|
||||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
|
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
|
||||||
@ -448,7 +449,7 @@ func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assumes accountid is valid and is owned by the current user
|
// Assumes accountid is valid and is owned by the current user
|
||||||
func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) {
|
func GetAccountBalanceDate(tx *db.Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) {
|
||||||
var splits []models.Split
|
var splits []models.Split
|
||||||
|
|
||||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
|
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
|
||||||
@ -470,7 +471,7 @@ func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *tim
|
|||||||
return &balance, nil
|
return &balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
func GetAccountBalanceDateRange(tx *db.Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
||||||
var splits []models.Split
|
var splits []models.Split
|
||||||
|
|
||||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
|
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
|
||||||
@ -492,7 +493,7 @@ func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begi
|
|||||||
return &balance, nil
|
return &balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
|
func GetAccountTransactions(tx *db.Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
|
||||||
var transactions []models.Transaction
|
var transactions []models.Transaction
|
||||||
var atl models.AccountTransactionsList
|
var atl models.AccountTransactionsList
|
||||||
|
|
||||||
|
@ -1,65 +1,14 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"github.com/aclindsa/gorp"
|
"github.com/aclindsa/gorp"
|
||||||
"strings"
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Tx struct {
|
func GetTx(gdb *gorp.DbMap) (*db.Tx, error) {
|
||||||
Dialect gorp.Dialect
|
tx, err := gdb.Begin()
|
||||||
Tx *gorp.Transaction
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tx *Tx) Rebind(query string) string {
|
|
||||||
chunks := strings.Split(query, "?")
|
|
||||||
str := chunks[0]
|
|
||||||
for i := 1; i < len(chunks); i++ {
|
|
||||||
str += tx.Dialect.BindVar(i-1) + chunks[i]
|
|
||||||
}
|
|
||||||
return str
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
|
|
||||||
return tx.Tx.Select(i, tx.Rebind(query), args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
return tx.Tx.Exec(tx.Rebind(query), args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) {
|
|
||||||
return tx.Tx.SelectInt(tx.Rebind(query), args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error {
|
|
||||||
return tx.Tx.SelectOne(holder, tx.Rebind(query), args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tx *Tx) Insert(list ...interface{}) error {
|
|
||||||
return tx.Tx.Insert(list...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tx *Tx) Update(list ...interface{}) (int64, error) {
|
|
||||||
return tx.Tx.Update(list...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
|
|
||||||
return tx.Tx.Delete(list...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tx *Tx) Commit() error {
|
|
||||||
return tx.Tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tx *Tx) Rollback() error {
|
|
||||||
return tx.Tx.Rollback()
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetTx(db *gorp.DbMap) (*Tx, error) {
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Tx{db.Dialect, tx}, nil
|
return &db.Tx{gdb.Dialect, tx}, nil
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
@ -14,7 +15,7 @@ func (ueu UserExistsError) Error() string {
|
|||||||
return "User exists"
|
return "User exists"
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(tx *Tx, userid int64) (*models.User, error) {
|
func GetUser(tx *db.Tx, userid int64) (*models.User, error) {
|
||||||
var u models.User
|
var u models.User
|
||||||
|
|
||||||
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
||||||
@ -24,7 +25,7 @@ func GetUser(tx *Tx, userid int64) (*models.User, error) {
|
|||||||
return &u, nil
|
return &u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByUsername(tx *Tx, username string) (*models.User, error) {
|
func GetUserByUsername(tx *db.Tx, username string) (*models.User, error) {
|
||||||
var u models.User
|
var u models.User
|
||||||
|
|
||||||
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
|
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
|
||||||
@ -34,7 +35,7 @@ func GetUserByUsername(tx *Tx, username string) (*models.User, error) {
|
|||||||
return &u, nil
|
return &u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertUser(tx *Tx, u *models.User) error {
|
func InsertUser(tx *db.Tx, u *models.User) error {
|
||||||
security_template := FindCurrencyTemplate(u.DefaultCurrency)
|
security_template := FindCurrencyTemplate(u.DefaultCurrency)
|
||||||
if security_template == nil {
|
if security_template == nil {
|
||||||
return errors.New("Invalid ISO4217 Default Currency")
|
return errors.New("Invalid ISO4217 Default Currency")
|
||||||
@ -75,7 +76,7 @@ func InsertUser(tx *Tx, u *models.User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) {
|
func GetUserFromSession(tx *db.Tx, r *http.Request) (*models.User, error) {
|
||||||
s, err := GetSession(tx, r)
|
s, err := GetSession(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -83,7 +84,7 @@ func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) {
|
|||||||
return GetUser(tx, s.UserId)
|
return GetUser(tx, s.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUser(tx *Tx, u *models.User) error {
|
func UpdateUser(tx *db.Tx, u *models.User) error {
|
||||||
security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
|
security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -103,7 +104,7 @@ func UpdateUser(tx *Tx, u *models.User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteUser(tx *Tx, u *models.User) error {
|
func DeleteUser(tx *db.Tx, u *models.User) error {
|
||||||
count, err := tx.Delete(u)
|
count, err := tx.Delete(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/aclindsa/gorp"
|
"github.com/aclindsa/gorp"
|
||||||
"github.com/aclindsa/moneygo/internal/config"
|
"github.com/aclindsa/moneygo/internal/config"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
@ -60,3 +61,40 @@ func GetDSN(dbtype config.DbType, dsn string) string {
|
|||||||
}
|
}
|
||||||
return dsn
|
return dsn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DbStore struct {
|
||||||
|
DbMap *gorp.DbMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DbStore) Begin() (store.Tx, error) {
|
||||||
|
tx, err := db.DbMap.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Tx{db.DbMap.Dialect, tx}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DbStore) Close() error {
|
||||||
|
err := db.DbMap.Db.Close()
|
||||||
|
db.DbMap = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) {
|
||||||
|
dsn = GetDSN(dbtype, dsn)
|
||||||
|
database, err := sql.Open(dbtype.String(), dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
database.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dbmap, err := GetDbMap(database, dbtype)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &DbStore{dbmap}, nil
|
||||||
|
}
|
36
internal/store/db/sessions.go
Normal file
36
internal/store/db/sessions.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tx *Tx) InsertSession(session *models.Session) error {
|
||||||
|
return tx.Insert(session)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetSession(secret string) (*models.Session, error) {
|
||||||
|
var s models.Session
|
||||||
|
|
||||||
|
err := tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", secret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Expires.Before(time.Now()) {
|
||||||
|
tx.Delete(&s)
|
||||||
|
return nil, fmt.Errorf("Session has expired")
|
||||||
|
}
|
||||||
|
return &s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) SessionExists(secret string) (bool, error) {
|
||||||
|
existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", secret)
|
||||||
|
return existing != 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) DeleteSession(session *models.Session) error {
|
||||||
|
_, err := tx.Delete(session)
|
||||||
|
return err
|
||||||
|
}
|
57
internal/store/db/tx.go
Normal file
57
internal/store/db/tx.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"github.com/aclindsa/gorp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tx struct {
|
||||||
|
Dialect gorp.Dialect
|
||||||
|
Tx *gorp.Transaction
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Rebind(query string) string {
|
||||||
|
chunks := strings.Split(query, "?")
|
||||||
|
str := chunks[0]
|
||||||
|
for i := 1; i < len(chunks); i++ {
|
||||||
|
str += tx.Dialect.BindVar(i-1) + chunks[i]
|
||||||
|
}
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
|
||||||
|
return tx.Tx.Select(i, tx.Rebind(query), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
return tx.Tx.Exec(tx.Rebind(query), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) SelectInt(query string, args ...interface{}) (int64, error) {
|
||||||
|
return tx.Tx.SelectInt(tx.Rebind(query), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) SelectOne(holder interface{}, query string, args ...interface{}) error {
|
||||||
|
return tx.Tx.SelectOne(holder, tx.Rebind(query), args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Insert(list ...interface{}) error {
|
||||||
|
return tx.Tx.Insert(list...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Update(list ...interface{}) (int64, error) {
|
||||||
|
return tx.Tx.Update(list...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
|
||||||
|
return tx.Tx.Delete(list...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Commit() error {
|
||||||
|
return tx.Tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Rollback() error {
|
||||||
|
return tx.Tx.Rollback()
|
||||||
|
}
|
24
internal/store/store.go
Normal file
24
internal/store/store.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SessionStore interface {
|
||||||
|
InsertSession(session *models.Session) error
|
||||||
|
GetSession(secret string) (*models.Session, error)
|
||||||
|
SessionExists(secret string) (bool, error)
|
||||||
|
DeleteSession(session *models.Session) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tx interface {
|
||||||
|
Commit() error
|
||||||
|
Rollback() error
|
||||||
|
|
||||||
|
SessionStore
|
||||||
|
}
|
||||||
|
|
||||||
|
type Store interface {
|
||||||
|
Begin() (Tx, error)
|
||||||
|
Close() error
|
||||||
|
}
|
15
main.go
15
main.go
@ -3,11 +3,10 @@ package main
|
|||||||
//go:generate make
|
//go:generate make
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"flag"
|
"flag"
|
||||||
"github.com/aclindsa/moneygo/internal/config"
|
"github.com/aclindsa/moneygo/internal/config"
|
||||||
"github.com/aclindsa/moneygo/internal/db"
|
|
||||||
"github.com/aclindsa/moneygo/internal/handlers"
|
"github.com/aclindsa/moneygo/internal/handlers"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"github.com/kabukky/httpscerts"
|
"github.com/kabukky/httpscerts"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@ -67,21 +66,15 @@ func staticHandler(w http.ResponseWriter, r *http.Request, basedir string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
dsn := db.GetDSN(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN)
|
db, err := db.GetStore(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN)
|
||||||
database, err := sql.Open(cfg.MoneyGo.DBType.String(), dsn)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
defer database.Close()
|
|
||||||
|
|
||||||
dbmap, err := db.GetDbMap(database, cfg.MoneyGo.DBType)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
// Get ServeMux for API and add our own handlers for files
|
// Get ServeMux for API and add our own handlers for files
|
||||||
servemux := http.NewServeMux()
|
servemux := http.NewServeMux()
|
||||||
servemux.Handle("/v1/", &handlers.APIHandler{DB: dbmap})
|
servemux.Handle("/v1/", &handlers.APIHandler{Store: db})
|
||||||
servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir))
|
servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir))
|
||||||
servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir))
|
servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user