diff --git a/internal/db/db.go b/internal/db/db.go index 80e3b9d..8622af0 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -13,6 +13,8 @@ import ( "strings" ) +const luaMaxLengthBuffer int = 4096 + func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { var dialect gorp.Dialect if dbtype == config.SQLite { @@ -36,7 +38,8 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { dbmap.AddTableWithName(handlers.Transaction{}, "transactions").SetKeys(true, "TransactionId") dbmap.AddTableWithName(handlers.Split{}, "splits").SetKeys(true, "SplitId") dbmap.AddTableWithName(handlers.Price{}, "prices").SetKeys(true, "PriceId") - dbmap.AddTableWithName(handlers.Report{}, "reports").SetKeys(true, "ReportId") + rtable := dbmap.AddTableWithName(handlers.Report{}, "reports").SetKeys(true, "ReportId") + rtable.ColMap("Lua").SetMaxSize(handlers.LuaMaxLength + luaMaxLengthBuffer) err := dbmap.CreateTablesIfNotExists() if err != nil { diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index 985ff07..6c313ef 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -39,6 +39,10 @@ type Report struct { Lua string } +// The maximum length (in bytes) the Lua code may be. This is used to set the +// max size of the database columns (with an added fudge factor) +const LuaMaxLength int = 65536 + func (r *Report) Write(w http.ResponseWriter) error { enc := json.NewEncoder(w) return enc.Encode(r) @@ -234,6 +238,10 @@ func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { report.ReportId = -1 report.UserId = user.UserId + if len(report.Lua) >= LuaMaxLength { + return NewError(3 /*Invalid Request*/) + } + err = InsertReport(tx, &report) if err != nil { log.Print(err) @@ -292,6 +300,10 @@ func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter { } report.UserId = user.UserId + if len(report.Lua) >= LuaMaxLength { + return NewError(3 /*Invalid Request*/) + } + err = UpdateReport(tx, &report) if err != nil { log.Print(err) diff --git a/internal/handlers/reports_test.go b/internal/handlers/reports_test.go index f4c2b44..a9e6c0e 100644 --- a/internal/handlers/reports_test.go +++ b/internal/handlers/reports_test.go @@ -50,7 +50,7 @@ func deleteReport(client *http.Client, r *handlers.Report) error { func TestCreateReport(t *testing.T) { RunWith(t, &data[0], func(t *testing.T, d *TestData) { - for i := 1; i < len(data[0].reports); i++ { + for i := 0; i < len(data[0].reports); i++ { orig := data[0].reports[i] r := d.reports[i] @@ -63,13 +63,26 @@ func TestCreateReport(t *testing.T) { if r.Lua != orig.Lua { t.Errorf("Lua doesn't match") } + + r.Lua = string(make([]byte, handlers.LuaMaxLength+1)) + _, err := createReport(d.clients[orig.UserId], &r) + if err == nil { + t.Fatalf("Expected error creating report with too-long Lua") + } + if herr, ok := err.(*handlers.Error); ok { + if herr.ErrorId != 3 { // Invalid requeset + t.Fatalf("Unexpected API error creating report with too-long Lua: %s", herr) + } + } else { + t.Fatalf("Unexpected error creating report with too-long Lua") + } } }) } func TestGetReport(t *testing.T) { RunWith(t, &data[0], func(t *testing.T, d *TestData) { - for i := 1; i < len(data[0].reports); i++ { + for i := 0; i < len(data[0].reports); i++ { orig := data[0].reports[i] curr := d.reports[i] @@ -87,9 +100,49 @@ func TestGetReport(t *testing.T) { }) } +func TestGetReports(t *testing.T) { + RunWith(t, &data[0], func(t *testing.T, d *TestData) { + rl, err := getReports(d.clients[0]) + if err != nil { + t.Fatalf("Error fetching reports: %s\n", err) + } + + numreports := 0 + foundIds := make(map[int64]bool) + for i := 0; i < len(data[0].reports); i++ { + orig := data[0].reports[i] + curr := d.reports[i] + + if curr.UserId != d.users[0].UserId { + continue + } + numreports += 1 + + found := false + for _, r := range *rl.Reports { + if orig.Name == r.Name && orig.Lua == r.Lua { + if _, ok := foundIds[r.ReportId]; ok { + continue + } + foundIds[r.ReportId] = true + found = true + break + } + } + if !found { + t.Errorf("Unable to find matching report: %+v", orig) + } + } + + if numreports != len(*rl.Reports) { + t.Fatalf("Expected %d reports, received %d", numreports, len(*rl.Reports)) + } + }) +} + func TestUpdateReport(t *testing.T) { RunWith(t, &data[0], func(t *testing.T, d *TestData) { - for i := 1; i < len(data[0].reports); i++ { + for i := 0; i < len(data[0].reports); i++ { orig := data[0].reports[i] curr := d.reports[i] @@ -110,13 +163,26 @@ func TestUpdateReport(t *testing.T) { if r.Lua != curr.Lua { t.Errorf("Lua doesn't match") } + + r.Lua = string(make([]byte, handlers.LuaMaxLength+1)) + _, err = updateReport(d.clients[orig.UserId], r) + if err == nil { + t.Fatalf("Expected error updating report with too-long Lua") + } + if herr, ok := err.(*handlers.Error); ok { + if herr.ErrorId != 3 { // Invalid requeset + t.Fatalf("Unexpected API error updating report with too-long Lua: %s", herr) + } + } else { + t.Fatalf("Unexpected error updating report with too-long Lua") + } } }) } func TestDeleteReport(t *testing.T) { RunWith(t, &data[0], func(t *testing.T, d *TestData) { - for i := 1; i < len(data[0].reports); i++ { + for i := 0; i < len(data[0].reports); i++ { orig := data[0].reports[i] curr := d.reports[i] diff --git a/internal/handlers/testdata_test.go b/internal/handlers/testdata_test.go index c505e59..f380bd7 100644 --- a/internal/handlers/testdata_test.go +++ b/internal/handlers/testdata_test.go @@ -321,7 +321,7 @@ var data = []TestData{ reports: []handlers.Report{ handlers.Report{ UserId: 0, - Name: "", + Name: "Monthly Expenses", Lua: ` function account_series_map(accounts, tabulation) map = {}