mirror of
https://github.com/aclindsa/moneygo.git
synced 2024-10-31 16:00:05 -04:00
lua: Add security and account types
This commit is contained in:
parent
806ceb2f5c
commit
f3becb7f5c
156
accounts_lua.go
Normal file
156
accounts_lua.go
Normal file
@ -0,0 +1,156 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
const luaAccountTypeName = "account"
|
||||
|
||||
func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) {
|
||||
var account_map map[int64]*Account
|
||||
|
||||
ctx := L.Context()
|
||||
|
||||
account_map, ok := ctx.Value(accountsContextKey).(map[int64]*Account)
|
||||
if !ok {
|
||||
user, ok := ctx.Value(userContextKey).(*User)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find User in lua's Context")
|
||||
}
|
||||
|
||||
accounts, err := GetAccounts(user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account_map = make(map[int64]*Account)
|
||||
for i := range *accounts {
|
||||
account_map[(*accounts)[i].AccountId] = &(*accounts)[i]
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, accountsContextKey, account_map)
|
||||
L.SetContext(ctx)
|
||||
}
|
||||
|
||||
return account_map, nil
|
||||
}
|
||||
|
||||
func luaGetAccounts(L *lua.LState) int {
|
||||
account_map, err := luaContextGetAccounts(L)
|
||||
if err != nil {
|
||||
panic("luaGetAccounts couldn't fetch accounts")
|
||||
}
|
||||
|
||||
table := L.NewTable()
|
||||
|
||||
for accountid := range account_map {
|
||||
table.RawSetInt(int(accountid), AccountToLua(L, account_map[accountid]))
|
||||
}
|
||||
|
||||
L.Push(table)
|
||||
return 1
|
||||
}
|
||||
|
||||
// Registers my account type to given L.
|
||||
func luaRegisterAccounts(L *lua.LState) {
|
||||
mt := L.NewTypeMetatable(luaAccountTypeName)
|
||||
L.SetGlobal("account", mt)
|
||||
L.SetField(mt, "__index", L.NewFunction(luaAccount__index))
|
||||
L.SetField(mt, "__tostring", L.NewFunction(luaAccount__tostring))
|
||||
L.SetField(mt, "__eq", L.NewFunction(luaAccount__eq))
|
||||
L.SetField(mt, "__metatable", lua.LString("protected"))
|
||||
getAccountsFn := L.NewFunction(luaGetAccounts)
|
||||
L.SetField(mt, "get_all", getAccountsFn)
|
||||
|
||||
// also register the get_accounts function as a global in its own right
|
||||
L.SetGlobal("get_accounts", getAccountsFn)
|
||||
}
|
||||
|
||||
func AccountToLua(L *lua.LState, account *Account) *lua.LUserData {
|
||||
ud := L.NewUserData()
|
||||
ud.Value = account
|
||||
L.SetMetatable(ud, L.GetTypeMetatable(luaAccountTypeName))
|
||||
return ud
|
||||
}
|
||||
|
||||
// Checks whether the first lua argument is a *LUserData with *Account and returns this *Account.
|
||||
func luaCheckAccount(L *lua.LState, n int) *Account {
|
||||
ud := L.CheckUserData(n)
|
||||
if account, ok := ud.Value.(*Account); ok {
|
||||
return account
|
||||
}
|
||||
L.ArgError(n, "account expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func luaAccount__index(L *lua.LState) int {
|
||||
a := luaCheckAccount(L, 1)
|
||||
field := L.CheckString(2)
|
||||
|
||||
switch field {
|
||||
case "AccountId", "accountid":
|
||||
L.Push(lua.LNumber(float64(a.AccountId)))
|
||||
case "Security", "security":
|
||||
security_map, err := luaContextGetSecurities(L)
|
||||
if err != nil {
|
||||
panic("account.security couldn't fetch securities")
|
||||
}
|
||||
if security, ok := security_map[a.SecurityId]; ok {
|
||||
L.Push(SecurityToLua(L, security))
|
||||
} else {
|
||||
panic("SecurityId not in lua security_map")
|
||||
}
|
||||
case "Parent", "parent", "ParentAccount", "parentaccount":
|
||||
if a.ParentAccountId == -1 {
|
||||
L.Push(lua.LNil)
|
||||
} else {
|
||||
account_map, err := luaContextGetAccounts(L)
|
||||
if err != nil {
|
||||
panic("account.parent couldn't fetch accounts")
|
||||
}
|
||||
if parent, ok := account_map[a.ParentAccountId]; ok {
|
||||
L.Push(AccountToLua(L, parent))
|
||||
} else {
|
||||
panic("ParentAccountId not in lua account_map")
|
||||
}
|
||||
}
|
||||
case "Name", "name":
|
||||
L.Push(lua.LString(a.Name))
|
||||
case "Type", "type":
|
||||
L.Push(lua.LNumber(float64(a.Type)))
|
||||
default:
|
||||
L.ArgError(2, "unexpected account attribute: "+field)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaAccount__tostring(L *lua.LState) int {
|
||||
a := luaCheckAccount(L, 1)
|
||||
|
||||
account_map, err := luaContextGetAccounts(L)
|
||||
if err != nil {
|
||||
panic("luaGetAccounts couldn't fetch accounts")
|
||||
}
|
||||
|
||||
full_name := a.Name
|
||||
for a.ParentAccountId != -1 {
|
||||
a = account_map[a.ParentAccountId]
|
||||
full_name = a.Name + "/" + full_name
|
||||
}
|
||||
|
||||
L.Push(lua.LString(full_name))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaAccount__eq(L *lua.LState) int {
|
||||
a := luaCheckAccount(L, 1)
|
||||
b := luaCheckAccount(L, 2)
|
||||
|
||||
L.Push(lua.LBool(a.AccountId == b.AccountId))
|
||||
|
||||
return 1
|
||||
}
|
28
reports.go
28
reports.go
@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"github.com/yuin/gopher-lua"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
@ -10,9 +11,13 @@ import (
|
||||
//type and value to store user in lua's Context
|
||||
type key int
|
||||
|
||||
const userContextKey key = 0
|
||||
const (
|
||||
userContextKey key = iota
|
||||
accountsContextKey
|
||||
securitiesContextKey
|
||||
)
|
||||
|
||||
const luaTimeoutSeconds uint = 5 // maximum time a lua request can run for
|
||||
const luaTimeoutSeconds time.Duration = 5 // maximum time a lua request can run for
|
||||
|
||||
func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := GetUserFromSession(r)
|
||||
@ -50,10 +55,25 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
err := L.DoString(`print("Hello World")`)
|
||||
|
||||
luaRegisterAccounts(L)
|
||||
luaRegisterSecurities(L)
|
||||
|
||||
err := L.DoString(`accounts = account.get_all()
|
||||
last_parent = nil
|
||||
for id, account in pairs(accounts) do
|
||||
parent = account.parent
|
||||
print(account, parent, account.security)
|
||||
if parent then
|
||||
print(last_parent, parent)
|
||||
print("parent equals last:", last_parent == parent)
|
||||
last_parent = parent
|
||||
end
|
||||
end
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Print("lua:" + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
127
securities_lua.go
Normal file
127
securities_lua.go
Normal file
@ -0,0 +1,127 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
const luaSecurityTypeName = "security"
|
||||
|
||||
func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) {
|
||||
var security_map map[int64]*Security
|
||||
|
||||
ctx := L.Context()
|
||||
|
||||
security_map, ok := ctx.Value(securitiesContextKey).(map[int64]*Security)
|
||||
if !ok {
|
||||
user, ok := ctx.Value(userContextKey).(*User)
|
||||
if !ok {
|
||||
return nil, errors.New("Couldn't find User in lua's Context")
|
||||
}
|
||||
|
||||
securities, err := GetSecurities(user.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
security_map = make(map[int64]*Security)
|
||||
for i := range *securities {
|
||||
security_map[(*securities)[i].SecurityId] = (*securities)[i]
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, securitiesContextKey, security_map)
|
||||
L.SetContext(ctx)
|
||||
}
|
||||
|
||||
return security_map, nil
|
||||
}
|
||||
|
||||
func luaGetSecurities(L *lua.LState) int {
|
||||
security_map, err := luaContextGetSecurities(L)
|
||||
if err != nil {
|
||||
panic("luaGetSecurities couldn't fetch securities")
|
||||
}
|
||||
|
||||
table := L.NewTable()
|
||||
|
||||
for securityid := range security_map {
|
||||
table.RawSetInt(int(securityid), SecurityToLua(L, security_map[securityid]))
|
||||
}
|
||||
|
||||
L.Push(table)
|
||||
return 1
|
||||
}
|
||||
|
||||
// Registers my security type to given L.
|
||||
func luaRegisterSecurities(L *lua.LState) {
|
||||
mt := L.NewTypeMetatable(luaSecurityTypeName)
|
||||
L.SetGlobal("security", mt)
|
||||
L.SetField(mt, "__index", L.NewFunction(luaSecurity__index))
|
||||
L.SetField(mt, "__tostring", L.NewFunction(luaSecurity__tostring))
|
||||
L.SetField(mt, "__eq", L.NewFunction(luaSecurity__eq))
|
||||
L.SetField(mt, "__metatable", lua.LString("protected"))
|
||||
getSecuritiesFn := L.NewFunction(luaGetSecurities)
|
||||
L.SetField(mt, "get_all", getSecuritiesFn)
|
||||
|
||||
// also register the get_securities function as a global in its own right
|
||||
L.SetGlobal("get_securities", getSecuritiesFn)
|
||||
}
|
||||
|
||||
func SecurityToLua(L *lua.LState, security *Security) *lua.LUserData {
|
||||
ud := L.NewUserData()
|
||||
ud.Value = security
|
||||
L.SetMetatable(ud, L.GetTypeMetatable(luaSecurityTypeName))
|
||||
return ud
|
||||
}
|
||||
|
||||
// Checks whether the first lua argument is a *LUserData with *Security and returns this *Security.
|
||||
func luaCheckSecurity(L *lua.LState, n int) *Security {
|
||||
ud := L.CheckUserData(n)
|
||||
if security, ok := ud.Value.(*Security); ok {
|
||||
return security
|
||||
}
|
||||
L.ArgError(n, "security expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func luaSecurity__index(L *lua.LState) int {
|
||||
a := luaCheckSecurity(L, 1)
|
||||
field := L.CheckString(2)
|
||||
|
||||
switch field {
|
||||
case "SecurityId", "securityid":
|
||||
L.Push(lua.LNumber(float64(a.SecurityId)))
|
||||
case "Name", "name":
|
||||
L.Push(lua.LString(a.Name))
|
||||
case "Description", "description":
|
||||
L.Push(lua.LString(a.Description))
|
||||
case "Symbol", "symbol":
|
||||
L.Push(lua.LString(a.Symbol))
|
||||
case "Precision", "precision":
|
||||
L.Push(lua.LNumber(float64(a.Precision)))
|
||||
case "Type", "type":
|
||||
L.Push(lua.LNumber(float64(a.Type)))
|
||||
default:
|
||||
L.ArgError(2, "unexpected security attribute: "+field)
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaSecurity__tostring(L *lua.LState) int {
|
||||
s := luaCheckSecurity(L, 1)
|
||||
|
||||
L.Push(lua.LString(s.Name + " - " + s.Description + " (" + s.Symbol + ")"))
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
func luaSecurity__eq(L *lua.LState) int {
|
||||
a := luaCheckSecurity(L, 1)
|
||||
b := luaCheckSecurity(L, 2)
|
||||
|
||||
L.Push(lua.LBool(a.SecurityId == b.SecurityId))
|
||||
|
||||
return 1
|
||||
}
|
Loading…
Reference in New Issue
Block a user