1
0
mirror of https://github.com/aclindsa/moneygo.git synced 2025-06-13 13:39:23 -04:00

Split Lua reports into own package

This commit is contained in:
2017-12-10 20:50:37 -05:00
parent ac8afec6c1
commit d5bea1102d
11 changed files with 138 additions and 130 deletions

View File

@ -0,0 +1,238 @@
package reports
import (
"context"
"errors"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"github.com/yuin/gopher-lua"
"math/big"
"strings"
)
const luaAccountTypeName = "account"
func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) {
var account_map map[int64]*models.Account
ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(store.Tx)
if !ok {
return nil, errors.New("Couldn't find tx in lua's Context")
}
account_map, ok = ctx.Value(accountsContextKey).(map[int64]*models.Account)
if !ok {
user, ok := ctx.Value(userContextKey).(*models.User)
if !ok {
return nil, errors.New("Couldn't find User in lua's Context")
}
accounts, err := tx.GetAccounts(user.UserId)
if err != nil {
return nil, err
}
account_map = make(map[int64]*models.Account)
for i := range *accounts {
account_map[(*accounts)[i].AccountId] = (*accounts)[i]
}
ctx = context.WithValue(ctx, accountsContextKey, account_map)
L.SetContext(ctx)
}
return account_map, nil
}
func luaGetAccounts(L *lua.LState) int {
account_map, err := luaContextGetAccounts(L)
if err != nil {
panic("luaGetAccounts couldn't fetch accounts")
}
table := L.NewTable()
for accountid := range account_map {
table.RawSetInt(int(accountid), AccountToLua(L, account_map[accountid]))
}
L.Push(table)
return 1
}
func luaRegisterAccounts(L *lua.LState) {
mt := L.NewTypeMetatable(luaAccountTypeName)
L.SetGlobal("account", mt)
L.SetField(mt, "__index", L.NewFunction(luaAccount__index))
L.SetField(mt, "__tostring", L.NewFunction(luaAccount__tostring))
L.SetField(mt, "__eq", L.NewFunction(luaAccount__eq))
L.SetField(mt, "__metatable", lua.LString("protected"))
for _, accttype := range models.AccountTypes {
L.SetField(mt, accttype.String(), lua.LNumber(float64(accttype)))
}
getAccountsFn := L.NewFunction(luaGetAccounts)
L.SetField(mt, "get_all", getAccountsFn)
// also register the get_accounts function as a global in its own right
L.SetGlobal("get_accounts", getAccountsFn)
}
func AccountToLua(L *lua.LState, account *models.Account) *lua.LUserData {
ud := L.NewUserData()
ud.Value = account
L.SetMetatable(ud, L.GetTypeMetatable(luaAccountTypeName))
return ud
}
// Checks whether the first lua argument is a *LUserData with *Account and returns this *Account.
func luaCheckAccount(L *lua.LState, n int) *models.Account {
ud := L.CheckUserData(n)
if account, ok := ud.Value.(*models.Account); ok {
return account
}
L.ArgError(n, "account expected")
return nil
}
func luaAccount__index(L *lua.LState) int {
a := luaCheckAccount(L, 1)
field := L.CheckString(2)
switch field {
case "AccountId", "accountid":
L.Push(lua.LNumber(float64(a.AccountId)))
case "Security", "security":
security_map, err := luaContextGetSecurities(L)
if err != nil {
panic("account.security couldn't fetch securities")
}
if security, ok := security_map[a.SecurityId]; ok {
L.Push(SecurityToLua(L, security))
} else {
panic("SecurityId not in lua security_map")
}
case "SecurityId", "securityid":
L.Push(lua.LNumber(float64(a.SecurityId)))
case "Parent", "parent", "ParentAccount", "parentaccount":
if a.ParentAccountId == -1 {
L.Push(lua.LNil)
} else {
account_map, err := luaContextGetAccounts(L)
if err != nil {
panic("account.parent couldn't fetch accounts")
}
if parent, ok := account_map[a.ParentAccountId]; ok {
L.Push(AccountToLua(L, parent))
} else {
panic("ParentAccountId not in lua account_map")
}
}
case "Name", "name":
L.Push(lua.LString(a.Name))
case "Type", "type":
L.Push(lua.LNumber(float64(a.Type)))
case "TypeName", "Typename":
L.Push(lua.LString(a.Type.String()))
case "typename":
L.Push(lua.LString(strings.ToLower(a.Type.String())))
case "Balance", "balance":
L.Push(L.NewFunction(luaAccountBalance))
default:
L.ArgError(2, "unexpected account attribute: "+field)
}
return 1
}
func balanceFromSplits(splits *[]*models.Split) (*big.Rat, error) {
var balance, tmp big.Rat
for _, s := range *splits {
rat_amount, err := models.GetBigAmount(s.Amount)
if err != nil {
return nil, err
}
tmp.Add(&balance, rat_amount)
balance.Set(&tmp)
}
return &balance, nil
}
func luaAccountBalance(L *lua.LState) int {
a := luaCheckAccount(L, 1)
ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(store.Tx)
if !ok {
panic("Couldn't find tx in lua's Context")
}
user, ok := ctx.Value(userContextKey).(*models.User)
if !ok {
panic("Couldn't find User in lua's Context")
}
security_map, err := luaContextGetSecurities(L)
if err != nil {
panic("account.security couldn't fetch securities")
}
security, ok := security_map[a.SecurityId]
if !ok {
panic("SecurityId not in lua security_map")
}
date := luaWeakCheckTime(L, 2)
var splits *[]*models.Split
if date != nil {
end := luaWeakCheckTime(L, 3)
if end != nil {
splits, err = tx.GetAccountSplitsDateRange(user, a.AccountId, date, end)
} else {
splits, err = tx.GetAccountSplitsDate(user, a.AccountId, date)
}
} else {
splits, err = tx.GetAccountSplits(user, a.AccountId)
}
if err != nil {
panic("Failed to fetch splits for account:" + err.Error())
}
rat, err := balanceFromSplits(splits)
if err != nil {
panic("Failed to calculate balance for account:" + err.Error())
}
b := &Balance{
Amount: rat,
Security: security,
}
L.Push(BalanceToLua(L, b))
return 1
}
func luaAccount__tostring(L *lua.LState) int {
a := luaCheckAccount(L, 1)
account_map, err := luaContextGetAccounts(L)
if err != nil {
panic("luaGetAccounts couldn't fetch accounts")
}
full_name := a.Name
for a.ParentAccountId != -1 {
a = account_map[a.ParentAccountId]
full_name = a.Name + "/" + full_name
}
L.Push(lua.LString(full_name))
return 1
}
func luaAccount__eq(L *lua.LState) int {
a := luaCheckAccount(L, 1)
b := luaCheckAccount(L, 2)
L.Push(lua.LBool(a.AccountId == b.AccountId))
return 1
}

225
internal/reports/balance.go Normal file
View File

@ -0,0 +1,225 @@
package reports
import (
"github.com/aclindsa/moneygo/internal/models"
"github.com/yuin/gopher-lua"
"math/big"
)
type Balance struct {
Security *models.Security
Amount *big.Rat
}
const luaBalanceTypeName = "balance"
func luaRegisterBalances(L *lua.LState) {
mt := L.NewTypeMetatable(luaBalanceTypeName)
L.SetGlobal("balance", mt)
L.SetField(mt, "__index", L.NewFunction(luaBalance__index))
L.SetField(mt, "__tostring", L.NewFunction(luaBalance__tostring))
L.SetField(mt, "__eq", L.NewFunction(luaBalance__eq))
L.SetField(mt, "__lt", L.NewFunction(luaBalance__lt))
L.SetField(mt, "__le", L.NewFunction(luaBalance__le))
L.SetField(mt, "__add", L.NewFunction(luaBalance__add))
L.SetField(mt, "__sub", L.NewFunction(luaBalance__sub))
L.SetField(mt, "__mul", L.NewFunction(luaBalance__mul))
L.SetField(mt, "__div", L.NewFunction(luaBalance__div))
L.SetField(mt, "__unm", L.NewFunction(luaBalance__unm))
L.SetField(mt, "__metatable", lua.LString("protected"))
}
func BalanceToLua(L *lua.LState, balance *Balance) *lua.LUserData {
ud := L.NewUserData()
ud.Value = balance
L.SetMetatable(ud, L.GetTypeMetatable(luaBalanceTypeName))
return ud
}
// Checks whether the first lua argument is a *LUserData with *Balance and returns this *Balance.
func luaCheckBalance(L *lua.LState, n int) *Balance {
ud := L.CheckUserData(n)
if balance, ok := ud.Value.(*Balance); ok {
return balance
}
L.ArgError(n, "balance expected")
return nil
}
func luaWeakCheckBalance(L *lua.LState, n int) *Balance {
v := L.Get(n)
if ud, ok := v.(*lua.LUserData); ok {
if balance, ok := ud.Value.(*Balance); ok {
return balance
}
L.ArgError(n, "balance expected")
}
return nil
}
func luaGetBalanceOperands(L *lua.LState, m int, n int) (*Balance, *Balance) {
bm := luaWeakCheckBalance(L, m)
bn := luaWeakCheckBalance(L, n)
if bm != nil && bn != nil {
return bm, bn
} else if bm != nil {
nn := L.CheckNumber(n)
var balance Balance
var rat big.Rat
balance.Security = bm.Security
balance.Amount = rat.SetFloat64(float64(nn))
if balance.Amount == nil {
L.ArgError(n, "non-finite float invalid for operand to balance arithemetic")
return nil, nil
}
return bm, &balance
} else if bn != nil {
nm := L.CheckNumber(m)
var balance Balance
var rat big.Rat
balance.Security = bn.Security
balance.Amount = rat.SetFloat64(float64(nm))
if balance.Amount == nil {
L.ArgError(m, "non-finite float invalid for operand to balance arithemetic")
return nil, nil
}
return &balance, bn
}
L.ArgError(m, "balance expected")
return nil, nil
}
func luaBalance__index(L *lua.LState) int {
a := luaCheckBalance(L, 1)
field := L.CheckString(2)
switch field {
case "Security", "security":
L.Push(SecurityToLua(L, a.Security))
case "Amount", "amount":
float, _ := a.Amount.Float64()
L.Push(lua.LNumber(float))
default:
L.ArgError(2, "unexpected balance attribute: "+field)
}
return 1
}
func luaBalance__tostring(L *lua.LState) int {
b := luaCheckBalance(L, 1)
L.Push(lua.LString(b.Security.Symbol + " " + b.Amount.FloatString(b.Security.Precision)))
return 1
}
func luaBalance__eq(L *lua.LState) int {
a := luaCheckBalance(L, 1)
b := luaCheckBalance(L, 2)
L.Push(lua.LBool(a.Security.SecurityId == b.Security.SecurityId && a.Amount.Cmp(b.Amount) == 0))
return 1
}
func luaBalance__lt(L *lua.LState) int {
a := luaCheckBalance(L, 1)
b := luaCheckBalance(L, 2)
if a.Security.SecurityId != b.Security.SecurityId {
L.ArgError(2, "Can't compare balances with different securities")
}
L.Push(lua.LBool(a.Amount.Cmp(b.Amount) < 0))
return 1
}
func luaBalance__le(L *lua.LState) int {
a := luaCheckBalance(L, 1)
b := luaCheckBalance(L, 2)
if a.Security.SecurityId != b.Security.SecurityId {
L.ArgError(2, "Can't compare balances with different securities")
}
L.Push(lua.LBool(a.Amount.Cmp(b.Amount) <= 0))
return 1
}
func luaBalance__add(L *lua.LState) int {
a, b := luaGetBalanceOperands(L, 1, 2)
if a.Security.SecurityId != b.Security.SecurityId {
L.ArgError(2, "Can't add balances with different securities")
}
var balance Balance
var rat big.Rat
balance.Security = a.Security
balance.Amount = rat.Add(a.Amount, b.Amount)
L.Push(BalanceToLua(L, &balance))
return 1
}
func luaBalance__sub(L *lua.LState) int {
a, b := luaGetBalanceOperands(L, 1, 2)
if a.Security.SecurityId != b.Security.SecurityId {
L.ArgError(2, "Can't subtract balances with different securities")
}
var balance Balance
var rat big.Rat
balance.Security = a.Security
balance.Amount = rat.Sub(a.Amount, b.Amount)
L.Push(BalanceToLua(L, &balance))
return 1
}
func luaBalance__mul(L *lua.LState) int {
a, b := luaGetBalanceOperands(L, 1, 2)
if a.Security.SecurityId != b.Security.SecurityId {
L.ArgError(2, "Can't multiply balances with different securities")
}
var balance Balance
var rat big.Rat
balance.Security = a.Security
balance.Amount = rat.Mul(a.Amount, b.Amount)
L.Push(BalanceToLua(L, &balance))
return 1
}
func luaBalance__div(L *lua.LState) int {
a, b := luaGetBalanceOperands(L, 1, 2)
if a.Security.SecurityId != b.Security.SecurityId {
L.ArgError(2, "Can't divide balances with different securities")
}
var balance Balance
var rat big.Rat
balance.Security = a.Security
balance.Amount = rat.Quo(a.Amount, b.Amount)
L.Push(BalanceToLua(L, &balance))
return 1
}
func luaBalance__unm(L *lua.LState) int {
b := luaCheckBalance(L, 1)
var balance Balance
var rat big.Rat
balance.Security = b.Security
balance.Amount = rat.Neg(b.Amount)
L.Push(BalanceToLua(L, &balance))
return 1
}

170
internal/reports/date.go Normal file
View File

@ -0,0 +1,170 @@
package reports
import (
"github.com/yuin/gopher-lua"
"time"
)
const luaDateTypeName = "date"
const timeFormat = "2006-01-02"
func luaRegisterDates(L *lua.LState) {
mt := L.NewTypeMetatable(luaDateTypeName)
L.SetGlobal("date", mt)
L.SetField(mt, "new", L.NewFunction(luaDateNew))
L.SetField(mt, "now", L.NewFunction(luaDateNow))
L.SetField(mt, "__index", L.NewFunction(luaDate__index))
L.SetField(mt, "__tostring", L.NewFunction(luaDate__tostring))
L.SetField(mt, "__eq", L.NewFunction(luaDate__eq))
L.SetField(mt, "__lt", L.NewFunction(luaDate__lt))
L.SetField(mt, "__le", L.NewFunction(luaDate__le))
L.SetField(mt, "__add", L.NewFunction(luaDate__add))
L.SetField(mt, "__sub", L.NewFunction(luaDate__sub))
L.SetField(mt, "__metatable", lua.LString("protected"))
}
func TimeToLua(L *lua.LState, date *time.Time) *lua.LUserData {
ud := L.NewUserData()
ud.Value = date
L.SetMetatable(ud, L.GetTypeMetatable(luaDateTypeName))
return ud
}
// Checks whether the first lua argument is a *LUserData with *Time and returns this *Time.
func luaCheckTime(L *lua.LState, n int) *time.Time {
ud := L.CheckUserData(n)
if date, ok := ud.Value.(*time.Time); ok {
return date
}
L.ArgError(n, "date expected")
return nil
}
func luaWeakCheckTime(L *lua.LState, n int) *time.Time {
v := L.Get(n)
if ud, ok := v.(*lua.LUserData); ok {
if date, ok := ud.Value.(*time.Time); ok {
return date
}
}
return nil
}
func luaWeakCheckTableFieldInt(L *lua.LState, T *lua.LTable, n int, name string, def int) int {
lv := T.RawGetString(name)
if lv == lua.LNil {
return def
}
if i, ok := lv.(lua.LNumber); ok {
return int(i)
}
L.ArgError(n, "table field '"+name+"' expected to be int")
return def
}
func luaDateNew(L *lua.LState) int {
// TODO make this track the user's timezone
v := L.Get(1)
if s, ok := v.(lua.LString); ok {
date, err := time.ParseInLocation(timeFormat, s.String(), time.Local)
if err != nil {
L.ArgError(1, "error parsing date string: "+err.Error())
return 0
}
L.Push(TimeToLua(L, &date))
return 1
}
var year, month, day int
if t, ok := v.(*lua.LTable); ok {
year = luaWeakCheckTableFieldInt(L, t, 1, "year", 0)
month = luaWeakCheckTableFieldInt(L, t, 1, "month", 1)
day = luaWeakCheckTableFieldInt(L, t, 1, "day", 1)
} else {
year = L.CheckInt(1)
month = L.CheckInt(2)
day = L.CheckInt(3)
}
date := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.Local)
L.Push(TimeToLua(L, &date))
return 1
}
func luaDateNow(L *lua.LState) int {
now := time.Now()
date := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local)
L.Push(TimeToLua(L, &date))
return 1
}
func luaDate__index(L *lua.LState) int {
d := luaCheckTime(L, 1)
field := L.CheckString(2)
switch field {
case "Year", "year":
L.Push(lua.LNumber(d.Year()))
case "Month", "month":
L.Push(lua.LNumber(float64(d.Month())))
case "Day", "day":
L.Push(lua.LNumber(float64(d.Day())))
default:
L.ArgError(2, "unexpected date attribute: "+field)
}
return 1
}
func luaDate__tostring(L *lua.LState) int {
a := luaCheckTime(L, 1)
L.Push(lua.LString(a.Format(timeFormat)))
return 1
}
func luaDate__eq(L *lua.LState) int {
a := luaCheckTime(L, 1)
b := luaCheckTime(L, 2)
L.Push(lua.LBool(a.Equal(*b)))
return 1
}
func luaDate__lt(L *lua.LState) int {
a := luaCheckTime(L, 1)
b := luaCheckTime(L, 2)
L.Push(lua.LBool(a.Before(*b)))
return 1
}
func luaDate__le(L *lua.LState) int {
a := luaCheckTime(L, 1)
b := luaCheckTime(L, 2)
L.Push(lua.LBool(a.Equal(*b) || a.Before(*b)))
return 1
}
func luaDate__add(L *lua.LState) int {
a := luaCheckTime(L, 1)
b := luaCheckTime(L, 2)
date := a.AddDate(b.Year(), int(b.Month()), b.Day())
L.Push(TimeToLua(L, &date))
return 1
}
func luaDate__sub(L *lua.LState) int {
a := luaCheckTime(L, 1)
b := luaCheckTime(L, 2)
date := a.AddDate(-b.Year(), -int(b.Month()), -b.Day())
L.Push(TimeToLua(L, &date))
return 1
}

View File

@ -0,0 +1,92 @@
package reports
import (
"github.com/aclindsa/moneygo/internal/models"
"github.com/yuin/gopher-lua"
)
const luaPriceTypeName = "price"
func luaRegisterPrices(L *lua.LState) {
mt := L.NewTypeMetatable(luaPriceTypeName)
L.SetGlobal("price", mt)
L.SetField(mt, "__index", L.NewFunction(luaPrice__index))
L.SetField(mt, "__tostring", L.NewFunction(luaPrice__tostring))
L.SetField(mt, "__metatable", lua.LString("protected"))
}
func PriceToLua(L *lua.LState, price *models.Price) *lua.LUserData {
ud := L.NewUserData()
ud.Value = price
L.SetMetatable(ud, L.GetTypeMetatable(luaPriceTypeName))
return ud
}
// Checks whether the first lua argument is a *LUserData with *Price and returns this *Price.
func luaCheckPrice(L *lua.LState, n int) *models.Price {
ud := L.CheckUserData(n)
if price, ok := ud.Value.(*models.Price); ok {
return price
}
L.ArgError(n, "price expected")
return nil
}
func luaPrice__index(L *lua.LState) int {
p := luaCheckPrice(L, 1)
field := L.CheckString(2)
switch field {
case "PriceId", "priceid":
L.Push(lua.LNumber(float64(p.PriceId)))
case "Security", "security":
security_map, err := luaContextGetSecurities(L)
if err != nil {
panic("luaContextGetSecurities couldn't fetch securities")
}
s, ok := security_map[p.SecurityId]
if !ok {
panic("Price's security not found for user")
}
L.Push(SecurityToLua(L, s))
case "Currency", "currency":
security_map, err := luaContextGetSecurities(L)
if err != nil {
panic("luaContextGetSecurities couldn't fetch securities")
}
c, ok := security_map[p.CurrencyId]
if !ok {
panic("Price's currency not found for user")
}
L.Push(SecurityToLua(L, c))
case "Value", "value":
amt, err := models.GetBigAmount(p.Value)
if err != nil {
panic(err)
}
float, _ := amt.Float64()
L.Push(lua.LNumber(float))
default:
L.ArgError(2, "unexpected price attribute: "+field)
}
return 1
}
func luaPrice__tostring(L *lua.LState) int {
p := luaCheckPrice(L, 1)
security_map, err := luaContextGetSecurities(L)
if err != nil {
panic("luaContextGetSecurities couldn't fetch securities")
}
s, ok1 := security_map[p.SecurityId]
c, ok2 := security_map[p.CurrencyId]
if !ok1 || !ok2 {
panic("Price's currency or security not found for user")
}
L.Push(lua.LString(p.Value + " " + c.Symbol + " (" + s.Symbol + ")"))
return 1
}

View File

@ -0,0 +1,88 @@
package reports
import (
"context"
"errors"
"fmt"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"github.com/yuin/gopher-lua"
"time"
)
//type and value to store user in lua's Context
type key int
const (
userContextKey key = iota
accountsContextKey
securitiesContextKey
balanceContextKey
dbContextKey
)
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
func RunReport(tx store.Tx, user *models.User, report *models.Report) (*models.Tabulation, error) {
// Create a new LState without opening the default libs for security
L := lua.NewState(lua.Options{SkipOpenLibs: true})
defer L.Close()
// Create a new context holding the current user with a timeout
ctx := context.WithValue(context.Background(), userContextKey, user)
ctx = context.WithValue(ctx, dbContextKey, tx)
ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second)
defer cancel()
L.SetContext(ctx)
for _, pair := range []struct {
n string
f lua.LGFunction
}{
{lua.LoadLibName, lua.OpenPackage}, // Must be first
{lua.BaseLibName, lua.OpenBase},
{lua.TabLibName, lua.OpenTable},
{lua.StringLibName, lua.OpenString},
{lua.MathLibName, lua.OpenMath},
} {
if err := L.CallByParam(lua.P{
Fn: L.NewFunction(pair.f),
NRet: 0,
Protect: true,
}, lua.LString(pair.n)); err != nil {
return nil, errors.New("Error initializing Lua packages")
}
}
luaRegisterAccounts(L)
luaRegisterSecurities(L)
luaRegisterBalances(L)
luaRegisterDates(L)
luaRegisterTabulations(L)
luaRegisterPrices(L)
err := L.DoString(report.Lua)
if err != nil {
return nil, err
}
if err := L.CallByParam(lua.P{
Fn: L.GetGlobal("generate"),
NRet: 1,
Protect: true,
}); err != nil {
return nil, err
}
value := L.Get(-1)
if ud, ok := value.(*lua.LUserData); ok {
if tabulation, ok := ud.Value.(*models.Tabulation); ok {
return tabulation, nil
} else {
return nil, fmt.Errorf("generate() for %s (Id: %d) didn't return a tabulation", report.Name, report.ReportId)
}
} else {
return nil, fmt.Errorf("generate() for %s (Id: %d) didn't even return LUserData", report.Name, report.ReportId)
}
}

View File

@ -0,0 +1,214 @@
package reports
import (
"context"
"errors"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/moneygo/internal/store"
"github.com/yuin/gopher-lua"
"time"
)
const luaSecurityTypeName = "security"
func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error) {
var security_map map[int64]*models.Security
ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(store.Tx)
if !ok {
return nil, errors.New("Couldn't find tx in lua's Context")
}
security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*models.Security)
if !ok {
user, ok := ctx.Value(userContextKey).(*models.User)
if !ok {
return nil, errors.New("Couldn't find User in lua's Context")
}
securities, err := tx.GetSecurities(user.UserId)
if err != nil {
return nil, err
}
security_map = make(map[int64]*models.Security)
for i := range *securities {
security_map[(*securities)[i].SecurityId] = (*securities)[i]
}
ctx = context.WithValue(ctx, securitiesContextKey, security_map)
L.SetContext(ctx)
}
return security_map, nil
}
func luaContextGetDefaultCurrency(L *lua.LState) (*models.Security, error) {
security_map, err := luaContextGetSecurities(L)
if err != nil {
return nil, err
}
ctx := L.Context()
user, ok := ctx.Value(userContextKey).(*models.User)
if !ok {
return nil, errors.New("Couldn't find User in lua's Context")
}
if security, ok := security_map[user.DefaultCurrency]; ok {
return security, nil
} else {
return nil, errors.New("DefaultCurrency not in lua security_map")
}
}
func luaGetDefaultCurrency(L *lua.LState) int {
defcurrency, err := luaContextGetDefaultCurrency(L)
if err != nil {
panic("luaGetDefaultCurrency couldn't fetch default currency")
}
L.Push(SecurityToLua(L, defcurrency))
return 1
}
func luaGetSecurities(L *lua.LState) int {
security_map, err := luaContextGetSecurities(L)
if err != nil {
panic("luaGetSecurities couldn't fetch securities")
}
table := L.NewTable()
for securityid := range security_map {
table.RawSetInt(int(securityid), SecurityToLua(L, security_map[securityid]))
}
L.Push(table)
return 1
}
func luaRegisterSecurities(L *lua.LState) {
mt := L.NewTypeMetatable(luaSecurityTypeName)
L.SetGlobal("security", mt)
L.SetField(mt, "__index", L.NewFunction(luaSecurity__index))
L.SetField(mt, "__tostring", L.NewFunction(luaSecurity__tostring))
L.SetField(mt, "__eq", L.NewFunction(luaSecurity__eq))
L.SetField(mt, "__metatable", lua.LString("protected"))
getSecuritiesFn := L.NewFunction(luaGetSecurities)
L.SetField(mt, "get_all", getSecuritiesFn)
getDefaultCurrencyFn := L.NewFunction(luaGetDefaultCurrency)
L.SetField(mt, "get_default", getDefaultCurrencyFn)
// also register the get_securities and get_default functions as globals in
// their own right
L.SetGlobal("get_securities", getSecuritiesFn)
L.SetGlobal("get_default_currency", getDefaultCurrencyFn)
}
func SecurityToLua(L *lua.LState, security *models.Security) *lua.LUserData {
ud := L.NewUserData()
ud.Value = security
L.SetMetatable(ud, L.GetTypeMetatable(luaSecurityTypeName))
return ud
}
// Checks whether the first lua argument is a *LUserData with *Security and returns this *Security.
func luaCheckSecurity(L *lua.LState, n int) *models.Security {
ud := L.CheckUserData(n)
if security, ok := ud.Value.(*models.Security); ok {
return security
}
L.ArgError(n, "security expected")
return nil
}
func luaSecurity__index(L *lua.LState) int {
a := luaCheckSecurity(L, 1)
field := L.CheckString(2)
switch field {
case "SecurityId", "securityid":
L.Push(lua.LNumber(float64(a.SecurityId)))
case "Name", "name":
L.Push(lua.LString(a.Name))
case "Description", "description":
L.Push(lua.LString(a.Description))
case "Symbol", "symbol":
L.Push(lua.LString(a.Symbol))
case "Precision", "precision":
L.Push(lua.LNumber(float64(a.Precision)))
case "Type", "type":
L.Push(lua.LNumber(float64(a.Type)))
case "ClosestPrice", "closestprice":
L.Push(L.NewFunction(luaClosestPrice))
case "AlternateId", "alternateid":
L.Push(lua.LString(a.AlternateId))
default:
L.ArgError(2, "unexpected security attribute: "+field)
}
return 1
}
// Return the price for security in currency closest to date
func getClosestPrice(tx store.Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) {
earliest, _ := tx.GetEarliestPrice(security, currency, date)
latest, err := tx.GetLatestPrice(security, currency, date)
// Return early if either earliest or latest are invalid
if earliest == nil {
return latest, err
} else if err != nil {
return earliest, nil
}
howlate := earliest.Date.Sub(*date)
howearly := date.Sub(latest.Date)
if howearly < howlate {
return latest, nil
} else {
return earliest, nil
}
}
func luaClosestPrice(L *lua.LState) int {
s := luaCheckSecurity(L, 1)
c := luaCheckSecurity(L, 2)
date := luaCheckTime(L, 3)
ctx := L.Context()
tx, ok := ctx.Value(dbContextKey).(store.Tx)
if !ok {
panic("Couldn't find tx in lua's Context")
}
p, err := getClosestPrice(tx, s, c, date)
if err != nil {
L.Push(lua.LNil)
} else {
L.Push(PriceToLua(L, p))
}
return 1
}
func luaSecurity__tostring(L *lua.LState) int {
s := luaCheckSecurity(L, 1)
L.Push(lua.LString(s.Name + " - " + s.Description + " (" + s.Symbol + ")"))
return 1
}
func luaSecurity__eq(L *lua.LState) int {
a := luaCheckSecurity(L, 1)
b := luaCheckSecurity(L, 2)
L.Push(lua.LBool(a.SecurityId == b.SecurityId))
return 1
}

View File

@ -0,0 +1,188 @@
package reports
import (
"github.com/aclindsa/moneygo/internal/models"
"github.com/yuin/gopher-lua"
)
const luaTabulationTypeName = "tabulation"
const luaSeriesTypeName = "series"
func luaRegisterTabulations(L *lua.LState) {
mtr := L.NewTypeMetatable(luaTabulationTypeName)
L.SetGlobal("tabulation", mtr)
L.SetField(mtr, "new", L.NewFunction(luaTabulationNew))
L.SetField(mtr, "__index", L.NewFunction(luaTabulation__index))
L.SetField(mtr, "__metatable", lua.LString("protected"))
mts := L.NewTypeMetatable(luaSeriesTypeName)
L.SetGlobal("series", mts)
L.SetField(mts, "__index", L.NewFunction(luaSeries__index))
L.SetField(mts, "__metatable", lua.LString("protected"))
}
// Checks whether the first lua argument is a *LUserData with *Tabulation and returns *Tabulation
func luaCheckTabulation(L *lua.LState, n int) *models.Tabulation {
ud := L.CheckUserData(n)
if tabulation, ok := ud.Value.(*models.Tabulation); ok {
return tabulation
}
L.ArgError(n, "tabulation expected")
return nil
}
// Checks whether the first lua argument is a *LUserData with *Series and returns *Series
func luaCheckSeries(L *lua.LState, n int) *models.Series {
ud := L.CheckUserData(n)
if series, ok := ud.Value.(*models.Series); ok {
return series
}
L.ArgError(n, "series expected")
return nil
}
func luaTabulationNew(L *lua.LState) int {
numvalues := L.CheckInt(1)
ud := L.NewUserData()
ud.Value = &models.Tabulation{
Labels: make([]string, numvalues),
Series: make(map[string]*models.Series),
}
L.SetMetatable(ud, L.GetTypeMetatable(luaTabulationTypeName))
L.Push(ud)
return 1
}
func luaTabulation__index(L *lua.LState) int {
field := L.CheckString(2)
switch field {
case "Label", "label":
L.Push(L.NewFunction(luaTabulationLabel))
case "Series", "series":
L.Push(L.NewFunction(luaTabulationSeries))
case "Title", "title":
L.Push(L.NewFunction(luaTabulationTitle))
case "Subtitle", "subtitle":
L.Push(L.NewFunction(luaTabulationSubtitle))
case "Units", "units":
L.Push(L.NewFunction(luaTabulationUnits))
default:
L.ArgError(2, "unexpected tabulation attribute: "+field)
}
return 1
}
func luaTabulationLabel(L *lua.LState) int {
tabulation := luaCheckTabulation(L, 1)
labelnumber := L.CheckInt(2)
label := L.CheckString(3)
if labelnumber > cap(tabulation.Labels) || labelnumber < 1 {
L.ArgError(2, "Label index must be between 1 and the number of data points, inclusive")
}
tabulation.Labels[labelnumber-1] = label
return 0
}
func luaTabulationSeries(L *lua.LState) int {
tabulation := luaCheckTabulation(L, 1)
name := L.CheckString(2)
ud := L.NewUserData()
s, ok := tabulation.Series[name]
if ok {
ud.Value = s
} else {
tabulation.Series[name] = &models.Series{
Series: make(map[string]*models.Series),
Values: make([]float64, cap(tabulation.Labels)),
}
ud.Value = tabulation.Series[name]
}
L.SetMetatable(ud, L.GetTypeMetatable(luaSeriesTypeName))
L.Push(ud)
return 1
}
func luaTabulationTitle(L *lua.LState) int {
tabulation := luaCheckTabulation(L, 1)
if L.GetTop() == 2 {
tabulation.Title = L.CheckString(2)
return 0
}
L.Push(lua.LString(tabulation.Title))
return 1
}
func luaTabulationSubtitle(L *lua.LState) int {
tabulation := luaCheckTabulation(L, 1)
if L.GetTop() == 2 {
tabulation.Subtitle = L.CheckString(2)
return 0
}
L.Push(lua.LString(tabulation.Subtitle))
return 1
}
func luaTabulationUnits(L *lua.LState) int {
tabulation := luaCheckTabulation(L, 1)
if L.GetTop() == 2 {
tabulation.Units = L.CheckString(2)
return 0
}
L.Push(lua.LString(tabulation.Units))
return 1
}
func luaSeries__index(L *lua.LState) int {
field := L.CheckString(2)
switch field {
case "Value", "value":
L.Push(L.NewFunction(luaSeriesValue))
case "Series", "series":
L.Push(L.NewFunction(luaSeriesSeries))
default:
L.ArgError(2, "unexpected series attribute: "+field)
}
return 1
}
func luaSeriesValue(L *lua.LState) int {
series := luaCheckSeries(L, 1)
valuenumber := L.CheckInt(2)
value := float64(L.CheckNumber(3))
if valuenumber > cap(series.Values) || valuenumber < 1 {
L.ArgError(2, "value index must be between 1 and the number of data points, inclusive")
}
series.Values[valuenumber-1] = value
return 0
}
func luaSeriesSeries(L *lua.LState) int {
parent := luaCheckSeries(L, 1)
name := L.CheckString(2)
ud := L.NewUserData()
s, ok := parent.Series[name]
if ok {
ud.Value = s
} else {
parent.Series[name] = &models.Series{
Series: make(map[string]*models.Series),
Values: make([]float64, cap(parent.Values)),
}
ud.Value = parent.Series[name]
}
L.SetMetatable(ud, L.GetTypeMetatable(luaSeriesTypeName))
L.Push(ud)
return 1
}