From f3becb7f5ce9bfac11f1b166d3bf7114c7637fb2 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Fri, 27 Jan 2017 11:04:39 -0500 Subject: [PATCH] lua: Add security and account types --- accounts_lua.go | 156 ++++++++++++++++++++++++++++++++++++++++++++++ reports.go | 28 +++++++-- securities_lua.go | 127 +++++++++++++++++++++++++++++++++++++ 3 files changed, 307 insertions(+), 4 deletions(-) create mode 100644 accounts_lua.go create mode 100644 securities_lua.go diff --git a/accounts_lua.go b/accounts_lua.go new file mode 100644 index 0000000..303666b --- /dev/null +++ b/accounts_lua.go @@ -0,0 +1,156 @@ +package main + +import ( + "context" + "errors" + "github.com/yuin/gopher-lua" +) + +const luaAccountTypeName = "account" + +func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { + var account_map map[int64]*Account + + ctx := L.Context() + + account_map, ok := ctx.Value(accountsContextKey).(map[int64]*Account) + if !ok { + user, ok := ctx.Value(userContextKey).(*User) + if !ok { + return nil, errors.New("Couldn't find User in lua's Context") + } + + accounts, err := GetAccounts(user.UserId) + if err != nil { + return nil, err + } + + account_map = make(map[int64]*Account) + for i := range *accounts { + account_map[(*accounts)[i].AccountId] = &(*accounts)[i] + } + + ctx = context.WithValue(ctx, accountsContextKey, account_map) + L.SetContext(ctx) + } + + return account_map, nil +} + +func luaGetAccounts(L *lua.LState) int { + account_map, err := luaContextGetAccounts(L) + if err != nil { + panic("luaGetAccounts couldn't fetch accounts") + } + + table := L.NewTable() + + for accountid := range account_map { + table.RawSetInt(int(accountid), AccountToLua(L, account_map[accountid])) + } + + L.Push(table) + return 1 +} + +// Registers my account type to given L. +func luaRegisterAccounts(L *lua.LState) { + mt := L.NewTypeMetatable(luaAccountTypeName) + L.SetGlobal("account", mt) + L.SetField(mt, "__index", L.NewFunction(luaAccount__index)) + L.SetField(mt, "__tostring", L.NewFunction(luaAccount__tostring)) + L.SetField(mt, "__eq", L.NewFunction(luaAccount__eq)) + L.SetField(mt, "__metatable", lua.LString("protected")) + getAccountsFn := L.NewFunction(luaGetAccounts) + L.SetField(mt, "get_all", getAccountsFn) + + // also register the get_accounts function as a global in its own right + L.SetGlobal("get_accounts", getAccountsFn) +} + +func AccountToLua(L *lua.LState, account *Account) *lua.LUserData { + ud := L.NewUserData() + ud.Value = account + L.SetMetatable(ud, L.GetTypeMetatable(luaAccountTypeName)) + return ud +} + +// Checks whether the first lua argument is a *LUserData with *Account and returns this *Account. +func luaCheckAccount(L *lua.LState, n int) *Account { + ud := L.CheckUserData(n) + if account, ok := ud.Value.(*Account); ok { + return account + } + L.ArgError(n, "account expected") + return nil +} + +func luaAccount__index(L *lua.LState) int { + a := luaCheckAccount(L, 1) + field := L.CheckString(2) + + switch field { + case "AccountId", "accountid": + L.Push(lua.LNumber(float64(a.AccountId))) + case "Security", "security": + security_map, err := luaContextGetSecurities(L) + if err != nil { + panic("account.security couldn't fetch securities") + } + if security, ok := security_map[a.SecurityId]; ok { + L.Push(SecurityToLua(L, security)) + } else { + panic("SecurityId not in lua security_map") + } + case "Parent", "parent", "ParentAccount", "parentaccount": + if a.ParentAccountId == -1 { + L.Push(lua.LNil) + } else { + account_map, err := luaContextGetAccounts(L) + if err != nil { + panic("account.parent couldn't fetch accounts") + } + if parent, ok := account_map[a.ParentAccountId]; ok { + L.Push(AccountToLua(L, parent)) + } else { + panic("ParentAccountId not in lua account_map") + } + } + case "Name", "name": + L.Push(lua.LString(a.Name)) + case "Type", "type": + L.Push(lua.LNumber(float64(a.Type))) + default: + L.ArgError(2, "unexpected account attribute: "+field) + } + + return 1 +} + +func luaAccount__tostring(L *lua.LState) int { + a := luaCheckAccount(L, 1) + + account_map, err := luaContextGetAccounts(L) + if err != nil { + panic("luaGetAccounts couldn't fetch accounts") + } + + full_name := a.Name + for a.ParentAccountId != -1 { + a = account_map[a.ParentAccountId] + full_name = a.Name + "/" + full_name + } + + L.Push(lua.LString(full_name)) + + return 1 +} + +func luaAccount__eq(L *lua.LState) int { + a := luaCheckAccount(L, 1) + b := luaCheckAccount(L, 2) + + L.Push(lua.LBool(a.AccountId == b.AccountId)) + + return 1 +} diff --git a/reports.go b/reports.go index 1ba8b8b..d8e18de 100644 --- a/reports.go +++ b/reports.go @@ -3,6 +3,7 @@ package main import ( "context" "github.com/yuin/gopher-lua" + "log" "net/http" "time" ) @@ -10,9 +11,13 @@ import ( //type and value to store user in lua's Context type key int -const userContextKey key = 0 +const ( + userContextKey key = iota + accountsContextKey + securitiesContextKey +) -const luaTimeoutSeconds uint = 5 // maximum time a lua request can run for +const luaTimeoutSeconds time.Duration = 5 // maximum time a lua request can run for func ReportHandler(w http.ResponseWriter, r *http.Request) { user, err := GetUserFromSession(r) @@ -50,10 +55,25 @@ func ReportHandler(w http.ResponseWriter, r *http.Request) { panic(err) } } - err := L.DoString(`print("Hello World")`) + + luaRegisterAccounts(L) + luaRegisterSecurities(L) + + err := L.DoString(`accounts = account.get_all() +last_parent = nil +for id, account in pairs(accounts) do + parent = account.parent + print(account, parent, account.security) + if parent then + print(last_parent, parent) + print("parent equals last:", last_parent == parent) + last_parent = parent + end +end +`) if err != nil { - panic(err) + log.Print("lua:" + err.Error()) } } } diff --git a/securities_lua.go b/securities_lua.go new file mode 100644 index 0000000..77feb5b --- /dev/null +++ b/securities_lua.go @@ -0,0 +1,127 @@ +package main + +import ( + "context" + "errors" + "github.com/yuin/gopher-lua" +) + +const luaSecurityTypeName = "security" + +func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { + var security_map map[int64]*Security + + ctx := L.Context() + + security_map, ok := ctx.Value(securitiesContextKey).(map[int64]*Security) + if !ok { + user, ok := ctx.Value(userContextKey).(*User) + if !ok { + return nil, errors.New("Couldn't find User in lua's Context") + } + + securities, err := GetSecurities(user.UserId) + if err != nil { + return nil, err + } + + security_map = make(map[int64]*Security) + for i := range *securities { + security_map[(*securities)[i].SecurityId] = (*securities)[i] + } + + ctx = context.WithValue(ctx, securitiesContextKey, security_map) + L.SetContext(ctx) + } + + return security_map, nil +} + +func luaGetSecurities(L *lua.LState) int { + security_map, err := luaContextGetSecurities(L) + if err != nil { + panic("luaGetSecurities couldn't fetch securities") + } + + table := L.NewTable() + + for securityid := range security_map { + table.RawSetInt(int(securityid), SecurityToLua(L, security_map[securityid])) + } + + L.Push(table) + return 1 +} + +// Registers my security type to given L. +func luaRegisterSecurities(L *lua.LState) { + mt := L.NewTypeMetatable(luaSecurityTypeName) + L.SetGlobal("security", mt) + L.SetField(mt, "__index", L.NewFunction(luaSecurity__index)) + L.SetField(mt, "__tostring", L.NewFunction(luaSecurity__tostring)) + L.SetField(mt, "__eq", L.NewFunction(luaSecurity__eq)) + L.SetField(mt, "__metatable", lua.LString("protected")) + getSecuritiesFn := L.NewFunction(luaGetSecurities) + L.SetField(mt, "get_all", getSecuritiesFn) + + // also register the get_securities function as a global in its own right + L.SetGlobal("get_securities", getSecuritiesFn) +} + +func SecurityToLua(L *lua.LState, security *Security) *lua.LUserData { + ud := L.NewUserData() + ud.Value = security + L.SetMetatable(ud, L.GetTypeMetatable(luaSecurityTypeName)) + return ud +} + +// Checks whether the first lua argument is a *LUserData with *Security and returns this *Security. +func luaCheckSecurity(L *lua.LState, n int) *Security { + ud := L.CheckUserData(n) + if security, ok := ud.Value.(*Security); ok { + return security + } + L.ArgError(n, "security expected") + return nil +} + +func luaSecurity__index(L *lua.LState) int { + a := luaCheckSecurity(L, 1) + field := L.CheckString(2) + + switch field { + case "SecurityId", "securityid": + L.Push(lua.LNumber(float64(a.SecurityId))) + case "Name", "name": + L.Push(lua.LString(a.Name)) + case "Description", "description": + L.Push(lua.LString(a.Description)) + case "Symbol", "symbol": + L.Push(lua.LString(a.Symbol)) + case "Precision", "precision": + L.Push(lua.LNumber(float64(a.Precision))) + case "Type", "type": + L.Push(lua.LNumber(float64(a.Type))) + default: + L.ArgError(2, "unexpected security attribute: "+field) + } + + return 1 +} + +func luaSecurity__tostring(L *lua.LState) int { + s := luaCheckSecurity(L, 1) + + L.Push(lua.LString(s.Name + " - " + s.Description + " (" + s.Symbol + ")")) + + return 1 +} + +func luaSecurity__eq(L *lua.LState) int { + a := luaCheckSecurity(L, 1) + b := luaCheckSecurity(L, 2) + + L.Push(lua.LBool(a.SecurityId == b.SecurityId)) + + return 1 +}