1
0
mirror of https://github.com/aclindsa/moneygo.git synced 2024-09-21 12:20:04 -04:00
moneygo/internal/handlers/reports.go
Aaron Lindsay 5504d37482 testing: Improve testing CRUD for reports
Ensure we don't get silent errors if the Lua code is longer than the
database column, don't leave out the first report from testing, test
fetching multiple reports.
2017-11-03 20:50:19 -04:00

331 lines
7.3 KiB
Go

package handlers
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/yuin/gopher-lua"
"log"
"net/http"
"regexp"
"strings"
"time"
)
var reportTabulationRE *regexp.Regexp
func init() {
reportTabulationRE = regexp.MustCompile(`^/report/[0-9]+/tabulation/?$`)
}
//type and value to store user in lua's Context
type key int
const (
userContextKey key = iota
accountsContextKey
securitiesContextKey
balanceContextKey
dbContextKey
)
const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for
type Report struct {
ReportId int64
UserId int64
Name string
Lua string
}
// The maximum length (in bytes) the Lua code may be. This is used to set the
// max size of the database columns (with an added fudge factor)
const LuaMaxLength int = 65536
func (r *Report) Write(w http.ResponseWriter) error {
enc := json.NewEncoder(w)
return enc.Encode(r)
}
func (r *Report) Read(json_str string) error {
dec := json.NewDecoder(strings.NewReader(json_str))
return dec.Decode(r)
}
type ReportList struct {
Reports *[]Report `json:"reports"`
}
func (rl *ReportList) Write(w http.ResponseWriter) error {
enc := json.NewEncoder(w)
return enc.Encode(rl)
}
func (rl *ReportList) Read(json_str string) error {
dec := json.NewDecoder(strings.NewReader(json_str))
return dec.Decode(rl)
}
type Series struct {
Values []float64
Series map[string]*Series
}
type Tabulation struct {
ReportId int64
Title string
Subtitle string
Units string
Labels []string
Series map[string]*Series
}
func (r *Tabulation) Write(w http.ResponseWriter) error {
enc := json.NewEncoder(w)
return enc.Encode(r)
}
func GetReport(tx *Tx, reportid int64, userid int64) (*Report, error) {
var r Report
err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
if err != nil {
return nil, err
}
return &r, nil
}
func GetReports(tx *Tx, userid int64) (*[]Report, error) {
var reports []Report
_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
if err != nil {
return nil, err
}
return &reports, nil
}
func InsertReport(tx *Tx, r *Report) error {
err := tx.Insert(r)
if err != nil {
return err
}
return nil
}
func UpdateReport(tx *Tx, r *Report) error {
count, err := tx.Update(r)
if err != nil {
return err
}
if count != 1 {
return errors.New("Updated more than one report")
}
return nil
}
func DeleteReport(tx *Tx, r *Report) error {
count, err := tx.Delete(r)
if err != nil {
return err
}
if count != 1 {
return errors.New("Deleted more than one report")
}
return nil
}
func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) {
// Create a new LState without opening the default libs for security
L := lua.NewState(lua.Options{SkipOpenLibs: true})
defer L.Close()
// Create a new context holding the current user with a timeout
ctx := context.WithValue(context.Background(), userContextKey, user)
ctx = context.WithValue(ctx, dbContextKey, tx)
ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second)
defer cancel()
L.SetContext(ctx)
for _, pair := range []struct {
n string
f lua.LGFunction
}{
{lua.LoadLibName, lua.OpenPackage}, // Must be first
{lua.BaseLibName, lua.OpenBase},
{lua.TabLibName, lua.OpenTable},
{lua.StringLibName, lua.OpenString},
{lua.MathLibName, lua.OpenMath},
} {
if err := L.CallByParam(lua.P{
Fn: L.NewFunction(pair.f),
NRet: 0,
Protect: true,
}, lua.LString(pair.n)); err != nil {
return nil, errors.New("Error initializing Lua packages")
}
}
luaRegisterAccounts(L)
luaRegisterSecurities(L)
luaRegisterBalances(L)
luaRegisterDates(L)
luaRegisterTabulations(L)
luaRegisterPrices(L)
err := L.DoString(report.Lua)
if err != nil {
return nil, err
}
if err := L.CallByParam(lua.P{
Fn: L.GetGlobal("generate"),
NRet: 1,
Protect: true,
}); err != nil {
return nil, err
}
value := L.Get(-1)
if ud, ok := value.(*lua.LUserData); ok {
if tabulation, ok := ud.Value.(*Tabulation); ok {
return tabulation, nil
} else {
return nil, fmt.Errorf("generate() for %s (Id: %d) didn't return a tabulation", report.Name, report.ReportId)
}
} else {
return nil, fmt.Errorf("generate() for %s (Id: %d) didn't even return LUserData", report.Name, report.ReportId)
}
}
func ReportTabulationHandler(tx *Tx, r *http.Request, user *User, reportid int64) ResponseWriterWriter {
report, err := GetReport(tx, reportid, user.UserId)
if err != nil {
return NewError(3 /*Invalid Request*/)
}
tabulation, err := runReport(tx, user, report)
if err != nil {
// TODO handle different failure cases differently
log.Print("runReport returned:", err)
return NewError(3 /*Invalid Request*/)
}
tabulation.ReportId = reportid
return tabulation
}
func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
user, err := GetUserFromSession(tx, r)
if err != nil {
return NewError(1 /*Not Signed In*/)
}
if r.Method == "POST" {
report_json := r.PostFormValue("report")
if report_json == "" {
return NewError(3 /*Invalid Request*/)
}
var report Report
err := report.Read(report_json)
if err != nil {
return NewError(3 /*Invalid Request*/)
}
report.ReportId = -1
report.UserId = user.UserId
if len(report.Lua) >= LuaMaxLength {
return NewError(3 /*Invalid Request*/)
}
err = InsertReport(tx, &report)
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)
}
return ResponseWrapper{201, &report}
} else if r.Method == "GET" {
if reportTabulationRE.MatchString(r.URL.Path) {
var reportid int64
n, err := GetURLPieces(r.URL.Path, "/report/%d/tabulation", &reportid)
if err != nil || n != 1 {
log.Print(err)
return NewError(999 /*InternalError*/)
}
return ReportTabulationHandler(tx, r, user, reportid)
}
var reportid int64
n, err := GetURLPieces(r.URL.Path, "/report/%d", &reportid)
if err != nil || n != 1 {
//Return all Reports
var rl ReportList
reports, err := GetReports(tx, user.UserId)
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)
}
rl.Reports = reports
return &rl
} else {
// Return Report with this Id
report, err := GetReport(tx, reportid, user.UserId)
if err != nil {
return NewError(3 /*Invalid Request*/)
}
return report
}
} else {
reportid, err := GetURLID(r.URL.Path)
if err != nil {
return NewError(3 /*Invalid Request*/)
}
if r.Method == "PUT" {
report_json := r.PostFormValue("report")
if report_json == "" {
return NewError(3 /*Invalid Request*/)
}
var report Report
err := report.Read(report_json)
if err != nil || report.ReportId != reportid {
return NewError(3 /*Invalid Request*/)
}
report.UserId = user.UserId
if len(report.Lua) >= LuaMaxLength {
return NewError(3 /*Invalid Request*/)
}
err = UpdateReport(tx, &report)
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)
}
return &report
} else if r.Method == "DELETE" {
report, err := GetReport(tx, reportid, user.UserId)
if err != nil {
return NewError(3 /*Invalid Request*/)
}
err = DeleteReport(tx, report)
if err != nil {
log.Print(err)
return NewError(999 /*Internal Error*/)
}
return SuccessWriter{}
}
}
return NewError(3 /*Invalid Request*/)
}