diff --git a/internal/db/db.go b/internal/db/db.go index a23a029..a33fb08 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/aclindsa/gorp" "github.com/aclindsa/moneygo/internal/config" - "github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/models" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" @@ -44,8 +43,8 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId") dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId") dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId") - rtable := dbmap.AddTableWithName(handlers.Report{}, "reports").SetKeys(true, "ReportId") - rtable.ColMap("Lua").SetMaxSize(handlers.LuaMaxLength + luaMaxLengthBuffer) + rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId") + rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer) err := dbmap.CreateTablesIfNotExists() if err != nil { diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index cab32d1..bec9525 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -2,14 +2,12 @@ package handlers import ( "context" - "encoding/json" "errors" "fmt" "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" "log" "net/http" - "strings" "time" ) @@ -26,67 +24,8 @@ const ( const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for -type Report struct { - ReportId int64 - UserId int64 - Name string - Lua string -} - -// The maximum length (in bytes) the Lua code may be. This is used to set the -// max size of the database columns (with an added fudge factor) -const LuaMaxLength int = 65536 - -func (r *Report) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(r) -} - -func (r *Report) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(r) -} - -type ReportList struct { - Reports *[]Report `json:"reports"` -} - -func (rl *ReportList) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(rl) -} - -func (rl *ReportList) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(rl) -} - -type Series struct { - Values []float64 - Series map[string]*Series -} - -type Tabulation struct { - ReportId int64 - Title string - Subtitle string - Units string - Labels []string - Series map[string]*Series -} - -func (t *Tabulation) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(t) -} - -func (t *Tabulation) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(t) -} - -func GetReport(tx *Tx, reportid int64, userid int64) (*Report, error) { - var r Report +func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) { + var r models.Report err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) if err != nil { @@ -95,8 +34,8 @@ func GetReport(tx *Tx, reportid int64, userid int64) (*Report, error) { return &r, nil } -func GetReports(tx *Tx, userid int64) (*[]Report, error) { - var reports []Report +func GetReports(tx *Tx, userid int64) (*[]models.Report, error) { + var reports []models.Report _, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid) if err != nil { @@ -105,7 +44,7 @@ func GetReports(tx *Tx, userid int64) (*[]Report, error) { return &reports, nil } -func InsertReport(tx *Tx, r *Report) error { +func InsertReport(tx *Tx, r *models.Report) error { err := tx.Insert(r) if err != nil { return err @@ -113,7 +52,7 @@ func InsertReport(tx *Tx, r *Report) error { return nil } -func UpdateReport(tx *Tx, r *Report) error { +func UpdateReport(tx *Tx, r *models.Report) error { count, err := tx.Update(r) if err != nil { return err @@ -124,7 +63,7 @@ func UpdateReport(tx *Tx, r *Report) error { return nil } -func DeleteReport(tx *Tx, r *Report) error { +func DeleteReport(tx *Tx, r *models.Report) error { count, err := tx.Delete(r) if err != nil { return err @@ -135,7 +74,7 @@ func DeleteReport(tx *Tx, r *Report) error { return nil } -func runReport(tx *Tx, user *models.User, report *Report) (*Tabulation, error) { +func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabulation, error) { // Create a new LState without opening the default libs for security L := lua.NewState(lua.Options{SkipOpenLibs: true}) defer L.Close() @@ -189,7 +128,7 @@ func runReport(tx *Tx, user *models.User, report *Report) (*Tabulation, error) { value := L.Get(-1) if ud, ok := value.(*lua.LUserData); ok { - if tabulation, ok := ud.Value.(*Tabulation); ok { + if tabulation, ok := ud.Value.(*models.Tabulation); ok { return tabulation, nil } else { return nil, fmt.Errorf("generate() for %s (Id: %d) didn't return a tabulation", report.Name, report.ReportId) @@ -224,14 +163,14 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { } if r.Method == "POST" { - var report Report + var report models.Report if err := ReadJSON(r, &report); err != nil { return NewError(3 /*Invalid Request*/) } report.ReportId = -1 report.UserId = user.UserId - if len(report.Lua) >= LuaMaxLength { + if len(report.Lua) >= models.LuaMaxLength { return NewError(3 /*Invalid Request*/) } @@ -245,7 +184,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { } else if r.Method == "GET" { if context.LastLevel() { //Return all Reports - var rl ReportList + var rl models.ReportList reports, err := GetReports(context.Tx, user.UserId) if err != nil { log.Print(err) @@ -278,13 +217,13 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { } if r.Method == "PUT" { - var report Report + var report models.Report if err := ReadJSON(r, &report); err != nil || report.ReportId != reportid { return NewError(3 /*Invalid Request*/) } report.UserId = user.UserId - if len(report.Lua) >= LuaMaxLength { + if len(report.Lua) >= models.LuaMaxLength { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/reports_lua.go b/internal/handlers/reports_lua.go index 2904a2b..51d919d 100644 --- a/internal/handlers/reports_lua.go +++ b/internal/handlers/reports_lua.go @@ -1,6 +1,7 @@ package handlers import ( + "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" ) @@ -21,9 +22,9 @@ func luaRegisterTabulations(L *lua.LState) { } // Checks whether the first lua argument is a *LUserData with *Tabulation and returns *Tabulation -func luaCheckTabulation(L *lua.LState, n int) *Tabulation { +func luaCheckTabulation(L *lua.LState, n int) *models.Tabulation { ud := L.CheckUserData(n) - if tabulation, ok := ud.Value.(*Tabulation); ok { + if tabulation, ok := ud.Value.(*models.Tabulation); ok { return tabulation } L.ArgError(n, "tabulation expected") @@ -31,9 +32,9 @@ func luaCheckTabulation(L *lua.LState, n int) *Tabulation { } // Checks whether the first lua argument is a *LUserData with *Series and returns *Series -func luaCheckSeries(L *lua.LState, n int) *Series { +func luaCheckSeries(L *lua.LState, n int) *models.Series { ud := L.CheckUserData(n) - if series, ok := ud.Value.(*Series); ok { + if series, ok := ud.Value.(*models.Series); ok { return series } L.ArgError(n, "series expected") @@ -43,9 +44,9 @@ func luaCheckSeries(L *lua.LState, n int) *Series { func luaTabulationNew(L *lua.LState) int { numvalues := L.CheckInt(1) ud := L.NewUserData() - ud.Value = &Tabulation{ + ud.Value = &models.Tabulation{ Labels: make([]string, numvalues), - Series: make(map[string]*Series), + Series: make(map[string]*models.Series), } L.SetMetatable(ud, L.GetTypeMetatable(luaTabulationTypeName)) L.Push(ud) @@ -94,8 +95,8 @@ func luaTabulationSeries(L *lua.LState) int { if ok { ud.Value = s } else { - tabulation.Series[name] = &Series{ - Series: make(map[string]*Series), + tabulation.Series[name] = &models.Series{ + Series: make(map[string]*models.Series), Values: make([]float64, cap(tabulation.Labels)), } ud.Value = tabulation.Series[name] @@ -175,8 +176,8 @@ func luaSeriesSeries(L *lua.LState) int { if ok { ud.Value = s } else { - parent.Series[name] = &Series{ - Series: make(map[string]*Series), + parent.Series[name] = &models.Series{ + Series: make(map[string]*models.Series), Values: make([]float64, cap(parent.Values)), } ud.Value = parent.Series[name] diff --git a/internal/handlers/reports_lua_test.go b/internal/handlers/reports_lua_test.go index 1ba1fa7..bf67f5b 100644 --- a/internal/handlers/reports_lua_test.go +++ b/internal/handlers/reports_lua_test.go @@ -2,7 +2,7 @@ package handlers_test import ( "fmt" - "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "testing" ) @@ -25,7 +25,7 @@ function generate() t:title(tostring(test())) return t end`, lt.Lua) - r := handlers.Report{ + r := models.Report{ Name: lt.Name, Lua: lua, } diff --git a/internal/handlers/reports_test.go b/internal/handlers/reports_test.go index 624cf0a..6f97b8c 100644 --- a/internal/handlers/reports_test.go +++ b/internal/handlers/reports_test.go @@ -2,19 +2,20 @@ package handlers_test import ( "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "strconv" "testing" ) -func createReport(client *http.Client, report *handlers.Report) (*handlers.Report, error) { - var r handlers.Report +func createReport(client *http.Client, report *models.Report) (*models.Report, error) { + var r models.Report err := create(client, report, &r, "/v1/reports/") return &r, err } -func getReport(client *http.Client, reportid int64) (*handlers.Report, error) { - var r handlers.Report +func getReport(client *http.Client, reportid int64) (*models.Report, error) { + var r models.Report err := read(client, &r, "/v1/reports/"+strconv.FormatInt(reportid, 10)) if err != nil { return nil, err @@ -22,8 +23,8 @@ func getReport(client *http.Client, reportid int64) (*handlers.Report, error) { return &r, nil } -func getReports(client *http.Client) (*handlers.ReportList, error) { - var rl handlers.ReportList +func getReports(client *http.Client) (*models.ReportList, error) { + var rl models.ReportList err := read(client, &rl, "/v1/reports/") if err != nil { return nil, err @@ -31,8 +32,8 @@ func getReports(client *http.Client) (*handlers.ReportList, error) { return &rl, nil } -func updateReport(client *http.Client, report *handlers.Report) (*handlers.Report, error) { - var r handlers.Report +func updateReport(client *http.Client, report *models.Report) (*models.Report, error) { + var r models.Report err := update(client, report, &r, "/v1/reports/"+strconv.FormatInt(report.ReportId, 10)) if err != nil { return nil, err @@ -40,7 +41,7 @@ func updateReport(client *http.Client, report *handlers.Report) (*handlers.Repor return &r, nil } -func deleteReport(client *http.Client, r *handlers.Report) error { +func deleteReport(client *http.Client, r *models.Report) error { err := remove(client, "/v1/reports/"+strconv.FormatInt(r.ReportId, 10)) if err != nil { return err @@ -48,8 +49,8 @@ func deleteReport(client *http.Client, r *handlers.Report) error { return nil } -func tabulateReport(client *http.Client, reportid int64) (*handlers.Tabulation, error) { - var t handlers.Tabulation +func tabulateReport(client *http.Client, reportid int64) (*models.Tabulation, error) { + var t models.Tabulation err := read(client, &t, "/v1/reports/"+strconv.FormatInt(reportid, 10)+"/tabulations") if err != nil { return nil, err @@ -73,7 +74,7 @@ func TestCreateReport(t *testing.T) { t.Errorf("Lua doesn't match") } - r.Lua = string(make([]byte, handlers.LuaMaxLength+1)) + r.Lua = string(make([]byte, models.LuaMaxLength+1)) _, err := createReport(d.clients[orig.UserId], &r) if err == nil { t.Fatalf("Expected error creating report with too-long Lua") @@ -173,7 +174,7 @@ func TestUpdateReport(t *testing.T) { t.Errorf("Lua doesn't match") } - r.Lua = string(make([]byte, handlers.LuaMaxLength+1)) + r.Lua = string(make([]byte, models.LuaMaxLength+1)) _, err = updateReport(d.clients[orig.UserId], r) if err == nil { t.Fatalf("Expected error updating report with too-long Lua") @@ -214,7 +215,7 @@ func TestDeleteReport(t *testing.T) { } }) } -func seriesEqualityHelper(t *testing.T, orig, curr map[string]*handlers.Series, name string) { +func seriesEqualityHelper(t *testing.T, orig, curr map[string]*models.Series, name string) { if orig == nil || curr == nil { if orig != nil { t.Fatalf("`%s` series unexpectedly nil", name) @@ -242,7 +243,7 @@ func seriesEqualityHelper(t *testing.T, orig, curr map[string]*handlers.Series, } } -func tabulationEqualityHelper(t *testing.T, orig, curr *handlers.Tabulation) { +func tabulationEqualityHelper(t *testing.T, orig, curr *models.Tabulation) { if orig.Title != curr.Title { t.Errorf("Tabulation Title doesn't match") } diff --git a/internal/handlers/testdata_test.go b/internal/handlers/testdata_test.go index 65bfee8..3d15888 100644 --- a/internal/handlers/testdata_test.go +++ b/internal/handlers/testdata_test.go @@ -3,7 +3,6 @@ package handlers_test import ( "encoding/json" "fmt" - "github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/models" "net/http" "strings" @@ -41,8 +40,8 @@ type TestData struct { prices []models.Price accounts []models.Account // accounts must appear after their parents in this slice transactions []models.Transaction - reports []handlers.Report - tabulations []handlers.Tabulation + reports []models.Report + tabulations []models.Tabulation } type TestDataFunc func(*testing.T, *TestData) @@ -382,7 +381,7 @@ var data = []TestData{ }, }, }, - reports: []handlers.Report{ + reports: []models.Report{ { UserId: 0, Name: "This Year's Monthly Expenses", @@ -440,39 +439,39 @@ function generate() end`, }, }, - tabulations: []handlers.Tabulation{ + tabulations: []models.Tabulation{ { ReportId: 0, Title: "2017 Monthly Expenses", Subtitle: "This is my subtitle", Units: "USD", Labels: []string{"2017-01-01", "2017-02-01", "2017-03-01", "2017-04-01", "2017-05-01", "2017-06-01", "2017-07-01", "2017-08-01", "2017-09-01", "2017-10-01", "2017-11-01", "2017-12-01"}, - Series: map[string]*handlers.Series{ + Series: map[string]*models.Series{ "Assets": { Values: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - Series: map[string]*handlers.Series{ + Series: map[string]*models.Series{ "Credit Union Checking": { Values: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - Series: map[string]*handlers.Series{}, + Series: map[string]*models.Series{}, }, }, }, "Expenses": { Values: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - Series: map[string]*handlers.Series{ + Series: map[string]*models.Series{ "Groceries": { Values: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 87.19, 0, 0}, - Series: map[string]*handlers.Series{}, + Series: map[string]*models.Series{}, }, "Cable": { Values: []float64{0, 0, 0, 0, 0, 0, 0, 0, 39.99, 0, 0, 0}, - Series: map[string]*handlers.Series{}, + Series: map[string]*models.Series{}, }, }, }, "Credit Card": { Values: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - Series: map[string]*handlers.Series{}, + Series: map[string]*models.Series{}, }, }, }, diff --git a/internal/models/reports.go b/internal/models/reports.go new file mode 100644 index 0000000..493fd21 --- /dev/null +++ b/internal/models/reports.go @@ -0,0 +1,66 @@ +package models + +import ( + "encoding/json" + "net/http" + "strings" +) + +type Report struct { + ReportId int64 + UserId int64 + Name string + Lua string +} + +// The maximum length (in bytes) the Lua code may be. This is used to set the +// max size of the database columns (with an added fudge factor) +const LuaMaxLength int = 65536 + +func (r *Report) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(r) +} + +func (r *Report) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(r) +} + +type ReportList struct { + Reports *[]Report `json:"reports"` +} + +func (rl *ReportList) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(rl) +} + +func (rl *ReportList) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(rl) +} + +type Series struct { + Values []float64 + Series map[string]*Series +} + +type Tabulation struct { + ReportId int64 + Title string + Subtitle string + Units string + Labels []string + Series map[string]*Series +} + +func (t *Tabulation) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(t) +} + +func (t *Tabulation) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(t) +}