From d059cd19ee675a625dddec63274d8071ad65b9b0 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Fri, 27 Jan 2017 21:50:02 -0500 Subject: [PATCH] lua: Add account balances --- accounts_lua.go | 22 +++++++++++++ balance_lua.go | 56 +++++++++++++++++++++++++++++++++ reports.go | 2 ++ transactions.go | 82 +++++++++++++++++++++++++++++++++++-------------- 4 files changed, 139 insertions(+), 23 deletions(-) create mode 100644 balance_lua.go diff --git a/accounts_lua.go b/accounts_lua.go index 303666b..f84bbe6 100644 --- a/accounts_lua.go +++ b/accounts_lua.go @@ -120,6 +120,28 @@ func luaAccount__index(L *lua.LState) int { L.Push(lua.LString(a.Name)) case "Type", "type": L.Push(lua.LNumber(float64(a.Type))) + case "Balance", "balance": + ctx := L.Context() + user, ok := ctx.Value(userContextKey).(*User) + if !ok { + panic("Couldn't find User in lua's Context") + } + security_map, err := luaContextGetSecurities(L) + if err != nil { + panic("account.security couldn't fetch securities") + } + security, ok := security_map[a.SecurityId] + if !ok { + panic("SecurityId not in lua security_map") + } + rat, err := GetAccountBalance(user, a.AccountId) + if err != nil { + panic("Failed to GetAccountBalance:" + err.Error()) + } + var b Balance + b.Security = security + b.Amount = rat + L.Push(BalanceToLua(L, &b)) default: L.ArgError(2, "unexpected account attribute: "+field) } diff --git a/balance_lua.go b/balance_lua.go new file mode 100644 index 0000000..907da6a --- /dev/null +++ b/balance_lua.go @@ -0,0 +1,56 @@ +package main + +import ( + "github.com/yuin/gopher-lua" + "math/big" +) + +type Balance struct { + Security *Security + Amount *big.Rat +} + +const luaBalanceTypeName = "balance" + +// Registers my balance type to given L. +func luaRegisterBalances(L *lua.LState) { + mt := L.NewTypeMetatable(luaBalanceTypeName) + L.SetGlobal("balance", mt) + L.SetField(mt, "__tostring", L.NewFunction(luaBalance__tostring)) + L.SetField(mt, "__eq", L.NewFunction(luaBalance__eq)) + L.SetField(mt, "__metatable", lua.LString("protected")) +} + +func BalanceToLua(L *lua.LState, balance *Balance) *lua.LUserData { + ud := L.NewUserData() + ud.Value = balance + L.SetMetatable(ud, L.GetTypeMetatable(luaBalanceTypeName)) + return ud +} + +// Checks whether the first lua argument is a *LUserData with *Balance and returns this *Balance. +func luaCheckBalance(L *lua.LState, n int) *Balance { + ud := L.CheckUserData(n) + if balance, ok := ud.Value.(*Balance); ok { + return balance + } + L.ArgError(n, "balance expected") + return nil +} + +func luaBalance__tostring(L *lua.LState) int { + b := luaCheckBalance(L, 1) + + L.Push(lua.LString(b.Security.Symbol + b.Amount.FloatString(b.Security.Precision))) + + return 1 +} + +func luaBalance__eq(L *lua.LState) int { + a := luaCheckBalance(L, 1) + b := luaCheckBalance(L, 2) + + L.Push(lua.LBool(a.Security.SecurityId == b.Security.SecurityId && a.Amount.Cmp(b.Amount) == 0)) + + return 1 +} diff --git a/reports.go b/reports.go index d8e18de..b109081 100644 --- a/reports.go +++ b/reports.go @@ -15,6 +15,7 @@ const ( userContextKey key = iota accountsContextKey securitiesContextKey + balanceContextKey ) const luaTimeoutSeconds time.Duration = 5 // maximum time a lua request can run for @@ -58,6 +59,7 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) { luaRegisterAccounts(L) luaRegisterSecurities(L) + luaRegisterBalances(L) err := L.DoString(`accounts = account.get_all() last_parent = nil diff --git a/transactions.go b/transactions.go index 1d7e0f2..4a5c428 100644 --- a/transactions.go +++ b/transactions.go @@ -617,6 +617,59 @@ func TransactionHandler(w http.ResponseWriter, r *http.Request) { } } +func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int64, transactions []Transaction) (*big.Rat, error) { + var pageDifference, tmp big.Rat + for i := range transactions { + _, err := transaction.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) + if err != nil { + return nil, err + } + + // Sum up the amounts from the splits we're returning so we can return + // an ending balance + for j := range transactions[i].Splits { + if transactions[i].Splits[j].AccountId == accountid { + rat_amount, err := GetBigAmount(transactions[i].Splits[j].Amount) + if err != nil { + return nil, err + } + tmp.Add(&pageDifference, rat_amount) + pageDifference.Set(&tmp) + } + } + } + return &pageDifference, nil +} + +func GetAccountBalance(user *User, accountid int64) (*big.Rat, error) { + var transactions []Transaction + transaction, err := DB.Begin() + if err != nil { + return nil, err + } + + sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + _, err = transaction.Select(&transactions, sql, user.UserId, accountid) + if err != nil { + transaction.Rollback() + return nil, err + } + + pageDifference, err := TransactionsBalanceDifference(transaction, accountid, transactions) + if err != nil { + transaction.Rollback() + return nil, err + } + + err = transaction.Commit() + if err != nil { + transaction.Rollback() + return nil, err + } + + return pageDifference, nil +} + func GetAccountTransactions(user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { var transactions []Transaction var atl AccountTransactionsList @@ -663,27 +716,10 @@ func GetAccountTransactions(user *User, accountid int64, sort string, page uint6 } atl.Transactions = &transactions - var pageDifference, tmp big.Rat - for i := range transactions { - _, err = transaction.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) - if err != nil { - transaction.Rollback() - return nil, err - } - - // Sum up the amounts from the splits we're returning so we can return - // an ending balance - for j := range transactions[i].Splits { - if transactions[i].Splits[j].AccountId == accountid { - rat_amount, err := GetBigAmount(transactions[i].Splits[j].Amount) - if err != nil { - transaction.Rollback() - return nil, err - } - tmp.Add(&pageDifference, rat_amount) - pageDifference.Set(&tmp) - } - } + pageDifference, err := TransactionsBalanceDifference(transaction, accountid, transactions) + if err != nil { + transaction.Rollback() + return nil, err } count, err := transaction.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid) @@ -711,7 +747,7 @@ func GetAccountTransactions(user *User, accountid int64, sort string, page uint6 return nil, err } - var balance big.Rat + var tmp, balance big.Rat for _, amount := range amounts { rat_amount, err := GetBigAmount(amount) if err != nil { @@ -722,7 +758,7 @@ func GetAccountTransactions(user *User, accountid int64, sort string, page uint6 balance.Set(&tmp) } atl.BeginningBalance = balance.FloatString(security.Precision) - atl.EndingBalance = tmp.Add(&balance, &pageDifference).FloatString(security.Precision) + atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision) err = transaction.Commit() if err != nil {