From cec769b6b8dc09e5a0ffe65390365900d3c280b2 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Mon, 30 Jan 2017 21:04:18 -0500 Subject: [PATCH] lua: Query account balances at dates --- accounts_lua.go | 64 ++++++++++++++++++++++++++++++++----------------- balance_lua.go | 3 ++- date_lua.go | 10 ++++++++ transactions.go | 58 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 112 insertions(+), 23 deletions(-) diff --git a/accounts_lua.go b/accounts_lua.go index f84bbe6..59b3bf6 100644 --- a/accounts_lua.go +++ b/accounts_lua.go @@ -4,6 +4,7 @@ import ( "context" "errors" "github.com/yuin/gopher-lua" + "math/big" ) const luaAccountTypeName = "account" @@ -61,9 +62,9 @@ func luaRegisterAccounts(L *lua.LState) { L.SetField(mt, "__tostring", L.NewFunction(luaAccount__tostring)) L.SetField(mt, "__eq", L.NewFunction(luaAccount__eq)) L.SetField(mt, "__metatable", lua.LString("protected")) + getAccountsFn := L.NewFunction(luaGetAccounts) L.SetField(mt, "get_all", getAccountsFn) - // also register the get_accounts function as a global in its own right L.SetGlobal("get_accounts", getAccountsFn) } @@ -121,27 +122,7 @@ func luaAccount__index(L *lua.LState) int { case "Type", "type": L.Push(lua.LNumber(float64(a.Type))) case "Balance", "balance": - ctx := L.Context() - user, ok := ctx.Value(userContextKey).(*User) - if !ok { - panic("Couldn't find User in lua's Context") - } - security_map, err := luaContextGetSecurities(L) - if err != nil { - panic("account.security couldn't fetch securities") - } - security, ok := security_map[a.SecurityId] - if !ok { - panic("SecurityId not in lua security_map") - } - rat, err := GetAccountBalance(user, a.AccountId) - if err != nil { - panic("Failed to GetAccountBalance:" + err.Error()) - } - var b Balance - b.Security = security - b.Amount = rat - L.Push(BalanceToLua(L, &b)) + L.Push(L.NewFunction(luaAccountBalance)) default: L.ArgError(2, "unexpected account attribute: "+field) } @@ -149,6 +130,45 @@ func luaAccount__index(L *lua.LState) int { return 1 } +func luaAccountBalance(L *lua.LState) int { + a := luaCheckAccount(L, 1) + + ctx := L.Context() + user, ok := ctx.Value(userContextKey).(*User) + if !ok { + panic("Couldn't find User in lua's Context") + } + security_map, err := luaContextGetSecurities(L) + if err != nil { + panic("account.security couldn't fetch securities") + } + security, ok := security_map[a.SecurityId] + if !ok { + panic("SecurityId not in lua security_map") + } + date := luaWeakCheckTime(L, 2) + var b Balance + var rat *big.Rat + if date != nil { + end := luaWeakCheckTime(L, 3) + if end != nil { + rat, err = GetAccountBalanceDateRange(user, a.AccountId, date, end) + } else { + rat, err = GetAccountBalanceDate(user, a.AccountId, date) + } + } else { + rat, err = GetAccountBalance(user, a.AccountId) + } + if err != nil { + panic("Failed to GetAccountBalance:" + err.Error()) + } + b.Amount = rat + b.Security = security + L.Push(BalanceToLua(L, &b)) + + return 1 +} + func luaAccount__tostring(L *lua.LState) int { a := luaCheckAccount(L, 1) diff --git a/balance_lua.go b/balance_lua.go index 64886fd..220edb0 100644 --- a/balance_lua.go +++ b/balance_lua.go @@ -52,6 +52,7 @@ func luaWeakCheckBalance(L *lua.LState, n int) *Balance { if balance, ok := ud.Value.(*Balance); ok { return balance } + L.ArgError(n, "balance expected") } return nil } @@ -109,7 +110,7 @@ func luaBalance__index(L *lua.LState) int { func luaBalance__tostring(L *lua.LState) int { b := luaCheckBalance(L, 1) - L.Push(lua.LString(b.Security.Symbol + b.Amount.FloatString(b.Security.Precision))) + L.Push(lua.LString(b.Security.Symbol + " " + b.Amount.FloatString(b.Security.Precision))) return 1 } diff --git a/date_lua.go b/date_lua.go index 31dfd0a..bbb57ff 100644 --- a/date_lua.go +++ b/date_lua.go @@ -37,6 +37,16 @@ func luaCheckTime(L *lua.LState, n int) *time.Time { return nil } +func luaWeakCheckTime(L *lua.LState, n int) *time.Time { + v := L.Get(n) + if ud, ok := v.(*lua.LUserData); ok { + if date, ok := ud.Value.(*time.Time); ok { + return date + } + } + return nil +} + func luaWeakCheckTableFieldInt(L *lua.LState, T *lua.LTable, n int, name string, def int) int { lv := T.RawGetString(name) if lv == lua.LNil { diff --git a/transactions.go b/transactions.go index 4a5c428..6df36fe 100644 --- a/transactions.go +++ b/transactions.go @@ -670,6 +670,64 @@ func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) { return pageDifference, nil } +func GetAccountBalanceDate(user *User, accountid int64, date *time.Time) (*big.Rat, error) { + var transactions []Transaction + transaction, err := DB.Begin() + if err != nil { + return nil, err + } + + sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=? AND transactions.Date < ?" + _, err = transaction.Select(&transactions, sql, user.UserId, accountid, date) + if err != nil { + transaction.Rollback() + return nil, err + } + + pageDifference, err := TransactionsBalanceDifference(transaction, accountid, transactions) + if err != nil { + transaction.Rollback() + return nil, err + } + + err = transaction.Commit() + if err != nil { + transaction.Rollback() + return nil, err + } + + return pageDifference, nil +} + +func GetAccountBalanceDateRange(user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { + var transactions []Transaction + transaction, err := DB.Begin() + if err != nil { + return nil, err + } + + sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=? AND transactions.Date >= ? AND transactions.Date < ?" + _, err = transaction.Select(&transactions, sql, user.UserId, accountid, begin, end) + if err != nil { + transaction.Rollback() + return nil, err + } + + pageDifference, err := TransactionsBalanceDifference(transaction, accountid, transactions) + if err != nil { + transaction.Rollback() + return nil, err + } + + err = transaction.Commit() + if err != nil { + transaction.Rollback() + return nil, err + } + + return pageDifference, nil +} + func GetAccountTransactions(user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { var transactions []Transaction var atl AccountTransactionsList