From 5504d37482803f72b991940a62ea69cf69bf103a Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Fri, 3 Nov 2017 20:50:19 -0400 Subject: [PATCH] testing: Improve testing CRUD for reports Ensure we don't get silent errors if the Lua code is longer than the database column, don't leave out the first report from testing, test fetching multiple reports. --- internal/db/db.go | 5 +- internal/handlers/reports.go | 12 +++++ internal/handlers/reports_test.go | 74 ++++++++++++++++++++++++++++-- internal/handlers/testdata_test.go | 2 +- 4 files changed, 87 insertions(+), 6 deletions(-) diff --git a/internal/db/db.go b/internal/db/db.go index 80e3b9d..8622af0 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -13,6 +13,8 @@ import ( "strings" ) +const luaMaxLengthBuffer int = 4096 + func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { var dialect gorp.Dialect if dbtype == config.SQLite { @@ -36,7 +38,8 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { dbmap.AddTableWithName(handlers.Transaction{}, "transactions").SetKeys(true, "TransactionId") dbmap.AddTableWithName(handlers.Split{}, "splits").SetKeys(true, "SplitId") dbmap.AddTableWithName(handlers.Price{}, "prices").SetKeys(true, "PriceId") - dbmap.AddTableWithName(handlers.Report{}, "reports").SetKeys(true, "ReportId") + rtable := dbmap.AddTableWithName(handlers.Report{}, "reports").SetKeys(true, "ReportId") + rtable.ColMap("Lua").SetMaxSize(handlers.LuaMaxLength + luaMaxLengthBuffer) err := dbmap.CreateTablesIfNotExists() if err != nil { diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index 985ff07..6c313ef 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -39,6 +39,10 @@ type Report struct { Lua string } +// The maximum length (in bytes) the Lua code may be. This is used to set the +// max size of the database columns (with an added fudge factor) +const LuaMaxLength int = 65536 + func (r *Report) Write(w http.ResponseWriter) error { enc := json.NewEncoder(w) return enc.Encode(r) @@ -234,6 +238,10 @@ func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { report.ReportId = -1 report.UserId = user.UserId + if len(report.Lua) >= LuaMaxLength { + return NewError(3 /*Invalid Request*/) + } + err = InsertReport(tx, &report) if err != nil { log.Print(err) @@ -292,6 +300,10 @@ func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } report.UserId = user.UserId + if len(report.Lua) >= LuaMaxLength { + return NewError(3 /*Invalid Request*/) + } + err = UpdateReport(tx, &report) if err != nil { log.Print(err) diff --git a/internal/handlers/reports_test.go b/internal/handlers/reports_test.go index f4c2b44..a9e6c0e 100644 --- a/internal/handlers/reports_test.go +++ b/internal/handlers/reports_test.go @@ -50,7 +50,7 @@ func deleteReport(client *http.Client, r *handlers.Report) error { func TestCreateReport(t *testing.T) { RunWith(t, &data[0], func(t *testing.T, d *TestData) { - for i := 1; i < len(data[0].reports); i++ { + for i := 0; i < len(data[0].reports); i++ { orig := data[0].reports[i] r := d.reports[i] @@ -63,13 +63,26 @@ func TestCreateReport(t *testing.T) { if r.Lua != orig.Lua { t.Errorf("Lua doesn't match") } + + r.Lua = string(make([]byte, handlers.LuaMaxLength+1)) + _, err := createReport(d.clients[orig.UserId], &r) + if err == nil { + t.Fatalf("Expected error creating report with too-long Lua") + } + if herr, ok := err.(*handlers.Error); ok { + if herr.ErrorId != 3 { // Invalid requeset + t.Fatalf("Unexpected API error creating report with too-long Lua: %s", herr) + } + } else { + t.Fatalf("Unexpected error creating report with too-long Lua") + } } }) } func TestGetReport(t *testing.T) { RunWith(t, &data[0], func(t *testing.T, d *TestData) { - for i := 1; i < len(data[0].reports); i++ { + for i := 0; i < len(data[0].reports); i++ { orig := data[0].reports[i] curr := d.reports[i] @@ -87,9 +100,49 @@ func TestGetReport(t *testing.T) { }) } +func TestGetReports(t *testing.T) { + RunWith(t, &data[0], func(t *testing.T, d *TestData) { + rl, err := getReports(d.clients[0]) + if err != nil { + t.Fatalf("Error fetching reports: %s\n", err) + } + + numreports := 0 + foundIds := make(map[int64]bool) + for i := 0; i < len(data[0].reports); i++ { + orig := data[0].reports[i] + curr := d.reports[i] + + if curr.UserId != d.users[0].UserId { + continue + } + numreports += 1 + + found := false + for _, r := range *rl.Reports { + if orig.Name == r.Name && orig.Lua == r.Lua { + if _, ok := foundIds[r.ReportId]; ok { + continue + } + foundIds[r.ReportId] = true + found = true + break + } + } + if !found { + t.Errorf("Unable to find matching report: %+v", orig) + } + } + + if numreports != len(*rl.Reports) { + t.Fatalf("Expected %d reports, received %d", numreports, len(*rl.Reports)) + } + }) +} + func TestUpdateReport(t *testing.T) { RunWith(t, &data[0], func(t *testing.T, d *TestData) { - for i := 1; i < len(data[0].reports); i++ { + for i := 0; i < len(data[0].reports); i++ { orig := data[0].reports[i] curr := d.reports[i] @@ -110,13 +163,26 @@ func TestUpdateReport(t *testing.T) { if r.Lua != curr.Lua { t.Errorf("Lua doesn't match") } + + r.Lua = string(make([]byte, handlers.LuaMaxLength+1)) + _, err = updateReport(d.clients[orig.UserId], r) + if err == nil { + t.Fatalf("Expected error updating report with too-long Lua") + } + if herr, ok := err.(*handlers.Error); ok { + if herr.ErrorId != 3 { // Invalid requeset + t.Fatalf("Unexpected API error updating report with too-long Lua: %s", herr) + } + } else { + t.Fatalf("Unexpected error updating report with too-long Lua") + } } }) } func TestDeleteReport(t *testing.T) { RunWith(t, &data[0], func(t *testing.T, d *TestData) { - for i := 1; i < len(data[0].reports); i++ { + for i := 0; i < len(data[0].reports); i++ { orig := data[0].reports[i] curr := d.reports[i] diff --git a/internal/handlers/testdata_test.go b/internal/handlers/testdata_test.go index c505e59..f380bd7 100644 --- a/internal/handlers/testdata_test.go +++ b/internal/handlers/testdata_test.go @@ -321,7 +321,7 @@ var data = []TestData{ reports: []handlers.Report{ handlers.Report{ UserId: 0, - Name: "", + Name: "Monthly Expenses", Lua: ` function account_series_map(accounts, tabulation) map = {}