From c48c50d2c55600defd699c26e39722d0bcf5075e Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Thu, 16 Nov 2017 19:17:51 -0500 Subject: [PATCH] API: Move prices under securities For example, instead of GETting /prices/5 to query a price with ID 5, you now must GET /securities/2/prices/5 (assuming price 5's SecurityId is 2) --- internal/handlers/handlers.go | 2 - internal/handlers/prices.go | 25 +++++----- internal/handlers/prices_test.go | 76 ++++++++++++++++-------------- internal/handlers/securities.go | 26 ++++++++++ internal/handlers/testdata_test.go | 16 +++++++ 5 files changed, 95 insertions(+), 50 deletions(-) diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index e4f1294..2022a2a 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -94,8 +94,6 @@ func (ah *APIHandler) route(r *http.Request) ResponseWriterWriter { return ah.txWrapper(SecurityHandler, r, context) case "securitytemplates": return SecurityTemplateHandler(r, context) - case "prices": - return ah.txWrapper(PriceHandler, r, context) case "accounts": return ah.txWrapper(AccountHandler, r, context) case "transactions": diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index d517dc9..2027497 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -69,19 +69,19 @@ func CreatePriceIfNotExist(tx *Tx, price *Price) error { return nil } -func GetPrice(tx *Tx, priceid, userid int64) (*Price, error) { +func GetPrice(tx *Tx, priceid, securityid int64) (*Price, error) { var p Price - err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId IN (SELECT SecurityId FROM securities WHERE UserId=?)", priceid, userid) + err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid) if err != nil { return nil, err } return &p, nil } -func GetPrices(tx *Tx, userid int64) (*[]*Price, error) { +func GetPrices(tx *Tx, securityid int64) (*[]*Price, error) { var prices []*Price - _, err := tx.Select(&prices, "SELECT * from prices where SecurityId IN (SELECT SecurityId FROM securities WHERE UserId=?)", userid) + _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid) if err != nil { return nil, err } @@ -129,10 +129,10 @@ func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Pr } } -func PriceHandler(r *http.Request, context *Context) ResponseWriterWriter { - user, err := GetUserFromSession(context.Tx, r) +func PriceHandler(r *http.Request, context *Context, user *User, securityid int64) ResponseWriterWriter { + security, err := GetSecurity(context.Tx, securityid, user.UserId) if err != nil { - return NewError(1 /*Not Signed In*/) + return NewError(3 /*Invalid Request*/) } if r.Method == "POST" { @@ -142,8 +142,7 @@ func PriceHandler(r *http.Request, context *Context) ResponseWriterWriter { } price.PriceId = -1 - _, err = GetSecurity(context.Tx, price.SecurityId, user.UserId) - if err != nil { + if price.SecurityId != security.SecurityId { return NewError(3 /*Invalid Request*/) } _, err = GetSecurity(context.Tx, price.CurrencyId, user.UserId) @@ -160,10 +159,10 @@ func PriceHandler(r *http.Request, context *Context) ResponseWriterWriter { return ResponseWrapper{201, &price} } else if r.Method == "GET" { if context.LastLevel() { - //Return all prices + //Return all this security's prices var pl PriceList - prices, err := GetPrices(context.Tx, user.UserId) + prices, err := GetPrices(context.Tx, security.SecurityId) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -178,7 +177,7 @@ func PriceHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } - price, err := GetPrice(context.Tx, priceid, user.UserId) + price, err := GetPrice(context.Tx, priceid, security.SecurityId) if err != nil { return NewError(3 /*Invalid Request*/) } @@ -212,7 +211,7 @@ func PriceHandler(r *http.Request, context *Context) ResponseWriterWriter { return &price } else if r.Method == "DELETE" { - price, err := GetPrice(context.Tx, priceid, user.UserId) + price, err := GetPrice(context.Tx, priceid, security.SecurityId) if err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/prices_test.go b/internal/handlers/prices_test.go index f799df1..d362614 100644 --- a/internal/handlers/prices_test.go +++ b/internal/handlers/prices_test.go @@ -10,22 +10,22 @@ import ( func createPrice(client *http.Client, price *handlers.Price) (*handlers.Price, error) { var p handlers.Price - err := create(client, price, &p, "/v1/prices/") + err := create(client, price, &p, "/v1/securities/"+strconv.FormatInt(price.SecurityId, 10)+"/prices/") return &p, err } -func getPrice(client *http.Client, priceid int64) (*handlers.Price, error) { +func getPrice(client *http.Client, priceid, securityid int64) (*handlers.Price, error) { var p handlers.Price - err := read(client, &p, "/v1/prices/"+strconv.FormatInt(priceid, 10)) + err := read(client, &p, "/v1/securities/"+strconv.FormatInt(securityid, 10)+"/prices/"+strconv.FormatInt(priceid, 10)) if err != nil { return nil, err } return &p, nil } -func getPrices(client *http.Client) (*handlers.PriceList, error) { +func getPrices(client *http.Client, securityid int64) (*handlers.PriceList, error) { var pl handlers.PriceList - err := read(client, &pl, "/v1/prices/") + err := read(client, &pl, "/v1/securities/"+strconv.FormatInt(securityid, 10)+"/prices/") if err != nil { return nil, err } @@ -34,7 +34,7 @@ func getPrices(client *http.Client) (*handlers.PriceList, error) { func updatePrice(client *http.Client, price *handlers.Price) (*handlers.Price, error) { var p handlers.Price - err := update(client, price, &p, "/v1/prices/"+strconv.FormatInt(price.PriceId, 10)) + err := update(client, price, &p, "/v1/securities/"+strconv.FormatInt(price.SecurityId, 10)+"/prices/"+strconv.FormatInt(price.PriceId, 10)) if err != nil { return nil, err } @@ -42,7 +42,7 @@ func updatePrice(client *http.Client, price *handlers.Price) (*handlers.Price, e } func deletePrice(client *http.Client, p *handlers.Price) error { - err := remove(client, "/v1/prices/"+strconv.FormatInt(p.PriceId, 10)) + err := remove(client, "/v1/securities/"+strconv.FormatInt(p.SecurityId, 10)+"/prices/"+strconv.FormatInt(p.PriceId, 10)) if err != nil { return err } @@ -84,7 +84,7 @@ func TestGetPrice(t *testing.T) { curr := d.prices[i] userid := data[0].securities[orig.SecurityId].UserId - p, err := getPrice(d.clients[userid], curr.PriceId) + p, err := getPrice(d.clients[userid], curr.PriceId, curr.SecurityId) if err != nil { t.Fatalf("Error fetching price: %s\n", err) } @@ -109,39 +109,45 @@ func TestGetPrice(t *testing.T) { func TestGetPrices(t *testing.T) { RunWith(t, &data[0], func(t *testing.T, d *TestData) { - pl, err := getPrices(d.clients[0]) - if err != nil { - t.Fatalf("Error fetching prices: %s\n", err) - } - - numprices := 0 - foundIds := make(map[int64]bool) - for i := 0; i < len(data[0].prices); i++ { - orig := data[0].prices[i] - - if data[0].securities[orig.SecurityId].UserId != 0 { + for origsecurityid, security := range d.securities { + if data[0].securities[origsecurityid].UserId != 0 { continue } - numprices += 1 - found := false - for _, p := range *pl.Prices { - if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date != orig.Date && p.Value != orig.Value && p.RemoteId != orig.RemoteId { - if _, ok := foundIds[p.PriceId]; ok { - continue + pl, err := getPrices(d.clients[0], security.SecurityId) + if err != nil { + t.Fatalf("Error fetching prices: %s\n", err) + } + + numprices := 0 + foundIds := make(map[int64]bool) + for i := 0; i < len(data[0].prices); i++ { + orig := data[0].prices[i] + + if orig.SecurityId != int64(origsecurityid) { + continue + } + numprices += 1 + + found := false + for _, p := range *pl.Prices { + if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date == orig.Date && p.Value == orig.Value && p.RemoteId == orig.RemoteId { + if _, ok := foundIds[p.PriceId]; ok { + continue + } + foundIds[p.PriceId] = true + found = true + break } - foundIds[p.PriceId] = true - found = true - break + } + if !found { + t.Errorf("Unable to find matching price: %+v", orig) } } - if !found { - t.Errorf("Unable to find matching price: %+v", orig) - } - } - if numprices != len(*pl.Prices) { - t.Fatalf("Expected %d prices, received %d", numprices, len(*pl.Prices)) + if numprices != len(*pl.Prices) { + t.Fatalf("Expected %d prices, received %d", numprices, len(*pl.Prices)) + } } }) } @@ -196,7 +202,7 @@ func TestDeletePrice(t *testing.T) { t.Fatalf("Error deleting price: %s\n", err) } - _, err = getPrice(d.clients[userid], curr.PriceId) + _, err = getPrice(d.clients[userid], curr.PriceId, curr.SecurityId) if err == nil { t.Fatalf("Expected error fetching deleted price") } diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index fa9e94b..0ca96d0 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -253,6 +253,17 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { } if r.Method == "POST" { + if !context.LastLevel() { + securityid, err := context.NextID() + if err != nil { + return NewError(3 /*Invalid Request*/) + } + if context.NextLevel() != "prices" { + return NewError(3 /*Invalid Request*/) + } + return PriceHandler(r, context, user, securityid) + } + var security Security if err := ReadJSON(r, &security); err != nil { return NewError(3 /*Invalid Request*/) @@ -285,6 +296,14 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { if err != nil { return NewError(3 /*Invalid Request*/) } + + if !context.LastLevel() { + if context.NextLevel() != "prices" { + return NewError(3 /*Invalid Request*/) + } + return PriceHandler(r, context, user, securityid) + } + security, err := GetSecurity(context.Tx, securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) @@ -297,6 +316,13 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { if err != nil { return NewError(3 /*Invalid Request*/) } + if !context.LastLevel() { + if context.NextLevel() != "prices" { + return NewError(3 /*Invalid Request*/) + } + return PriceHandler(r, context, user, securityid) + } + if r.Method == "PUT" { var security Security if err := ReadJSON(r, &security); err != nil || security.SecurityId != securityid { diff --git a/internal/handlers/testdata_test.go b/internal/handlers/testdata_test.go index e6dc770..143229e 100644 --- a/internal/handlers/testdata_test.go +++ b/internal/handlers/testdata_test.go @@ -200,6 +200,15 @@ var data = []TestData{ Type: handlers.Currency, AlternateId: "978", }, + handlers.Security{ + UserId: 0, + Name: "EUR", + Description: "Euro", + Symbol: "€", + Precision: 2, + Type: handlers.Currency, + AlternateId: "978", + }, }, prices: []handlers.Price{ handlers.Price{ @@ -230,6 +239,13 @@ var data = []TestData{ Value: "227.21", RemoteId: "12387-129831-1241", }, + handlers.Price{ + SecurityId: 0, + CurrencyId: 3, + Date: time.Date(2017, time.November, 16, 18, 49, 53, 0, time.UTC), + Value: "0.85", + RemoteId: "USDEUR819298714", + }, }, accounts: []handlers.Account{ handlers.Account{