mirror of
https://github.com/aclindsa/moneygo.git
synced 2024-12-26 07:33:21 -05:00
Merge pull request #35 from aclindsa/store_split
Split DB activity into 'store'
This commit is contained in:
commit
9cdf4f3c29
@ -20,12 +20,11 @@ env:
|
|||||||
- MONEYGO_TEST_DB=mysql
|
- MONEYGO_TEST_DB=mysql
|
||||||
- MONEYGO_TEST_DB=postgres
|
- MONEYGO_TEST_DB=postgres
|
||||||
|
|
||||||
# OSX builds take too long, so don't wait for all of them
|
# OSX builds take too long, so don't wait for them
|
||||||
matrix:
|
matrix:
|
||||||
fast_finish: true
|
fast_finish: true
|
||||||
allow_failures:
|
allow_failures:
|
||||||
- os: osx
|
- os: osx
|
||||||
go: master
|
|
||||||
|
|
||||||
before_install:
|
before_install:
|
||||||
# Fetch/build coverage reporting tools
|
# Fetch/build coverage reporting tools
|
||||||
|
@ -3,43 +3,22 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) {
|
|
||||||
var a models.Account
|
|
||||||
|
|
||||||
err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &a, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) {
|
|
||||||
var accounts []models.Account
|
|
||||||
|
|
||||||
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &accounts, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get (and attempt to create if it doesn't exist). Matches on UserId,
|
// Get (and attempt to create if it doesn't exist). Matches on UserId,
|
||||||
// SecurityId, Type, Name, and ParentAccountId
|
// SecurityId, Type, Name, and ParentAccountId
|
||||||
func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
|
func GetCreateAccount(tx store.Tx, a models.Account) (*models.Account, error) {
|
||||||
var accounts []models.Account
|
|
||||||
var account models.Account
|
var account models.Account
|
||||||
|
|
||||||
// Try to find the top-level trading account
|
accounts, err := tx.FindMatchingAccounts(&a)
|
||||||
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(accounts) == 1 {
|
if len(*accounts) > 0 {
|
||||||
account = accounts[0]
|
account = *(*accounts)[0]
|
||||||
} else {
|
} else {
|
||||||
account.UserId = a.UserId
|
account.UserId = a.UserId
|
||||||
account.SecurityId = a.SecurityId
|
account.SecurityId = a.SecurityId
|
||||||
@ -47,7 +26,7 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
|
|||||||
account.Name = a.Name
|
account.Name = a.Name
|
||||||
account.ParentAccountId = a.ParentAccountId
|
account.ParentAccountId = a.ParentAccountId
|
||||||
|
|
||||||
err = tx.Insert(&account)
|
err = tx.InsertAccount(&account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -57,11 +36,11 @@ func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) {
|
|||||||
|
|
||||||
// Get (and attempt to create if it doesn't exist) the security/currency
|
// Get (and attempt to create if it doesn't exist) the security/currency
|
||||||
// trading account for the supplied security/currency
|
// trading account for the supplied security/currency
|
||||||
func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) {
|
func GetTradingAccount(tx store.Tx, userid int64, securityid int64) (*models.Account, error) {
|
||||||
var tradingAccount models.Account
|
var tradingAccount models.Account
|
||||||
var account models.Account
|
var account models.Account
|
||||||
|
|
||||||
user, err := GetUser(tx, userid)
|
user, err := tx.GetUser(userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -78,7 +57,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
security, err := GetSecurity(tx, securityid, userid)
|
security, err := tx.GetSecurity(securityid, userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -99,7 +78,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account,
|
|||||||
|
|
||||||
// Get (and attempt to create if it doesn't exist) the security/currency
|
// Get (and attempt to create if it doesn't exist) the security/currency
|
||||||
// imbalance account for the supplied security/currency
|
// imbalance account for the supplied security/currency
|
||||||
func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) {
|
func GetImbalanceAccount(tx store.Tx, userid int64, securityid int64) (*models.Account, error) {
|
||||||
var imbalanceAccount models.Account
|
var imbalanceAccount models.Account
|
||||||
var account models.Account
|
var account models.Account
|
||||||
xxxtemplate := FindSecurityTemplate("XXX", models.Currency)
|
xxxtemplate := FindSecurityTemplate("XXX", models.Currency)
|
||||||
@ -123,7 +102,7 @@ func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Accoun
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
security, err := GetSecurity(tx, securityid, userid)
|
security, err := tx.GetSecurity(securityid, userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -142,120 +121,6 @@ func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Accoun
|
|||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ParentAccountMissingError struct{}
|
|
||||||
|
|
||||||
func (pame ParentAccountMissingError) Error() string {
|
|
||||||
return "Parent account missing"
|
|
||||||
}
|
|
||||||
|
|
||||||
type TooMuchNestingError struct{}
|
|
||||||
|
|
||||||
func (tmne TooMuchNestingError) Error() string {
|
|
||||||
return "Too much nesting"
|
|
||||||
}
|
|
||||||
|
|
||||||
type CircularAccountsError struct{}
|
|
||||||
|
|
||||||
func (cae CircularAccountsError) Error() string {
|
|
||||||
return "Would result in circular account relationship"
|
|
||||||
}
|
|
||||||
|
|
||||||
func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error {
|
|
||||||
found := make(map[int64]bool)
|
|
||||||
if !insert {
|
|
||||||
found[a.AccountId] = true
|
|
||||||
}
|
|
||||||
parentid := a.ParentAccountId
|
|
||||||
depth := 0
|
|
||||||
for parentid != -1 {
|
|
||||||
depth += 1
|
|
||||||
if depth > 100 {
|
|
||||||
return TooMuchNestingError{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var a models.Account
|
|
||||||
err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid)
|
|
||||||
if err != nil {
|
|
||||||
return ParentAccountMissingError{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insertion by itself can never result in circular dependencies
|
|
||||||
if insert {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
found[parentid] = true
|
|
||||||
parentid = a.ParentAccountId
|
|
||||||
if _, ok := found[parentid]; ok {
|
|
||||||
return CircularAccountsError{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if insert {
|
|
||||||
err := tx.Insert(a)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
oldacct, err := GetAccount(tx, a.AccountId, a.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
a.AccountVersion = oldacct.AccountVersion + 1
|
|
||||||
|
|
||||||
count, err := tx.Update(a)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count != 1 {
|
|
||||||
return errors.New("Updated more than one account")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InsertAccount(tx *Tx, a *models.Account) error {
|
|
||||||
return insertUpdateAccount(tx, a, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func UpdateAccount(tx *Tx, a *models.Account) error {
|
|
||||||
return insertUpdateAccount(tx, a, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeleteAccount(tx *Tx, a *models.Account) error {
|
|
||||||
if a.ParentAccountId != -1 {
|
|
||||||
// Re-parent splits to this account's parent account if this account isn't a root account
|
|
||||||
_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Delete splits if this account is a root account
|
|
||||||
_, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-parent child accounts to this account's parent account
|
|
||||||
_, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
count, err := tx.Delete(a)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count != 1 {
|
|
||||||
return errors.New("Was going to delete more than one account")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||||
user, err := GetUserFromSession(context.Tx, r)
|
user, err := GetUserFromSession(context.Tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -279,7 +144,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
account.UserId = user.UserId
|
account.UserId = user.UserId
|
||||||
account.AccountVersion = 0
|
account.AccountVersion = 0
|
||||||
|
|
||||||
security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId)
|
security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -288,9 +153,9 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = InsertAccount(context.Tx, &account)
|
err = context.Tx.InsertAccount(&account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(ParentAccountMissingError); ok {
|
if _, ok := err.(store.ParentAccountMissingError); ok {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
} else {
|
} else {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -303,7 +168,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
if context.LastLevel() {
|
if context.LastLevel() {
|
||||||
//Return all Accounts
|
//Return all Accounts
|
||||||
var al models.AccountList
|
var al models.AccountList
|
||||||
accounts, err := GetAccounts(context.Tx, user.UserId)
|
accounts, err := context.Tx.GetAccounts(user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -319,7 +184,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
|
|
||||||
if context.LastLevel() {
|
if context.LastLevel() {
|
||||||
// Return Account with this Id
|
// Return Account with this Id
|
||||||
account, err := GetAccount(context.Tx, accountid, user.UserId)
|
account, err := context.Tx.GetAccount(accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
@ -340,7 +205,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
}
|
}
|
||||||
account.UserId = user.UserId
|
account.UserId = user.UserId
|
||||||
|
|
||||||
security, err := GetSecurity(context.Tx, account.SecurityId, user.UserId)
|
security, err := context.Tx.GetSecurity(account.SecurityId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -353,11 +218,11 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = UpdateAccount(context.Tx, &account)
|
err = context.Tx.UpdateAccount(&account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(ParentAccountMissingError); ok {
|
if _, ok := err.(store.ParentAccountMissingError); ok {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
} else if _, ok := err.(CircularAccountsError); ok {
|
} else if _, ok := err.(store.CircularAccountsError); ok {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
} else {
|
} else {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -367,12 +232,12 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
|
|
||||||
return &account
|
return &account
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
account, err := GetAccount(context.Tx, accountid, user.UserId)
|
account, err := context.Tx.GetAccount(accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteAccount(context.Tx, account)
|
err = context.Tx.DeleteAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
|
@ -4,8 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
"github.com/yuin/gopher-lua"
|
"github.com/yuin/gopher-lua"
|
||||||
"math/big"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) {
|
|||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
|
|
||||||
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
tx, ok := ctx.Value(dbContextKey).(store.Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Couldn't find tx in lua's Context")
|
return nil, errors.New("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
@ -28,14 +28,14 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) {
|
|||||||
return nil, errors.New("Couldn't find User in lua's Context")
|
return nil, errors.New("Couldn't find User in lua's Context")
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, err := GetAccounts(tx, user.UserId)
|
accounts, err := tx.GetAccounts(user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
account_map = make(map[int64]*models.Account)
|
account_map = make(map[int64]*models.Account)
|
||||||
for i := range *accounts {
|
for i := range *accounts {
|
||||||
account_map[(*accounts)[i].AccountId] = &(*accounts)[i]
|
account_map[(*accounts)[i].AccountId] = (*accounts)[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, accountsContextKey, account_map)
|
ctx = context.WithValue(ctx, accountsContextKey, account_map)
|
||||||
@ -150,7 +150,7 @@ func luaAccountBalance(L *lua.LState) int {
|
|||||||
a := luaCheckAccount(L, 1)
|
a := luaCheckAccount(L, 1)
|
||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
tx, ok := ctx.Value(dbContextKey).(store.Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("Couldn't find tx in lua's Context")
|
panic("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
@ -167,24 +167,29 @@ func luaAccountBalance(L *lua.LState) int {
|
|||||||
panic("SecurityId not in lua security_map")
|
panic("SecurityId not in lua security_map")
|
||||||
}
|
}
|
||||||
date := luaWeakCheckTime(L, 2)
|
date := luaWeakCheckTime(L, 2)
|
||||||
var b Balance
|
var splits *[]*models.Split
|
||||||
var rat *big.Rat
|
|
||||||
if date != nil {
|
if date != nil {
|
||||||
end := luaWeakCheckTime(L, 3)
|
end := luaWeakCheckTime(L, 3)
|
||||||
if end != nil {
|
if end != nil {
|
||||||
rat, err = GetAccountBalanceDateRange(tx, user, a.AccountId, date, end)
|
splits, err = tx.GetAccountSplitsDateRange(user, a.AccountId, date, end)
|
||||||
} else {
|
} else {
|
||||||
rat, err = GetAccountBalanceDate(tx, user, a.AccountId, date)
|
splits, err = tx.GetAccountSplitsDate(user, a.AccountId, date)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
rat, err = GetAccountBalance(tx, user, a.AccountId)
|
splits, err = tx.GetAccountSplits(user, a.AccountId)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("Failed to GetAccountBalance:" + err.Error())
|
panic("Failed to fetch splits for account:" + err.Error())
|
||||||
}
|
}
|
||||||
b.Amount = rat
|
rat, err := BalanceFromSplits(splits)
|
||||||
b.Security = security
|
if err != nil {
|
||||||
L.Push(BalanceToLua(L, &b))
|
panic("Failed to calculate balance for account:" + err.Error())
|
||||||
|
}
|
||||||
|
b := &Balance{
|
||||||
|
Amount: rat,
|
||||||
|
Security: security,
|
||||||
|
}
|
||||||
|
L.Push(BalanceToLua(L, b))
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
@ -2,12 +2,11 @@ package handlers_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/aclindsa/moneygo/internal/config"
|
"github.com/aclindsa/moneygo/internal/config"
|
||||||
"github.com/aclindsa/moneygo/internal/db"
|
|
||||||
"github.com/aclindsa/moneygo/internal/handlers"
|
"github.com/aclindsa/moneygo/internal/handlers"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
@ -253,24 +252,15 @@ func RunTests(m *testing.M) int {
|
|||||||
dsn = envDSN
|
dsn = envDSN
|
||||||
}
|
}
|
||||||
|
|
||||||
dsn = db.GetDSN(dbType, dsn)
|
db, err := db.GetStore(dbType, dsn)
|
||||||
database, err := sql.Open(dbType.String(), dsn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
defer database.Close()
|
defer db.Close()
|
||||||
|
|
||||||
dbmap, err := db.GetDbMap(database, dbType)
|
db.Empty() // clear the DB tables
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = dbmap.TruncateTables()
|
server = httptest.NewTLSServer(&handlers.APIHandler{Store: db})
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
server = httptest.NewTLSServer(&handlers.APIHandler{DB: dbmap})
|
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
return m.Run()
|
return m.Run()
|
||||||
|
@ -437,7 +437,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite
|
|||||||
}
|
}
|
||||||
split.AccountId = acctId
|
split.AccountId = acctId
|
||||||
|
|
||||||
exists, err := SplitAlreadyImported(context.Tx, split)
|
exists, err := context.Tx.SplitExists(split)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print("Error checking if split was already imported:", err)
|
log.Print("Error checking if split was already imported:", err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -446,7 +446,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !already_imported {
|
if !already_imported {
|
||||||
err := InsertTransaction(context.Tx, &transaction, user)
|
err := context.Tx.InsertTransaction(&transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
|
@ -38,13 +38,13 @@ func TestImportGnucash(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for i, account := range *accounts.Accounts {
|
for i, account := range *accounts.Accounts {
|
||||||
if account.Name == "Income" && account.Type == models.Income && account.ParentAccountId == -1 {
|
if account.Name == "Income" && account.Type == models.Income && account.ParentAccountId == -1 {
|
||||||
income = &(*accounts.Accounts)[i]
|
income = (*accounts.Accounts)[i]
|
||||||
} else if account.Name == "Equity" && account.Type == models.Equity && account.ParentAccountId == -1 {
|
} else if account.Name == "Equity" && account.Type == models.Equity && account.ParentAccountId == -1 {
|
||||||
equity = &(*accounts.Accounts)[i]
|
equity = (*accounts.Accounts)[i]
|
||||||
} else if account.Name == "Liabilities" && account.Type == models.Liability && account.ParentAccountId == -1 {
|
} else if account.Name == "Liabilities" && account.Type == models.Liability && account.ParentAccountId == -1 {
|
||||||
liabilities = &(*accounts.Accounts)[i]
|
liabilities = (*accounts.Accounts)[i]
|
||||||
} else if account.Name == "Expenses" && account.Type == models.Expense && account.ParentAccountId == -1 {
|
} else if account.Name == "Expenses" && account.Type == models.Expense && account.ParentAccountId == -1 {
|
||||||
expenses = &(*accounts.Accounts)[i]
|
expenses = (*accounts.Accounts)[i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if income == nil {
|
if income == nil {
|
||||||
@ -61,15 +61,15 @@ func TestImportGnucash(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for i, account := range *accounts.Accounts {
|
for i, account := range *accounts.Accounts {
|
||||||
if account.Name == "Salary" && account.Type == models.Income && account.ParentAccountId == income.AccountId {
|
if account.Name == "Salary" && account.Type == models.Income && account.ParentAccountId == income.AccountId {
|
||||||
salary = &(*accounts.Accounts)[i]
|
salary = (*accounts.Accounts)[i]
|
||||||
} else if account.Name == "Opening Balances" && account.Type == models.Equity && account.ParentAccountId == equity.AccountId {
|
} else if account.Name == "Opening Balances" && account.Type == models.Equity && account.ParentAccountId == equity.AccountId {
|
||||||
openingbalances = &(*accounts.Accounts)[i]
|
openingbalances = (*accounts.Accounts)[i]
|
||||||
} else if account.Name == "Credit Card" && account.Type == models.Liability && account.ParentAccountId == liabilities.AccountId {
|
} else if account.Name == "Credit Card" && account.Type == models.Liability && account.ParentAccountId == liabilities.AccountId {
|
||||||
creditcard = &(*accounts.Accounts)[i]
|
creditcard = (*accounts.Accounts)[i]
|
||||||
} else if account.Name == "Groceries" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId {
|
} else if account.Name == "Groceries" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId {
|
||||||
groceries = &(*accounts.Accounts)[i]
|
groceries = (*accounts.Accounts)[i]
|
||||||
} else if account.Name == "Cable" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId {
|
} else if account.Name == "Cable" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId {
|
||||||
cable = &(*accounts.Accounts)[i]
|
cable = (*accounts.Accounts)[i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if salary == nil {
|
if salary == nil {
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/aclindsa/gorp"
|
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
@ -16,7 +17,7 @@ type ResponseWriterWriter interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Context struct {
|
type Context struct {
|
||||||
Tx *Tx
|
Tx store.Tx
|
||||||
User *models.User
|
User *models.User
|
||||||
remainingURL string // portion of URL path not yet reached in the hierarchy
|
remainingURL string // portion of URL path not yet reached in the hierarchy
|
||||||
}
|
}
|
||||||
@ -46,11 +47,11 @@ func (c *Context) LastLevel() bool {
|
|||||||
type Handler func(*http.Request, *Context) ResponseWriterWriter
|
type Handler func(*http.Request, *Context) ResponseWriterWriter
|
||||||
|
|
||||||
type APIHandler struct {
|
type APIHandler struct {
|
||||||
DB *gorp.DbMap
|
Store *db.DbStore
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
|
func (ah *APIHandler) txWrapper(h Handler, r *http.Request, context *Context) (writer ResponseWriterWriter) {
|
||||||
tx, err := GetTx(ah.DB)
|
tx, err := ah.Store.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
|
@ -3,6 +3,7 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
"github.com/aclindsa/ofxgo"
|
"github.com/aclindsa/ofxgo"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@ -23,7 +24,7 @@ func (od *OFXDownload) Read(json_str string) error {
|
|||||||
return dec.Decode(od)
|
return dec.Decode(od)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
|
func ofxImportHelper(tx store.Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
|
||||||
itl, err := ImportOFX(r)
|
itl, err := ImportOFX(r)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -38,7 +39,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return Account with this Id
|
// Return Account with this Id
|
||||||
account, err := GetAccount(tx, accountid, user.UserId)
|
account, err := tx.GetAccount(accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
@ -158,7 +159,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re
|
|||||||
split := new(models.Split)
|
split := new(models.Split)
|
||||||
r := new(big.Rat)
|
r := new(big.Rat)
|
||||||
r.Neg(&imbalance)
|
r.Neg(&imbalance)
|
||||||
security, err := GetSecurity(tx, imbalanced_security, user.UserId)
|
security, err := tx.GetSecurity(imbalanced_security, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -186,7 +187,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re
|
|||||||
split.SecurityId = -1
|
split.SecurityId = -1
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err := SplitAlreadyImported(tx, split)
|
exists, err := tx.SplitExists(split)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print("Error checking if split was already imported:", err)
|
log.Print("Error checking if split was already imported:", err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -201,7 +202,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, transaction := range transactions {
|
for _, transaction := range transactions {
|
||||||
err := InsertTransaction(tx, &transaction, user)
|
err := tx.InsertTransaction(&transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -217,7 +218,7 @@ func OFXImportHandler(context *Context, r *http.Request, user *models.User, acco
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := GetAccount(context.Tx, accountid, user.UserId)
|
account, err := context.Tx.GetAccount(accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
@ -83,7 +83,7 @@ func findAccount(client *http.Client, name string, tipe models.AccountType, secu
|
|||||||
}
|
}
|
||||||
for _, account := range *accounts.Accounts {
|
for _, account := range *accounts.Accounts {
|
||||||
if account.Name == name && account.Type == tipe && account.SecurityId == securityid {
|
if account.Name == name && account.Type == tipe && account.SecurityId == securityid {
|
||||||
return &account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("Unable to find account: \"%s\"", name)
|
return nil, fmt.Errorf("Unable to find account: \"%s\"", name)
|
||||||
|
@ -2,82 +2,41 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CreatePriceIfNotExist(tx *Tx, price *models.Price) error {
|
func CreatePriceIfNotExist(tx store.Tx, price *models.Price) error {
|
||||||
if len(price.RemoteId) == 0 {
|
if len(price.RemoteId) == 0 {
|
||||||
// Always create a new price if we can't match on the RemoteId
|
// Always create a new price if we can't match on the RemoteId
|
||||||
err := tx.Insert(price)
|
err := tx.InsertPrice(price)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var prices []*models.Price
|
exists, err := tx.PriceExists(price)
|
||||||
|
|
||||||
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if exists {
|
||||||
if len(prices) > 0 {
|
|
||||||
return nil // price already exists
|
return nil // price already exists
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Insert(price)
|
err = tx.InsertPrice(price)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) {
|
|
||||||
var p models.Price
|
|
||||||
err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) {
|
|
||||||
var prices []*models.Price
|
|
||||||
|
|
||||||
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &prices, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the latest price for security in currency units before date
|
|
||||||
func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
|
||||||
var p models.Price
|
|
||||||
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the earliest price for security in currency units after date
|
|
||||||
func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
|
||||||
var p models.Price
|
|
||||||
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the price for security in currency closest to date
|
// Return the price for security in currency closest to date
|
||||||
func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
func GetClosestPrice(tx store.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
||||||
earliest, _ := GetEarliestPrice(tx, security, currency, date)
|
earliest, _ := tx.GetEarliestPrice(security, currency, date)
|
||||||
latest, err := GetLatestPrice(tx, security, currency, date)
|
latest, err := tx.GetLatestPrice(security, currency, date)
|
||||||
|
|
||||||
// Return early if either earliest or latest are invalid
|
// Return early if either earliest or latest are invalid
|
||||||
if earliest == nil {
|
if earliest == nil {
|
||||||
@ -96,7 +55,7 @@ func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Tim
|
|||||||
}
|
}
|
||||||
|
|
||||||
func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter {
|
func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter {
|
||||||
security, err := GetSecurity(context.Tx, securityid, user.UserId)
|
security, err := context.Tx.GetSecurity(securityid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
@ -111,12 +70,12 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
|
|||||||
if price.SecurityId != security.SecurityId {
|
if price.SecurityId != security.SecurityId {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
_, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId)
|
_, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = context.Tx.Insert(&price)
|
err = context.Tx.InsertPrice(&price)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -128,7 +87,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
|
|||||||
//Return all this security's prices
|
//Return all this security's prices
|
||||||
var pl models.PriceList
|
var pl models.PriceList
|
||||||
|
|
||||||
prices, err := GetPrices(context.Tx, security.SecurityId)
|
prices, err := context.Tx.GetPrices(security.SecurityId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -143,7 +102,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
price, err := GetPrice(context.Tx, priceid, security.SecurityId)
|
price, err := context.Tx.GetPrice(priceid, security.SecurityId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
@ -160,30 +119,30 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = GetSecurity(context.Tx, price.SecurityId, user.UserId)
|
_, err = context.Tx.GetSecurity(price.SecurityId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
_, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId)
|
_, err = context.Tx.GetSecurity(price.CurrencyId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := context.Tx.Update(&price)
|
err = context.Tx.UpdatePrice(&price)
|
||||||
if err != nil || count != 1 {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &price
|
return &price
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
price, err := GetPrice(context.Tx, priceid, security.SecurityId)
|
price, err := context.Tx.GetPrice(priceid, security.SecurityId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := context.Tx.Delete(price)
|
err = context.Tx.DeletePrice(price)
|
||||||
if err != nil || count != 1 {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
"github.com/yuin/gopher-lua"
|
"github.com/yuin/gopher-lua"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -24,57 +25,7 @@ const (
|
|||||||
|
|
||||||
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
|
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
|
||||||
|
|
||||||
func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) {
|
func runReport(tx store.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
|
||||||
var r models.Report
|
|
||||||
|
|
||||||
err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetReports(tx *Tx, userid int64) (*[]models.Report, error) {
|
|
||||||
var reports []models.Report
|
|
||||||
|
|
||||||
_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &reports, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InsertReport(tx *Tx, r *models.Report) error {
|
|
||||||
err := tx.Insert(r)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func UpdateReport(tx *Tx, r *models.Report) error {
|
|
||||||
count, err := tx.Update(r)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count != 1 {
|
|
||||||
return errors.New("Updated more than one report")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeleteReport(tx *Tx, r *models.Report) error {
|
|
||||||
count, err := tx.Delete(r)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count != 1 {
|
|
||||||
return errors.New("Deleted more than one report")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
|
|
||||||
// Create a new LState without opening the default libs for security
|
// Create a new LState without opening the default libs for security
|
||||||
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
||||||
defer L.Close()
|
defer L.Close()
|
||||||
@ -138,8 +89,8 @@ func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabula
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
|
func ReportTabulationHandler(tx store.Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
|
||||||
report, err := GetReport(tx, reportid, user.UserId)
|
report, err := tx.GetReport(reportid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
@ -174,7 +125,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = InsertReport(context.Tx, &report)
|
err = context.Tx.InsertReport(&report)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -185,7 +136,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
if context.LastLevel() {
|
if context.LastLevel() {
|
||||||
//Return all Reports
|
//Return all Reports
|
||||||
var rl models.ReportList
|
var rl models.ReportList
|
||||||
reports, err := GetReports(context.Tx, user.UserId)
|
reports, err := context.Tx.GetReports(user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -203,7 +154,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
return ReportTabulationHandler(context.Tx, r, user, reportid)
|
return ReportTabulationHandler(context.Tx, r, user, reportid)
|
||||||
} else {
|
} else {
|
||||||
// Return Report with this Id
|
// Return Report with this Id
|
||||||
report, err := GetReport(context.Tx, reportid, user.UserId)
|
report, err := context.Tx.GetReport(reportid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
@ -227,7 +178,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = UpdateReport(context.Tx, &report)
|
err = context.Tx.UpdateReport(&report)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -235,12 +186,12 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
|
|
||||||
return &report
|
return &report
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
report, err := GetReport(context.Tx, reportid, user.UserId)
|
report, err := context.Tx.GetReport(reportid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteReport(context.Tx, report)
|
err = context.Tx.DeleteReport(report)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
|
@ -4,8 +4,8 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -50,108 +50,34 @@ func FindCurrencyTemplate(iso4217 int64) *models.Security {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, error) {
|
func UpdateSecurity(tx store.Tx, s *models.Security) (err error) {
|
||||||
var s models.Security
|
user, err := tx.GetUser(s.UserId)
|
||||||
|
|
||||||
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) {
|
|
||||||
var securities []*models.Security
|
|
||||||
|
|
||||||
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &securities, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InsertSecurity(tx *Tx, s *models.Security) error {
|
|
||||||
err := tx.Insert(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func UpdateSecurity(tx *Tx, s *models.Security) (err error) {
|
|
||||||
user, err := GetUser(tx, s.UserId)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
} else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency {
|
} else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency {
|
||||||
return errors.New("Cannot change security which is user's default currency to be non-currency")
|
return errors.New("Cannot change security which is user's default currency to be non-currency")
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := tx.Update(s)
|
err = tx.UpdateSecurity(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if count > 1 {
|
|
||||||
return fmt.Errorf("Updated %d securities (expected 1)", count)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type SecurityInUseError struct {
|
func ImportGetCreateSecurity(tx store.Tx, userid int64, security *models.Security) (*models.Security, error) {
|
||||||
message string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e SecurityInUseError) Error() string {
|
|
||||||
return e.message
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeleteSecurity(tx *Tx, s *models.Security) error {
|
|
||||||
// First, ensure no accounts are using this security
|
|
||||||
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
|
|
||||||
|
|
||||||
if accounts != 0 {
|
|
||||||
return SecurityInUseError{"One or more accounts still use this security"}
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := GetUser(tx, s.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
} else if user.DefaultCurrency == s.SecurityId {
|
|
||||||
return SecurityInUseError{"Cannot delete security which is user's default currency"}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove all prices involving this security (either of this security, or
|
|
||||||
// using it as a currency)
|
|
||||||
_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
count, err := tx.Delete(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count != 1 {
|
|
||||||
return errors.New("Deleted more than one security")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*models.Security, error) {
|
|
||||||
security.UserId = userid
|
security.UserId = userid
|
||||||
if len(security.AlternateId) == 0 {
|
if len(security.AlternateId) == 0 {
|
||||||
// Always create a new local security if we can't match on the AlternateId
|
// Always create a new local security if we can't match on the AlternateId
|
||||||
err := InsertSecurity(tx, security)
|
err := tx.InsertSecurity(security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return security, nil
|
return security, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var securities []*models.Security
|
securities, err := tx.FindMatchingSecurities(security)
|
||||||
|
|
||||||
_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -159,7 +85,7 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*
|
|||||||
// First try to find a case insensitive match on the name or symbol
|
// First try to find a case insensitive match on the name or symbol
|
||||||
upperName := strings.ToUpper(security.Name)
|
upperName := strings.ToUpper(security.Name)
|
||||||
upperSymbol := strings.ToUpper(security.Symbol)
|
upperSymbol := strings.ToUpper(security.Symbol)
|
||||||
for _, s := range securities {
|
for _, s := range *securities {
|
||||||
if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) ||
|
if (len(s.Name) > 0 && strings.ToUpper(s.Name) == upperName) ||
|
||||||
(len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) {
|
(len(s.Symbol) > 0 && strings.ToUpper(s.Symbol) == upperSymbol) {
|
||||||
return s, nil
|
return s, nil
|
||||||
@ -168,7 +94,7 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*
|
|||||||
// if strings.Contains(strings.ToUpper(security.Name), upperSearch) ||
|
// if strings.Contains(strings.ToUpper(security.Name), upperSearch) ||
|
||||||
|
|
||||||
// Try to find a partial string match on the name or symbol
|
// Try to find a partial string match on the name or symbol
|
||||||
for _, s := range securities {
|
for _, s := range *securities {
|
||||||
sUpperName := strings.ToUpper(s.Name)
|
sUpperName := strings.ToUpper(s.Name)
|
||||||
sUpperSymbol := strings.ToUpper(s.Symbol)
|
sUpperSymbol := strings.ToUpper(s.Symbol)
|
||||||
if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) ||
|
if (len(upperName) > 0 && len(s.Name) > 0 && (strings.Contains(upperName, sUpperName) || strings.Contains(sUpperName, upperName))) ||
|
||||||
@ -178,12 +104,12 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Give up and return the first security in the list
|
// Give up and return the first security in the list
|
||||||
if len(securities) > 0 {
|
if len(*securities) > 0 {
|
||||||
return securities[0], nil
|
return (*securities)[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there wasn't even one security in the list, make a new one
|
// If there wasn't even one security in the list, make a new one
|
||||||
err = InsertSecurity(tx, security)
|
err = tx.InsertSecurity(security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -216,7 +142,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
security.SecurityId = -1
|
security.SecurityId = -1
|
||||||
security.UserId = user.UserId
|
security.UserId = user.UserId
|
||||||
|
|
||||||
err = InsertSecurity(context.Tx, &security)
|
err = context.Tx.InsertSecurity(&security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -228,7 +154,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
//Return all securities
|
//Return all securities
|
||||||
var sl models.SecurityList
|
var sl models.SecurityList
|
||||||
|
|
||||||
securities, err := GetSecurities(context.Tx, user.UserId)
|
securities, err := context.Tx.GetSecurities(user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -249,7 +175,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
return PriceHandler(r, context, user, securityid)
|
return PriceHandler(r, context, user, securityid)
|
||||||
}
|
}
|
||||||
|
|
||||||
security, err := GetSecurity(context.Tx, securityid, user.UserId)
|
security, err := context.Tx.GetSecurity(securityid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
@ -283,13 +209,13 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
|
|
||||||
return &security
|
return &security
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
security, err := GetSecurity(context.Tx, securityid, user.UserId)
|
security, err := context.Tx.GetSecurity(securityid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteSecurity(context.Tx, security)
|
err = context.Tx.DeleteSecurity(security)
|
||||||
if _, ok := err.(SecurityInUseError); ok {
|
if _, ok := err.(store.SecurityInUseError); ok {
|
||||||
return NewError(7 /*In Use Error*/)
|
return NewError(7 /*In Use Error*/)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
"github.com/yuin/gopher-lua"
|
"github.com/yuin/gopher-lua"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -14,7 +15,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error)
|
|||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
|
|
||||||
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
tx, ok := ctx.Value(dbContextKey).(store.Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Couldn't find tx in lua's Context")
|
return nil, errors.New("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
@ -26,7 +27,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error)
|
|||||||
return nil, errors.New("Couldn't find User in lua's Context")
|
return nil, errors.New("Couldn't find User in lua's Context")
|
||||||
}
|
}
|
||||||
|
|
||||||
securities, err := GetSecurities(tx, user.UserId)
|
securities, err := tx.GetSecurities(user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -158,7 +159,7 @@ func luaClosestPrice(L *lua.LState) int {
|
|||||||
date := luaCheckTime(L, 3)
|
date := luaCheckTime(L, 3)
|
||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
tx, ok := ctx.Value(dbContextKey).(store.Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("Couldn't find tx in lua's Context")
|
panic("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
|
@ -3,36 +3,37 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetSession(tx *Tx, r *http.Request) (*models.Session, error) {
|
func GetSession(tx store.Tx, r *http.Request) (*models.Session, error) {
|
||||||
var s models.Session
|
|
||||||
|
|
||||||
cookie, err := r.Cookie("moneygo-session")
|
cookie, err := r.Cookie("moneygo-session")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("moneygo-session cookie not set")
|
return nil, fmt.Errorf("moneygo-session cookie not set")
|
||||||
}
|
}
|
||||||
s.SessionSecret = cookie.Value
|
|
||||||
|
|
||||||
err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
|
s, err := tx.GetSession(cookie.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.Expires.Before(time.Now()) {
|
if s.Expires.Before(time.Now()) {
|
||||||
tx.Delete(&s)
|
err := tx.DeleteSession(s)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Unexpected error when attempting to delete expired session: %s", err)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("Session has expired")
|
return nil, fmt.Errorf("Session has expired")
|
||||||
}
|
}
|
||||||
return &s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteSessionIfExists(tx *Tx, r *http.Request) error {
|
func DeleteSessionIfExists(tx store.Tx, r *http.Request) error {
|
||||||
session, err := GetSession(tx, r)
|
session, err := GetSession(tx, r)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_, err := tx.Delete(session)
|
err := tx.DeleteSession(session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -50,21 +51,26 @@ func (n *NewSessionWriter) Write(w http.ResponseWriter) error {
|
|||||||
return n.session.Write(w)
|
return n.session.Write(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
|
func NewSession(tx store.Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
|
||||||
|
err := DeleteSessionIfExists(tx, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
s, err := models.NewSession(userid)
|
s, err := models.NewSession(userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", s.SessionSecret)
|
exists, err := tx.SessionExists(s.SessionSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if existing > 0 {
|
if exists {
|
||||||
return nil, fmt.Errorf("%d session(s) exist with the generated session_secret", existing)
|
return nil, fmt.Errorf("Session already exists with the generated session_secret")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Insert(s)
|
err = tx.InsertSession(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -79,22 +85,19 @@ func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
dbuser, err := GetUserByUsername(context.Tx, user.Username)
|
// Hash password before checking username to help mitigate timing
|
||||||
|
// attacks
|
||||||
|
user.HashPassword()
|
||||||
|
|
||||||
|
dbuser, err := context.Tx.GetUserByUsername(user.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(2 /*Unauthorized Access*/)
|
return NewError(2 /*Unauthorized Access*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
user.HashPassword()
|
|
||||||
if user.PasswordHash != dbuser.PasswordHash {
|
if user.PasswordHash != dbuser.PasswordHash {
|
||||||
return NewError(2 /*Unauthorized Access*/)
|
return NewError(2 /*Unauthorized Access*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteSessionIfExists(context.Tx, r)
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
return NewError(999 /*Internal Error*/)
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId)
|
sessionwriter, err := NewSession(context.Tx, r, dbuser.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
@ -2,24 +2,18 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
"log"
|
"log"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func SplitAlreadyImported(tx *Tx, s *models.Split) (bool, error) {
|
|
||||||
count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
|
|
||||||
return count == 1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return a map of security ID's to big.Rat's containing the amount that
|
// Return a map of security ID's to big.Rat's containing the amount that
|
||||||
// security is imbalanced by
|
// security is imbalanced by
|
||||||
func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, error) {
|
func GetTransactionImbalances(tx store.Tx, t *models.Transaction) (map[int64]big.Rat, error) {
|
||||||
sums := make(map[int64]big.Rat)
|
sums := make(map[int64]big.Rat)
|
||||||
|
|
||||||
if !t.Valid() {
|
if !t.Valid() {
|
||||||
@ -31,7 +25,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat,
|
|||||||
if t.Splits[i].AccountId != -1 {
|
if t.Splits[i].AccountId != -1 {
|
||||||
var err error
|
var err error
|
||||||
var account *models.Account
|
var account *models.Account
|
||||||
account, err = GetAccount(tx, t.Splits[i].AccountId, t.UserId)
|
account, err = tx.GetAccount(t.Splits[i].AccountId, t.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -47,7 +41,7 @@ func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat,
|
|||||||
|
|
||||||
// Returns true if all securities contained in this transaction are balanced,
|
// Returns true if all securities contained in this transaction are balanced,
|
||||||
// false otherwise
|
// false otherwise
|
||||||
func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) {
|
func TransactionBalanced(tx store.Tx, t *models.Transaction) (bool, error) {
|
||||||
var zero big.Rat
|
var zero big.Rat
|
||||||
|
|
||||||
sums, err := GetTransactionImbalances(tx, t)
|
sums, err := GetTransactionImbalances(tx, t)
|
||||||
@ -63,219 +57,6 @@ func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transaction, error) {
|
|
||||||
var t models.Transaction
|
|
||||||
|
|
||||||
err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) {
|
|
||||||
var transactions []models.Transaction
|
|
||||||
|
|
||||||
_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range transactions {
|
|
||||||
_, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &transactions, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error {
|
|
||||||
for i := range accountids {
|
|
||||||
account, err := GetAccount(tx, accountids[i], user.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
account.AccountVersion++
|
|
||||||
count, err := tx.Update(account)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count != 1 {
|
|
||||||
return errors.New("Updated more than one account")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type AccountMissingError struct{}
|
|
||||||
|
|
||||||
func (ame AccountMissingError) Error() string {
|
|
||||||
return "Account missing"
|
|
||||||
}
|
|
||||||
|
|
||||||
func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
|
|
||||||
// Map of any accounts with transaction splits being added
|
|
||||||
a_map := make(map[int64]bool)
|
|
||||||
for i := range t.Splits {
|
|
||||||
if t.Splits[i].AccountId != -1 {
|
|
||||||
existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if existing != 1 {
|
|
||||||
return AccountMissingError{}
|
|
||||||
}
|
|
||||||
a_map[t.Splits[i].AccountId] = true
|
|
||||||
} else if t.Splits[i].SecurityId == -1 {
|
|
||||||
return AccountMissingError{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//increment versions for all accounts
|
|
||||||
var a_ids []int64
|
|
||||||
for id := range a_map {
|
|
||||||
a_ids = append(a_ids, id)
|
|
||||||
}
|
|
||||||
// ensure at least one of the splits is associated with an actual account
|
|
||||||
if len(a_ids) < 1 {
|
|
||||||
return AccountMissingError{}
|
|
||||||
}
|
|
||||||
err := incrementAccountVersions(tx, user, a_ids)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.UserId = user.UserId
|
|
||||||
err = tx.Insert(t)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range t.Splits {
|
|
||||||
t.Splits[i].TransactionId = t.TransactionId
|
|
||||||
t.Splits[i].SplitId = -1
|
|
||||||
err = tx.Insert(t.Splits[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
|
|
||||||
var existing_splits []*models.Split
|
|
||||||
|
|
||||||
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Map of any accounts with transaction splits being added
|
|
||||||
a_map := make(map[int64]bool)
|
|
||||||
|
|
||||||
// Make a map with any existing splits for this transaction
|
|
||||||
s_map := make(map[int64]bool)
|
|
||||||
for i := range existing_splits {
|
|
||||||
s_map[existing_splits[i].SplitId] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert splits, updating any pre-existing ones
|
|
||||||
for i := range t.Splits {
|
|
||||||
t.Splits[i].TransactionId = t.TransactionId
|
|
||||||
_, ok := s_map[t.Splits[i].SplitId]
|
|
||||||
if ok {
|
|
||||||
count, err := tx.Update(t.Splits[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count > 1 {
|
|
||||||
return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count)
|
|
||||||
}
|
|
||||||
delete(s_map, t.Splits[i].SplitId)
|
|
||||||
} else {
|
|
||||||
t.Splits[i].SplitId = -1
|
|
||||||
err := tx.Insert(t.Splits[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if t.Splits[i].AccountId != -1 {
|
|
||||||
a_map[t.Splits[i].AccountId] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete any remaining pre-existing splits
|
|
||||||
for i := range existing_splits {
|
|
||||||
_, ok := s_map[existing_splits[i].SplitId]
|
|
||||||
if existing_splits[i].AccountId != -1 {
|
|
||||||
a_map[existing_splits[i].AccountId] = true
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
_, err := tx.Delete(existing_splits[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Increment versions for all accounts with modified splits
|
|
||||||
var a_ids []int64
|
|
||||||
for id := range a_map {
|
|
||||||
a_ids = append(a_ids, id)
|
|
||||||
}
|
|
||||||
err = incrementAccountVersions(tx, user, a_ids)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
count, err := tx.Update(t)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count > 1 {
|
|
||||||
return fmt.Errorf("Updated %d transactions (expected 1)", count)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeleteTransaction(tx *Tx, t *models.Transaction, user *models.User) error {
|
|
||||||
var accountids []int64
|
|
||||||
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
count, err := tx.Delete(t)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count != 1 {
|
|
||||||
return errors.New("Deleted more than one transaction")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = incrementAccountVersions(tx, user, accountids)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||||
user, err := GetUserFromSession(context.Tx, r)
|
user, err := GetUserFromSession(context.Tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -296,7 +77,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
|||||||
|
|
||||||
for i := range transaction.Splits {
|
for i := range transaction.Splits {
|
||||||
transaction.Splits[i].SplitId = -1
|
transaction.Splits[i].SplitId = -1
|
||||||
_, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId)
|
_, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
@ -310,9 +91,9 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = InsertTransaction(context.Tx, &transaction, user)
|
err = context.Tx.InsertTransaction(&transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(AccountMissingError); ok {
|
if _, ok := err.(store.AccountMissingError); ok {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
} else {
|
} else {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
@ -325,7 +106,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
|||||||
if context.LastLevel() {
|
if context.LastLevel() {
|
||||||
//Return all Transactions
|
//Return all Transactions
|
||||||
var al models.TransactionList
|
var al models.TransactionList
|
||||||
transactions, err := GetTransactions(context.Tx, user.UserId)
|
transactions, err := context.Tx.GetTransactions(user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -338,7 +119,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
transaction, err := GetTransaction(context.Tx, transactionid, user.UserId)
|
transaction, err := context.Tx.GetTransaction(transactionid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
@ -370,13 +151,13 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := range transaction.Splits {
|
for i := range transaction.Splits {
|
||||||
_, err := GetAccount(context.Tx, transaction.Splits[i].AccountId, user.UserId)
|
_, err := context.Tx.GetAccount(transaction.Splits[i].AccountId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = UpdateTransaction(context.Tx, &transaction, user)
|
err = context.Tx.UpdateTransaction(&transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -384,12 +165,12 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
|||||||
|
|
||||||
return &transaction
|
return &transaction
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
transaction, err := GetTransaction(context.Tx, transactionid, user.UserId)
|
transaction, err := context.Tx.GetTransaction(transactionid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteTransaction(context.Tx, transaction, user)
|
err = context.Tx.DeleteTransaction(transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
@ -401,41 +182,9 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
|
|||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) {
|
func BalanceFromSplits(splits *[]*models.Split) (*big.Rat, error) {
|
||||||
var pageDifference, tmp big.Rat
|
|
||||||
for i := range transactions {
|
|
||||||
_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sum up the amounts from the splits we're returning so we can return
|
|
||||||
// an ending balance
|
|
||||||
for j := range transactions[i].Splits {
|
|
||||||
if transactions[i].Splits[j].AccountId == accountid {
|
|
||||||
rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tmp.Add(&pageDifference, rat_amount)
|
|
||||||
pageDifference.Set(&tmp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &pageDifference, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) {
|
|
||||||
var splits []models.Split
|
|
||||||
|
|
||||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
|
|
||||||
_, err := tx.Select(&splits, sql, accountid, user.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var balance, tmp big.Rat
|
var balance, tmp big.Rat
|
||||||
for _, s := range splits {
|
for _, s := range *splits {
|
||||||
rat_amount, err := models.GetBigAmount(s.Amount)
|
rat_amount, err := models.GetBigAmount(s.Amount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -447,132 +196,6 @@ func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, er
|
|||||||
return &balance, nil
|
return &balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assumes accountid is valid and is owned by the current user
|
|
||||||
func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) {
|
|
||||||
var splits []models.Split
|
|
||||||
|
|
||||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
|
|
||||||
_, err := tx.Select(&splits, sql, accountid, user.UserId, date)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var balance, tmp big.Rat
|
|
||||||
for _, s := range splits {
|
|
||||||
rat_amount, err := models.GetBigAmount(s.Amount)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tmp.Add(&balance, rat_amount)
|
|
||||||
balance.Set(&tmp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &balance, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
|
||||||
var splits []models.Split
|
|
||||||
|
|
||||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
|
|
||||||
_, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var balance, tmp big.Rat
|
|
||||||
for _, s := range splits {
|
|
||||||
rat_amount, err := models.GetBigAmount(s.Amount)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tmp.Add(&balance, rat_amount)
|
|
||||||
balance.Set(&tmp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &balance, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
|
|
||||||
var transactions []models.Transaction
|
|
||||||
var atl models.AccountTransactionsList
|
|
||||||
|
|
||||||
var sqlsort, balanceLimitOffset string
|
|
||||||
var balanceLimitOffsetArg uint64
|
|
||||||
if sort == "date-asc" {
|
|
||||||
sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC"
|
|
||||||
balanceLimitOffset = " LIMIT ?"
|
|
||||||
balanceLimitOffsetArg = page * limit
|
|
||||||
} else if sort == "date-desc" {
|
|
||||||
numSplits, err := tx.SelectInt("SELECT count(*) FROM splits")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC"
|
|
||||||
balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits)
|
|
||||||
balanceLimitOffsetArg = (page + 1) * limit
|
|
||||||
}
|
|
||||||
|
|
||||||
var sqloffset string
|
|
||||||
if page > 0 {
|
|
||||||
sqloffset = fmt.Sprintf(" OFFSET %d", page*limit)
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := GetAccount(tx, accountid, user.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
atl.Account = account
|
|
||||||
|
|
||||||
sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset
|
|
||||||
_, err = tx.Select(&transactions, sql, user.UserId, accountid, limit)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
atl.Transactions = &transactions
|
|
||||||
|
|
||||||
pageDifference, err := TransactionsBalanceDifference(tx, accountid, transactions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
atl.TotalTransactions = count
|
|
||||||
|
|
||||||
security, err := GetSecurity(tx, atl.Account.SecurityId, user.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if security == nil {
|
|
||||||
return nil, errors.New("Security not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sum all the splits for all transaction splits for this account that
|
|
||||||
// occurred before the page we're returning
|
|
||||||
var amounts []string
|
|
||||||
sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?"
|
|
||||||
_, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var tmp, balance big.Rat
|
|
||||||
for _, amount := range amounts {
|
|
||||||
rat_amount, err := models.GetBigAmount(amount)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tmp.Add(&balance, rat_amount)
|
|
||||||
balance.Set(&tmp)
|
|
||||||
}
|
|
||||||
atl.BeginningBalance = balance.FloatString(security.Precision)
|
|
||||||
atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision)
|
|
||||||
|
|
||||||
return &atl, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return only those transactions which have at least one split pertaining to
|
// Return only those transactions which have at least one split pertaining to
|
||||||
// an account
|
// an account
|
||||||
func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter {
|
func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter {
|
||||||
@ -608,7 +231,7 @@ func AccountTransactionsHandler(context *Context, r *http.Request, user *models.
|
|||||||
sort = sortstring
|
sort = sortstring
|
||||||
}
|
}
|
||||||
|
|
||||||
accountTransactions, err := GetAccountTransactions(context.Tx, user, accountid, sort, page, limit)
|
accountTransactions, err := context.Tx.GetAccountTransactions(user, accountid, sort, page, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
|
@ -276,7 +276,7 @@ func TestGetTransactions(t *testing.T) {
|
|||||||
found := false
|
found := false
|
||||||
for _, tran := range *tl.Transactions {
|
for _, tran := range *tl.Transactions {
|
||||||
if tran.TransactionId == curr.TransactionId {
|
if tran.TransactionId == curr.TransactionId {
|
||||||
ensureTransactionsMatch(t, &curr, &tran, nil, true, true)
|
ensureTransactionsMatch(t, &curr, tran, nil, true, true)
|
||||||
if _, ok := foundIds[tran.TransactionId]; ok {
|
if _, ok := foundIds[tran.TransactionId]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -410,7 +410,7 @@ func helperTestAccountTransactions(t *testing.T, d *TestData, account *models.Ac
|
|||||||
}
|
}
|
||||||
if atl.Transactions != nil {
|
if atl.Transactions != nil {
|
||||||
for _, tran := range *atl.Transactions {
|
for _, tran := range *atl.Transactions {
|
||||||
transactions = append(transactions, tran)
|
transactions = append(transactions, *tran)
|
||||||
}
|
}
|
||||||
lastFetchCount = int64(len(*atl.Transactions))
|
lastFetchCount = int64(len(*atl.Transactions))
|
||||||
} else {
|
} else {
|
||||||
|
@ -2,8 +2,8 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
@ -14,41 +14,21 @@ func (ueu UserExistsError) Error() string {
|
|||||||
return "User exists"
|
return "User exists"
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(tx *Tx, userid int64) (*models.User, error) {
|
func InsertUser(tx store.Tx, u *models.User) error {
|
||||||
var u models.User
|
|
||||||
|
|
||||||
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUserByUsername(tx *Tx, username string) (*models.User, error) {
|
|
||||||
var u models.User
|
|
||||||
|
|
||||||
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InsertUser(tx *Tx, u *models.User) error {
|
|
||||||
security_template := FindCurrencyTemplate(u.DefaultCurrency)
|
security_template := FindCurrencyTemplate(u.DefaultCurrency)
|
||||||
if security_template == nil {
|
if security_template == nil {
|
||||||
return errors.New("Invalid ISO4217 Default Currency")
|
return errors.New("Invalid ISO4217 Default Currency")
|
||||||
}
|
}
|
||||||
|
|
||||||
existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username)
|
exists, err := tx.UsernameExists(u.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if existing > 0 {
|
if exists {
|
||||||
return UserExistsError{}
|
return UserExistsError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Insert(u)
|
err = tx.InsertUser(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -58,33 +38,31 @@ func InsertUser(tx *Tx, u *models.User) error {
|
|||||||
security = *security_template
|
security = *security_template
|
||||||
security.UserId = u.UserId
|
security.UserId = u.UserId
|
||||||
|
|
||||||
err = InsertSecurity(tx, &security)
|
err = tx.InsertSecurity(&security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the user's DefaultCurrency to our new SecurityId
|
// Update the user's DefaultCurrency to our new SecurityId
|
||||||
u.DefaultCurrency = security.SecurityId
|
u.DefaultCurrency = security.SecurityId
|
||||||
count, err := tx.Update(u)
|
err = tx.UpdateUser(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if count != 1 {
|
|
||||||
return errors.New("Would have updated more than one user")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) {
|
func GetUserFromSession(tx store.Tx, r *http.Request) (*models.User, error) {
|
||||||
s, err := GetSession(tx, r)
|
s, err := GetSession(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return GetUser(tx, s.UserId)
|
return tx.GetUser(s.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUser(tx *Tx, u *models.User) error {
|
func UpdateUser(tx store.Tx, u *models.User) error {
|
||||||
security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
|
security, err := tx.GetSecurity(u.DefaultCurrency, u.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {
|
} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {
|
||||||
@ -93,49 +71,7 @@ func UpdateUser(tx *Tx, u *models.User) error {
|
|||||||
return errors.New("New DefaultCurrency security is not a currency")
|
return errors.New("New DefaultCurrency security is not a currency")
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := tx.Update(u)
|
err = tx.UpdateUser(u)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
} else if count != 1 {
|
|
||||||
return errors.New("Would have updated more than one user")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeleteUser(tx *Tx, u *models.User) error {
|
|
||||||
count, err := tx.Delete(u)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count != 1 {
|
|
||||||
return fmt.Errorf("No user to delete")
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -204,7 +140,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
|
|
||||||
return user
|
return user
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
err := DeleteUser(context.Tx, user)
|
err := context.Tx.DeleteUser(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return NewError(999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
|
@ -94,7 +94,7 @@ type Account struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AccountList struct {
|
type AccountList struct {
|
||||||
Accounts *[]Account `json:"accounts"`
|
Accounts *[]*Account `json:"accounts"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) Write(w http.ResponseWriter) error {
|
func (a *Account) Write(w http.ResponseWriter) error {
|
||||||
|
@ -28,7 +28,7 @@ func (r *Report) Read(json_str string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ReportList struct {
|
type ReportList struct {
|
||||||
Reports *[]Report `json:"reports"`
|
Reports *[]*Report `json:"reports"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rl *ReportList) Write(w http.ResponseWriter) error {
|
func (rl *ReportList) Write(w http.ResponseWriter) error {
|
||||||
|
@ -82,12 +82,12 @@ type Transaction struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TransactionList struct {
|
type TransactionList struct {
|
||||||
Transactions *[]Transaction `json:"transactions"`
|
Transactions *[]*Transaction `json:"transactions"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountTransactionsList struct {
|
type AccountTransactionsList struct {
|
||||||
Account *Account
|
Account *Account
|
||||||
Transactions *[]Transaction
|
Transactions *[]*Transaction
|
||||||
TotalTransactions int64
|
TotalTransactions int64
|
||||||
BeginningBalance string
|
BeginningBalance string
|
||||||
EndingBalance string
|
EndingBalance string
|
||||||
|
133
internal/store/db/accounts.go
Normal file
133
internal/store/db/accounts.go
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tx *Tx) GetAccount(accountid int64, userid int64) (*models.Account, error) {
|
||||||
|
var account models.Account
|
||||||
|
|
||||||
|
err := tx.SelectOne(&account, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetAccounts(userid int64) (*[]*models.Account, error) {
|
||||||
|
var accounts []*models.Account
|
||||||
|
|
||||||
|
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &accounts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) FindMatchingAccounts(account *models.Account) (*[]*models.Account, error) {
|
||||||
|
var accounts []*models.Account
|
||||||
|
|
||||||
|
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC", account.UserId, account.SecurityId, account.Type, account.Name, account.ParentAccountId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &accounts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) insertUpdateAccount(account *models.Account, insert bool) error {
|
||||||
|
found := make(map[int64]bool)
|
||||||
|
if !insert {
|
||||||
|
found[account.AccountId] = true
|
||||||
|
}
|
||||||
|
parentid := account.ParentAccountId
|
||||||
|
depth := 0
|
||||||
|
for parentid != -1 {
|
||||||
|
depth += 1
|
||||||
|
if depth > 100 {
|
||||||
|
return store.TooMuchNestingError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var a models.Account
|
||||||
|
err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid)
|
||||||
|
if err != nil {
|
||||||
|
return store.ParentAccountMissingError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insertion by itself can never result in circular dependencies
|
||||||
|
if insert {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
found[parentid] = true
|
||||||
|
parentid = a.ParentAccountId
|
||||||
|
if _, ok := found[parentid]; ok {
|
||||||
|
return store.CircularAccountsError{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if insert {
|
||||||
|
err := tx.Insert(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
oldacct, err := tx.GetAccount(account.AccountId, account.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
account.AccountVersion = oldacct.AccountVersion + 1
|
||||||
|
|
||||||
|
count, err := tx.Update(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return errors.New("Updated more than one account")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) InsertAccount(account *models.Account) error {
|
||||||
|
return tx.insertUpdateAccount(account, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) UpdateAccount(account *models.Account) error {
|
||||||
|
return tx.insertUpdateAccount(account, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) DeleteAccount(account *models.Account) error {
|
||||||
|
if account.ParentAccountId != -1 {
|
||||||
|
// Re-parent splits to this account's parent account if this account isn't a root account
|
||||||
|
_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", account.ParentAccountId, account.AccountId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Delete splits if this account is a root account
|
||||||
|
_, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", account.AccountId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-parent child accounts to this account's parent account
|
||||||
|
_, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", account.ParentAccountId, account.AccountId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := tx.Delete(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return errors.New("Was going to delete more than one account")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/aclindsa/gorp"
|
"github.com/aclindsa/gorp"
|
||||||
"github.com/aclindsa/moneygo/internal/config"
|
"github.com/aclindsa/moneygo/internal/config"
|
||||||
"github.com/aclindsa/moneygo/internal/models"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
@ -18,7 +19,7 @@ import (
|
|||||||
// implementation's string type specified by the same.
|
// implementation's string type specified by the same.
|
||||||
const luaMaxLengthBuffer int = 4096
|
const luaMaxLengthBuffer int = 4096
|
||||||
|
|
||||||
func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
|
func getDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
|
||||||
var dialect gorp.Dialect
|
var dialect gorp.Dialect
|
||||||
if dbtype == config.SQLite {
|
if dbtype == config.SQLite {
|
||||||
dialect = gorp.SqliteDialect{}
|
dialect = gorp.SqliteDialect{}
|
||||||
@ -38,11 +39,11 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
|
|||||||
dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
|
dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
|
||||||
dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId")
|
dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId")
|
||||||
dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId")
|
dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId")
|
||||||
dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId")
|
|
||||||
dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId")
|
dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId")
|
||||||
|
dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId")
|
||||||
|
dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId")
|
||||||
dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId")
|
dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId")
|
||||||
dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId")
|
dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId")
|
||||||
dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId")
|
|
||||||
rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId")
|
rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId")
|
||||||
rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer)
|
rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer)
|
||||||
|
|
||||||
@ -54,9 +55,50 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
|
|||||||
return dbmap, nil
|
return dbmap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetDSN(dbtype config.DbType, dsn string) string {
|
func getDSN(dbtype config.DbType, dsn string) string {
|
||||||
if dbtype == config.MySQL && !strings.Contains(dsn, "parseTime=true") {
|
if dbtype == config.MySQL && !strings.Contains(dsn, "parseTime=true") {
|
||||||
log.Fatalf("The DSN for MySQL MUST contain 'parseTime=True' but does not!")
|
log.Fatalf("The DSN for MySQL MUST contain 'parseTime=True' but does not!")
|
||||||
}
|
}
|
||||||
return dsn
|
return dsn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DbStore struct {
|
||||||
|
dbMap *gorp.DbMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DbStore) Empty() error {
|
||||||
|
return db.dbMap.TruncateTables()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DbStore) Begin() (store.Tx, error) {
|
||||||
|
tx, err := db.dbMap.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Tx{db.dbMap.Dialect, tx}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DbStore) Close() error {
|
||||||
|
err := db.dbMap.Db.Close()
|
||||||
|
db.dbMap = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetStore(dbtype config.DbType, dsn string) (store *DbStore, err error) {
|
||||||
|
dsn = getDSN(dbtype, dsn)
|
||||||
|
database, err := sql.Open(dbtype.String(), dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
database.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dbmap, err := getDbMap(database, dbtype)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &DbStore{dbmap}, nil
|
||||||
|
}
|
78
internal/store/db/prices.go
Normal file
78
internal/store/db/prices.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tx *Tx) PriceExists(price *models.Price) (bool, error) {
|
||||||
|
var prices []*models.Price
|
||||||
|
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value)
|
||||||
|
return len(prices) > 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) InsertPrice(price *models.Price) error {
|
||||||
|
return tx.Insert(price)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetPrice(priceid, securityid int64) (*models.Price, error) {
|
||||||
|
var price models.Price
|
||||||
|
err := tx.SelectOne(&price, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &price, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetPrices(securityid int64) (*[]*models.Price, error) {
|
||||||
|
var prices []*models.Price
|
||||||
|
|
||||||
|
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &prices, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the latest price for security in currency units before date
|
||||||
|
func (tx *Tx) GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
||||||
|
var price models.Price
|
||||||
|
err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &price, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the earliest price for security in currency units after date
|
||||||
|
func (tx *Tx) GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error) {
|
||||||
|
var price models.Price
|
||||||
|
err := tx.SelectOne(&price, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &price, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) UpdatePrice(price *models.Price) error {
|
||||||
|
count, err := tx.Update(price)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to update 1 price, was going to update %d", count)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) DeletePrice(price *models.Price) error {
|
||||||
|
count, err := tx.Delete(price)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to delete 1 price, was going to delete %d", count)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
56
internal/store/db/reports.go
Normal file
56
internal/store/db/reports.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tx *Tx) GetReport(reportid int64, userid int64) (*models.Report, error) {
|
||||||
|
var r models.Report
|
||||||
|
|
||||||
|
err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetReports(userid int64) (*[]*models.Report, error) {
|
||||||
|
var reports []*models.Report
|
||||||
|
|
||||||
|
_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &reports, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) InsertReport(report *models.Report) error {
|
||||||
|
err := tx.Insert(report)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) UpdateReport(report *models.Report) error {
|
||||||
|
count, err := tx.Update(report)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to update 1 report, was going to update %d", count)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) DeleteReport(report *models.Report) error {
|
||||||
|
count, err := tx.Delete(report)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to delete 1 report, was going to delete %d", count)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
88
internal/store/db/securities.go
Normal file
88
internal/store/db/securities.go
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tx *Tx) GetSecurity(securityid int64, userid int64) (*models.Security, error) {
|
||||||
|
var s models.Security
|
||||||
|
|
||||||
|
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetSecurities(userid int64) (*[]*models.Security, error) {
|
||||||
|
var securities []*models.Security
|
||||||
|
|
||||||
|
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &securities, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) FindMatchingSecurities(security *models.Security) (*[]*models.Security, error) {
|
||||||
|
var securities []*models.Security
|
||||||
|
|
||||||
|
_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", security.UserId, security.Type, security.AlternateId, security.Precision)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &securities, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) InsertSecurity(s *models.Security) error {
|
||||||
|
err := tx.Insert(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) UpdateSecurity(security *models.Security) error {
|
||||||
|
count, err := tx.Update(security)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to update 1 security, was going to update %d", count)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) DeleteSecurity(s *models.Security) error {
|
||||||
|
// First, ensure no accounts are using this security
|
||||||
|
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
|
||||||
|
|
||||||
|
if accounts != 0 {
|
||||||
|
return store.SecurityInUseError{"One or more accounts still use this security"}
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := tx.GetUser(s.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if user.DefaultCurrency == s.SecurityId {
|
||||||
|
return store.SecurityInUseError{"Cannot delete security which is user's default currency"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove all prices involving this security (either of this security, or
|
||||||
|
// using it as a currency)
|
||||||
|
_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := tx.Delete(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to delete 1 security, was going to delete %d", count)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
42
internal/store/db/sessions.go
Normal file
42
internal/store/db/sessions.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tx *Tx) InsertSession(session *models.Session) error {
|
||||||
|
return tx.Insert(session)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetSession(secret string) (*models.Session, error) {
|
||||||
|
var s models.Session
|
||||||
|
|
||||||
|
err := tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", secret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Expires.Before(time.Now()) {
|
||||||
|
tx.Delete(&s)
|
||||||
|
return nil, fmt.Errorf("Session has expired")
|
||||||
|
}
|
||||||
|
return &s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) SessionExists(secret string) (bool, error) {
|
||||||
|
existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", secret)
|
||||||
|
return existing != 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) DeleteSession(session *models.Session) error {
|
||||||
|
count, err := tx.Delete(session)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
361
internal/store/db/transactions.go
Normal file
361
internal/store/db/transactions.go
Normal file
@ -0,0 +1,361 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store"
|
||||||
|
"math/big"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tx *Tx) incrementAccountVersions(user *models.User, accountids []int64) error {
|
||||||
|
for i := range accountids {
|
||||||
|
account, err := tx.GetAccount(accountids[i], user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
account.AccountVersion++
|
||||||
|
count, err := tx.Update(account)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return errors.New("Updated more than one account")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) InsertTransaction(t *models.Transaction, user *models.User) error {
|
||||||
|
// Map of any accounts with transaction splits being added
|
||||||
|
a_map := make(map[int64]bool)
|
||||||
|
for i := range t.Splits {
|
||||||
|
if t.Splits[i].AccountId != -1 {
|
||||||
|
existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if existing != 1 {
|
||||||
|
return store.AccountMissingError{}
|
||||||
|
}
|
||||||
|
a_map[t.Splits[i].AccountId] = true
|
||||||
|
} else if t.Splits[i].SecurityId == -1 {
|
||||||
|
return store.AccountMissingError{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//increment versions for all accounts
|
||||||
|
var a_ids []int64
|
||||||
|
for id := range a_map {
|
||||||
|
a_ids = append(a_ids, id)
|
||||||
|
}
|
||||||
|
// ensure at least one of the splits is associated with an actual account
|
||||||
|
if len(a_ids) < 1 {
|
||||||
|
return store.AccountMissingError{}
|
||||||
|
}
|
||||||
|
err := tx.incrementAccountVersions(user, a_ids)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.UserId = user.UserId
|
||||||
|
err = tx.Insert(t)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range t.Splits {
|
||||||
|
t.Splits[i].TransactionId = t.TransactionId
|
||||||
|
t.Splits[i].SplitId = -1
|
||||||
|
err = tx.Insert(t.Splits[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) SplitExists(s *models.Split) (bool, error) {
|
||||||
|
count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
|
||||||
|
return count == 1, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetTransaction(transactionid int64, userid int64) (*models.Transaction, error) {
|
||||||
|
var t models.Transaction
|
||||||
|
|
||||||
|
err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetTransactions(userid int64) (*[]*models.Transaction, error) {
|
||||||
|
var transactions []*models.Transaction
|
||||||
|
|
||||||
|
_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range transactions {
|
||||||
|
_, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &transactions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) UpdateTransaction(t *models.Transaction, user *models.User) error {
|
||||||
|
var existing_splits []*models.Split
|
||||||
|
|
||||||
|
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map of any accounts with transaction splits being added
|
||||||
|
a_map := make(map[int64]bool)
|
||||||
|
|
||||||
|
// Make a map with any existing splits for this transaction
|
||||||
|
s_map := make(map[int64]bool)
|
||||||
|
for i := range existing_splits {
|
||||||
|
s_map[existing_splits[i].SplitId] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert splits, updating any pre-existing ones
|
||||||
|
for i := range t.Splits {
|
||||||
|
t.Splits[i].TransactionId = t.TransactionId
|
||||||
|
_, ok := s_map[t.Splits[i].SplitId]
|
||||||
|
if ok {
|
||||||
|
count, err := tx.Update(t.Splits[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count > 1 {
|
||||||
|
return fmt.Errorf("Updated %d transaction splits while attempting to update only 1", count)
|
||||||
|
}
|
||||||
|
delete(s_map, t.Splits[i].SplitId)
|
||||||
|
} else {
|
||||||
|
t.Splits[i].SplitId = -1
|
||||||
|
err := tx.Insert(t.Splits[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if t.Splits[i].AccountId != -1 {
|
||||||
|
a_map[t.Splits[i].AccountId] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete any remaining pre-existing splits
|
||||||
|
for i := range existing_splits {
|
||||||
|
_, ok := s_map[existing_splits[i].SplitId]
|
||||||
|
if existing_splits[i].AccountId != -1 {
|
||||||
|
a_map[existing_splits[i].AccountId] = true
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
_, err := tx.Delete(existing_splits[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment versions for all accounts with modified splits
|
||||||
|
var a_ids []int64
|
||||||
|
for id := range a_map {
|
||||||
|
a_ids = append(a_ids, id)
|
||||||
|
}
|
||||||
|
err = tx.incrementAccountVersions(user, a_ids)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := tx.Update(t)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count > 1 {
|
||||||
|
return fmt.Errorf("Updated %d transactions (expected 1)", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) DeleteTransaction(t *models.Transaction, user *models.User) error {
|
||||||
|
var accountids []int64
|
||||||
|
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := tx.Delete(t)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return errors.New("Deleted more than one transaction")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.incrementAccountVersions(user, accountids)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error) {
|
||||||
|
var splits []*models.Split
|
||||||
|
|
||||||
|
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
|
||||||
|
_, err := tx.Select(&splits, sql, accountid, user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &splits, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assumes accountid is valid and is owned by the current user
|
||||||
|
func (tx *Tx) GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error) {
|
||||||
|
var splits []*models.Split
|
||||||
|
|
||||||
|
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
|
||||||
|
_, err := tx.Select(&splits, sql, accountid, user.UserId, date)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &splits, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error) {
|
||||||
|
var splits []*models.Split
|
||||||
|
|
||||||
|
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
|
||||||
|
_, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &splits, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) transactionsBalanceDifference(accountid int64, transactions []*models.Transaction) (*big.Rat, error) {
|
||||||
|
var pageDifference, tmp big.Rat
|
||||||
|
for i := range transactions {
|
||||||
|
_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sum up the amounts from the splits we're returning so we can return
|
||||||
|
// an ending balance
|
||||||
|
for j := range transactions[i].Splits {
|
||||||
|
if transactions[i].Splits[j].AccountId == accountid {
|
||||||
|
rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tmp.Add(&pageDifference, rat_amount)
|
||||||
|
pageDifference.Set(&tmp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &pageDifference, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) {
|
||||||
|
var transactions []*models.Transaction
|
||||||
|
var atl models.AccountTransactionsList
|
||||||
|
|
||||||
|
var sqlsort, balanceLimitOffset string
|
||||||
|
var balanceLimitOffsetArg uint64
|
||||||
|
if sort == "date-asc" {
|
||||||
|
sqlsort = " ORDER BY transactions.Date ASC, transactions.TransactionId ASC"
|
||||||
|
balanceLimitOffset = " LIMIT ?"
|
||||||
|
balanceLimitOffsetArg = page * limit
|
||||||
|
} else if sort == "date-desc" {
|
||||||
|
numSplits, err := tx.SelectInt("SELECT count(*) FROM splits")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sqlsort = " ORDER BY transactions.Date DESC, transactions.TransactionId DESC"
|
||||||
|
balanceLimitOffset = fmt.Sprintf(" LIMIT %d OFFSET ?", numSplits)
|
||||||
|
balanceLimitOffsetArg = (page + 1) * limit
|
||||||
|
}
|
||||||
|
|
||||||
|
var sqloffset string
|
||||||
|
if page > 0 {
|
||||||
|
sqloffset = fmt.Sprintf(" OFFSET %d", page*limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := tx.GetAccount(accountid, user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
atl.Account = account
|
||||||
|
|
||||||
|
sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset
|
||||||
|
_, err = tx.Select(&transactions, sql, user.UserId, accountid, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
atl.Transactions = &transactions
|
||||||
|
|
||||||
|
pageDifference, err := tx.transactionsBalanceDifference(accountid, transactions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
atl.TotalTransactions = count
|
||||||
|
|
||||||
|
security, err := tx.GetSecurity(atl.Account.SecurityId, user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if security == nil {
|
||||||
|
return nil, errors.New("Security not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sum all the splits for all transaction splits for this account that
|
||||||
|
// occurred before the page we're returning
|
||||||
|
var amounts []string
|
||||||
|
sql = "SELECT s.Amount FROM splits AS s INNER JOIN (SELECT DISTINCT transactions.Date, transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ") as t ON s.TransactionId = t.TransactionId WHERE s.AccountId=?"
|
||||||
|
_, err = tx.Select(&amounts, sql, user.UserId, accountid, balanceLimitOffsetArg, accountid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var tmp, balance big.Rat
|
||||||
|
for _, amount := range amounts {
|
||||||
|
rat_amount, err := models.GetBigAmount(amount)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tmp.Add(&balance, rat_amount)
|
||||||
|
balance.Set(&tmp)
|
||||||
|
}
|
||||||
|
atl.BeginningBalance = balance.FloatString(security.Precision)
|
||||||
|
atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision)
|
||||||
|
|
||||||
|
return &atl, nil
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package handlers
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
@ -41,7 +41,20 @@ func (tx *Tx) Insert(list ...interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Tx) Update(list ...interface{}) (int64, error) {
|
func (tx *Tx) Update(list ...interface{}) (int64, error) {
|
||||||
return tx.Tx.Update(list...)
|
count, err := tx.Tx.Update(list...)
|
||||||
|
if count == 0 {
|
||||||
|
switch tx.Dialect.(type) {
|
||||||
|
case gorp.MySQLDialect:
|
||||||
|
// Always return 1 for 0 if we're using MySQL because it returns
|
||||||
|
// count=0 if the row data was unchanged, even if the row existed
|
||||||
|
|
||||||
|
// TODO Find another way to fix this without risking ignoring
|
||||||
|
// errors
|
||||||
|
|
||||||
|
count = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
|
func (tx *Tx) Delete(list ...interface{}) (int64, error) {
|
||||||
@ -55,11 +68,3 @@ func (tx *Tx) Commit() error {
|
|||||||
func (tx *Tx) Rollback() error {
|
func (tx *Tx) Rollback() error {
|
||||||
return tx.Tx.Rollback()
|
return tx.Tx.Rollback()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTx(db *gorp.DbMap) (*Tx, error) {
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Tx{db.Dialect, tx}, nil
|
|
||||||
}
|
|
86
internal/store/db/users.go
Normal file
86
internal/store/db/users.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tx *Tx) UsernameExists(username string) (bool, error) {
|
||||||
|
existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", username)
|
||||||
|
return existing != 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) InsertUser(user *models.User) error {
|
||||||
|
return tx.Insert(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetUser(userid int64) (*models.User, error) {
|
||||||
|
var u models.User
|
||||||
|
|
||||||
|
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) GetUserByUsername(username string) (*models.User, error) {
|
||||||
|
var u models.User
|
||||||
|
|
||||||
|
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) UpdateUser(user *models.User) error {
|
||||||
|
count, err := tx.Update(user)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to update 1 user, was going to update %d", count)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) DeleteUser(user *models.User) error {
|
||||||
|
count, err := tx.Delete(user)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
return fmt.Errorf("Expected to delete 1 user, was going to delete %d", count)
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM prices WHERE prices.SecurityId IN (SELECT securities.SecurityId FROM securities WHERE securities.UserId=?)", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM splits WHERE splits.TransactionId IN (SELECT transactions.TransactionId FROM transactions WHERE transactions.UserId=?)", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", user.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
123
internal/store/store.go
Normal file
123
internal/store/store.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserStore interface {
|
||||||
|
UsernameExists(username string) (bool, error)
|
||||||
|
InsertUser(user *models.User) error
|
||||||
|
GetUser(userid int64) (*models.User, error)
|
||||||
|
GetUserByUsername(username string) (*models.User, error)
|
||||||
|
UpdateUser(user *models.User) error
|
||||||
|
DeleteUser(user *models.User) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionStore interface {
|
||||||
|
SessionExists(secret string) (bool, error)
|
||||||
|
InsertSession(session *models.Session) error
|
||||||
|
GetSession(secret string) (*models.Session, error)
|
||||||
|
DeleteSession(session *models.Session) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type SecurityInUseError struct {
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e SecurityInUseError) Error() string {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
type SecurityStore interface {
|
||||||
|
InsertSecurity(security *models.Security) error
|
||||||
|
GetSecurity(securityid int64, userid int64) (*models.Security, error)
|
||||||
|
GetSecurities(userid int64) (*[]*models.Security, error)
|
||||||
|
FindMatchingSecurities(security *models.Security) (*[]*models.Security, error)
|
||||||
|
UpdateSecurity(security *models.Security) error
|
||||||
|
DeleteSecurity(security *models.Security) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type PriceStore interface {
|
||||||
|
PriceExists(price *models.Price) (bool, error)
|
||||||
|
InsertPrice(price *models.Price) error
|
||||||
|
GetPrice(priceid, securityid int64) (*models.Price, error)
|
||||||
|
GetPrices(securityid int64) (*[]*models.Price, error)
|
||||||
|
GetLatestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error)
|
||||||
|
GetEarliestPrice(security, currency *models.Security, date *time.Time) (*models.Price, error)
|
||||||
|
UpdatePrice(price *models.Price) error
|
||||||
|
DeletePrice(price *models.Price) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type ParentAccountMissingError struct{}
|
||||||
|
|
||||||
|
func (pame ParentAccountMissingError) Error() string {
|
||||||
|
return "Parent account missing"
|
||||||
|
}
|
||||||
|
|
||||||
|
type TooMuchNestingError struct{}
|
||||||
|
|
||||||
|
func (tmne TooMuchNestingError) Error() string {
|
||||||
|
return "Too much account nesting"
|
||||||
|
}
|
||||||
|
|
||||||
|
type CircularAccountsError struct{}
|
||||||
|
|
||||||
|
func (cae CircularAccountsError) Error() string {
|
||||||
|
return "Would result in circular account relationship"
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccountStore interface {
|
||||||
|
InsertAccount(account *models.Account) error
|
||||||
|
GetAccount(accountid int64, userid int64) (*models.Account, error)
|
||||||
|
GetAccounts(userid int64) (*[]*models.Account, error)
|
||||||
|
FindMatchingAccounts(account *models.Account) (*[]*models.Account, error)
|
||||||
|
UpdateAccount(account *models.Account) error
|
||||||
|
DeleteAccount(account *models.Account) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccountMissingError struct{}
|
||||||
|
|
||||||
|
func (ame AccountMissingError) Error() string {
|
||||||
|
return "Account missing"
|
||||||
|
}
|
||||||
|
|
||||||
|
type TransactionStore interface {
|
||||||
|
SplitExists(s *models.Split) (bool, error)
|
||||||
|
InsertTransaction(t *models.Transaction, user *models.User) error
|
||||||
|
GetTransaction(transactionid int64, userid int64) (*models.Transaction, error)
|
||||||
|
GetTransactions(userid int64) (*[]*models.Transaction, error)
|
||||||
|
UpdateTransaction(t *models.Transaction, user *models.User) error
|
||||||
|
DeleteTransaction(t *models.Transaction, user *models.User) error
|
||||||
|
GetAccountSplits(user *models.User, accountid int64) (*[]*models.Split, error)
|
||||||
|
GetAccountSplitsDate(user *models.User, accountid int64, date *time.Time) (*[]*models.Split, error)
|
||||||
|
GetAccountSplitsDateRange(user *models.User, accountid int64, begin, end *time.Time) (*[]*models.Split, error)
|
||||||
|
GetAccountTransactions(user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReportStore interface {
|
||||||
|
InsertReport(report *models.Report) error
|
||||||
|
GetReport(reportid int64, userid int64) (*models.Report, error)
|
||||||
|
GetReports(userid int64) (*[]*models.Report, error)
|
||||||
|
UpdateReport(report *models.Report) error
|
||||||
|
DeleteReport(report *models.Report) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tx interface {
|
||||||
|
Commit() error
|
||||||
|
Rollback() error
|
||||||
|
|
||||||
|
UserStore
|
||||||
|
SessionStore
|
||||||
|
SecurityStore
|
||||||
|
PriceStore
|
||||||
|
AccountStore
|
||||||
|
TransactionStore
|
||||||
|
ReportStore
|
||||||
|
}
|
||||||
|
|
||||||
|
type Store interface {
|
||||||
|
Empty() error
|
||||||
|
Begin() (Tx, error)
|
||||||
|
Close() error
|
||||||
|
}
|
15
main.go
15
main.go
@ -3,11 +3,10 @@ package main
|
|||||||
//go:generate make
|
//go:generate make
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"flag"
|
"flag"
|
||||||
"github.com/aclindsa/moneygo/internal/config"
|
"github.com/aclindsa/moneygo/internal/config"
|
||||||
"github.com/aclindsa/moneygo/internal/db"
|
|
||||||
"github.com/aclindsa/moneygo/internal/handlers"
|
"github.com/aclindsa/moneygo/internal/handlers"
|
||||||
|
"github.com/aclindsa/moneygo/internal/store/db"
|
||||||
"github.com/kabukky/httpscerts"
|
"github.com/kabukky/httpscerts"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@ -67,21 +66,15 @@ func staticHandler(w http.ResponseWriter, r *http.Request, basedir string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
dsn := db.GetDSN(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN)
|
db, err := db.GetStore(cfg.MoneyGo.DBType, cfg.MoneyGo.DSN)
|
||||||
database, err := sql.Open(cfg.MoneyGo.DBType.String(), dsn)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
defer database.Close()
|
|
||||||
|
|
||||||
dbmap, err := db.GetDbMap(database, cfg.MoneyGo.DBType)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
// Get ServeMux for API and add our own handlers for files
|
// Get ServeMux for API and add our own handlers for files
|
||||||
servemux := http.NewServeMux()
|
servemux := http.NewServeMux()
|
||||||
servemux.Handle("/v1/", &handlers.APIHandler{DB: dbmap})
|
servemux.Handle("/v1/", &handlers.APIHandler{Store: db})
|
||||||
servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir))
|
servemux.HandleFunc("/", FileHandlerFunc(rootHandler, cfg.MoneyGo.Basedir))
|
||||||
servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir))
|
servemux.HandleFunc("/static/", FileHandlerFunc(staticHandler, cfg.MoneyGo.Basedir))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user