From 9b1b682baca0550b0f7b1353633957ba8ede57ee Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Fri, 10 Nov 2017 20:13:49 -0500 Subject: [PATCH] prices: Implement initial API and tests --- internal/handlers/handlers.go | 1 + internal/handlers/prices.go | 165 ++++++++++++++++++++-- internal/handlers/prices_test.go | 212 +++++++++++++++++++++++++++++ internal/handlers/testdata_test.go | 43 +++++- 4 files changed, 412 insertions(+), 9 deletions(-) create mode 100644 internal/handlers/prices_test.go diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index d518e92..3ce7173 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -55,6 +55,7 @@ func GetHandler(db *gorp.DbMap) *http.ServeMux { servemux.HandleFunc("/session/", TxHandlerFunc(SessionHandler, db)) servemux.HandleFunc("/user/", TxHandlerFunc(UserHandler, db)) servemux.HandleFunc("/security/", TxHandlerFunc(SecurityHandler, db)) + servemux.HandleFunc("/price/", TxHandlerFunc(PriceHandler, db)) servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler) servemux.HandleFunc("/account/", TxHandlerFunc(AccountHandler, db)) servemux.HandleFunc("/transaction/", TxHandlerFunc(TransactionHandler, db)) diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index 81b32d0..897ca1b 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -1,6 +1,10 @@ package handlers import ( + "encoding/json" + "log" + "net/http" + "strings" "time" ) @@ -13,18 +17,34 @@ type Price struct { RemoteId string // unique ID from source, for detecting duplicates } -func InsertPrice(tx *Tx, p *Price) error { - err := tx.Insert(p) - if err != nil { - return err - } - return nil +type PriceList struct { + Prices *[]*Price `json:"prices"` +} + +func (p *Price) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(p) +} + +func (p *Price) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(p) +} + +func (pl *PriceList) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(pl) +} + +func (pl *PriceList) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(pl) } func CreatePriceIfNotExist(tx *Tx, price *Price) error { if len(price.RemoteId) == 0 { // Always create a new price if we can't match on the RemoteId - err := InsertPrice(tx, price) + err := tx.Insert(price) if err != nil { return err } @@ -42,13 +62,32 @@ func CreatePriceIfNotExist(tx *Tx, price *Price) error { return nil // price already exists } - err = InsertPrice(tx, price) + err = tx.Insert(price) if err != nil { return err } return nil } +func GetPrice(tx *Tx, priceid, userid int64) (*Price, error) { + var p Price + err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId IN (SELECT SecurityId FROM securities WHERE UserId=?)", priceid, userid) + if err != nil { + return nil, err + } + return &p, nil +} + +func GetPrices(tx *Tx, userid int64) (*[]*Price, error) { + var prices []*Price + + _, err := tx.Select(&prices, "SELECT * from prices where SecurityId IN (SELECT SecurityId FROM securities WHERE UserId=?)", userid) + if err != nil { + return nil, err + } + return &prices, nil +} + // Return the latest price for security in currency units before date func GetLatestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { var p Price @@ -89,3 +128,113 @@ func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Pr return earliest, nil } } + +func PriceHandler(r *http.Request, tx *Tx) ResponseWriterWriter { + user, err := GetUserFromSession(tx, r) + if err != nil { + return NewError(1 /*Not Signed In*/) + } + + if r.Method == "POST" { + price_json := r.PostFormValue("price") + if price_json == "" { + return NewError(3 /*Invalid Request*/) + } + + var price Price + err := price.Read(price_json) + if err != nil { + return NewError(3 /*Invalid Request*/) + } + price.PriceId = -1 + + _, err = GetSecurity(tx, price.SecurityId, user.UserId) + if err != nil { + return NewError(3 /*Invalid Request*/) + } + _, err = GetSecurity(tx, price.CurrencyId, user.UserId) + if err != nil { + return NewError(3 /*Invalid Request*/) + } + + err = tx.Insert(&price) + if err != nil { + log.Print(err) + return NewError(999 /*Internal Error*/) + } + + return ResponseWrapper{201, &price} + } else if r.Method == "GET" { + var priceid int64 + n, err := GetURLPieces(r.URL.Path, "/price/%d", &priceid) + + if err != nil || n != 1 { + //Return all prices + var pl PriceList + + prices, err := GetPrices(tx, user.UserId) + if err != nil { + log.Print(err) + return NewError(999 /*Internal Error*/) + } + + pl.Prices = prices + return &pl + } else { + price, err := GetPrice(tx, priceid, user.UserId) + if err != nil { + return NewError(3 /*Invalid Request*/) + } + + return price + } + } else { + priceid, err := GetURLID(r.URL.Path) + if err != nil { + return NewError(3 /*Invalid Request*/) + } + if r.Method == "PUT" { + price_json := r.PostFormValue("price") + if price_json == "" { + return NewError(3 /*Invalid Request*/) + } + + var price Price + err := price.Read(price_json) + if err != nil || price.PriceId != priceid { + return NewError(3 /*Invalid Request*/) + } + + _, err = GetSecurity(tx, price.SecurityId, user.UserId) + if err != nil { + return NewError(3 /*Invalid Request*/) + } + _, err = GetSecurity(tx, price.CurrencyId, user.UserId) + if err != nil { + return NewError(3 /*Invalid Request*/) + } + + count, err := tx.Update(&price) + if err != nil || count != 1 { + log.Print(err) + return NewError(999 /*Internal Error*/) + } + + return &price + } else if r.Method == "DELETE" { + price, err := GetPrice(tx, priceid, user.UserId) + if err != nil { + return NewError(3 /*Invalid Request*/) + } + + count, err := tx.Delete(price) + if err != nil || count != 1 { + log.Print(err) + return NewError(999 /*Internal Error*/) + } + + return SuccessWriter{} + } + } + return NewError(3 /*Invalid Request*/) +} diff --git a/internal/handlers/prices_test.go b/internal/handlers/prices_test.go new file mode 100644 index 0000000..7c019d5 --- /dev/null +++ b/internal/handlers/prices_test.go @@ -0,0 +1,212 @@ +package handlers_test + +import ( + "github.com/aclindsa/moneygo/internal/handlers" + "net/http" + "strconv" + "testing" + "time" +) + +func createPrice(client *http.Client, price *handlers.Price) (*handlers.Price, error) { + var p handlers.Price + err := create(client, price, &p, "/price/", "price") + return &p, err +} + +func getPrice(client *http.Client, priceid int64) (*handlers.Price, error) { + var p handlers.Price + err := read(client, &p, "/price/"+strconv.FormatInt(priceid, 10), "price") + if err != nil { + return nil, err + } + return &p, nil +} + +func getPrices(client *http.Client) (*handlers.PriceList, error) { + var pl handlers.PriceList + err := read(client, &pl, "/price/", "prices") + if err != nil { + return nil, err + } + return &pl, nil +} + +func updatePrice(client *http.Client, price *handlers.Price) (*handlers.Price, error) { + var p handlers.Price + err := update(client, price, &p, "/price/"+strconv.FormatInt(price.PriceId, 10), "price") + if err != nil { + return nil, err + } + return &p, nil +} + +func deletePrice(client *http.Client, p *handlers.Price) error { + err := remove(client, "/price/"+strconv.FormatInt(p.PriceId, 10), "price") + if err != nil { + return err + } + return nil +} + +func TestCreatePrice(t *testing.T) { + RunWith(t, &data[0], func(t *testing.T, d *TestData) { + for i := 0; i < len(data[0].prices); i++ { + orig := data[0].prices[i] + p := d.prices[i] + + if p.PriceId == 0 { + t.Errorf("Unable to create price: %+v", p) + } + if p.SecurityId != d.securities[orig.SecurityId].SecurityId { + t.Errorf("SecurityId doesn't match") + } + if p.CurrencyId != d.securities[orig.CurrencyId].SecurityId { + t.Errorf("CurrencyId doesn't match") + } + if p.Date != orig.Date { + t.Errorf("Date doesn't match") + } + if p.Value != orig.Value { + t.Errorf("Value doesn't match") + } + if p.RemoteId != orig.RemoteId { + t.Errorf("RemoteId doesn't match") + } + } + }) +} + +func TestGetPrice(t *testing.T) { + RunWith(t, &data[0], func(t *testing.T, d *TestData) { + for i := 0; i < len(data[0].prices); i++ { + orig := data[0].prices[i] + curr := d.prices[i] + + userid := data[0].securities[orig.SecurityId].UserId + p, err := getPrice(d.clients[userid], curr.PriceId) + if err != nil { + t.Fatalf("Error fetching price: %s\n", err) + } + if p.SecurityId != d.securities[orig.SecurityId].SecurityId { + t.Errorf("SecurityId doesn't match") + } + if p.CurrencyId != d.securities[orig.CurrencyId].SecurityId { + t.Errorf("CurrencyId doesn't match") + } + if p.Date != orig.Date { + t.Errorf("Date doesn't match") + } + if p.Value != orig.Value { + t.Errorf("Value doesn't match") + } + if p.RemoteId != orig.RemoteId { + t.Errorf("RemoteId doesn't match") + } + } + }) +} + +func TestGetPrices(t *testing.T) { + RunWith(t, &data[0], func(t *testing.T, d *TestData) { + pl, err := getPrices(d.clients[0]) + if err != nil { + t.Fatalf("Error fetching prices: %s\n", err) + } + + numprices := 0 + foundIds := make(map[int64]bool) + for i := 0; i < len(data[0].prices); i++ { + orig := data[0].prices[i] + + if data[0].securities[orig.SecurityId].UserId != 0 { + continue + } + numprices += 1 + + found := false + for _, p := range *pl.Prices { + if p.SecurityId == d.securities[orig.SecurityId].SecurityId && p.CurrencyId == d.securities[orig.CurrencyId].SecurityId && p.Date != orig.Date && p.Value != orig.Value && p.RemoteId != orig.RemoteId { + if _, ok := foundIds[p.PriceId]; ok { + continue + } + foundIds[p.PriceId] = true + found = true + break + } + } + if !found { + t.Errorf("Unable to find matching price: %+v", orig) + } + } + + if numprices != len(*pl.Prices) { + t.Fatalf("Expected %d prices, received %d", numprices, len(*pl.Prices)) + } + }) +} + +func TestUpdatePrice(t *testing.T) { + RunWith(t, &data[0], func(t *testing.T, d *TestData) { + for i := 0; i < len(data[0].prices); i++ { + orig := data[0].prices[i] + curr := d.prices[i] + + tmp := curr.SecurityId + curr.SecurityId = curr.CurrencyId + curr.CurrencyId = tmp + curr.Value = "5.55" + curr.Date = time.Date(2019, time.June, 5, 12, 5, 6, 7, time.UTC) + curr.RemoteId = "something" + + userid := data[0].securities[orig.SecurityId].UserId + p, err := updatePrice(d.clients[userid], &curr) + if err != nil { + t.Fatalf("Error updating price: %s\n", err) + } + + if p.SecurityId != curr.SecurityId { + t.Errorf("SecurityId doesn't match") + } + if p.CurrencyId != curr.CurrencyId { + t.Errorf("CurrencyId doesn't match") + } + if p.Date != curr.Date { + t.Errorf("Date doesn't match") + } + if p.Value != curr.Value { + t.Errorf("Value doesn't match") + } + if p.RemoteId != curr.RemoteId { + t.Errorf("RemoteId doesn't match") + } + } + }) +} + +func TestDeletePrice(t *testing.T) { + RunWith(t, &data[0], func(t *testing.T, d *TestData) { + for i := 0; i < len(data[0].prices); i++ { + orig := data[0].prices[i] + curr := d.prices[i] + + userid := data[0].securities[orig.SecurityId].UserId + err := deletePrice(d.clients[userid], &curr) + if err != nil { + t.Fatalf("Error deleting price: %s\n", err) + } + + _, err = getPrice(d.clients[userid], curr.PriceId) + if err == nil { + t.Fatalf("Expected error fetching deleted price") + } + if herr, ok := err.(*handlers.Error); ok { + if herr.ErrorId != 3 { // Invalid requeset + t.Fatalf("Unexpected API error fetching deleted price: %s", herr) + } + } else { + t.Fatalf("Unexpected error fetching deleted price") + } + } + }) +} diff --git a/internal/handlers/testdata_test.go b/internal/handlers/testdata_test.go index 5bd81b3..e6dc770 100644 --- a/internal/handlers/testdata_test.go +++ b/internal/handlers/testdata_test.go @@ -37,9 +37,9 @@ type TestData struct { users []User clients []*http.Client securities []handlers.Security + prices []handlers.Price accounts []handlers.Account // accounts must appear after their parents in this slice transactions []handlers.Transaction - prices []handlers.Price reports []handlers.Report tabulations []handlers.Tabulation } @@ -90,6 +90,17 @@ func (t *TestData) Initialize() (*TestData, error) { t2.securities = append(t2.securities, *s2) } + for _, price := range t.prices { + userid := t.securities[price.SecurityId].UserId + price.SecurityId = t2.securities[price.SecurityId].SecurityId + price.CurrencyId = t2.securities[price.CurrencyId].SecurityId + p2, err := createPrice(t2.clients[userid], &price) + if err != nil { + return nil, err + } + t2.prices = append(t2.prices, *p2) + } + for _, account := range t.accounts { account.SecurityId = t2.securities[account.SecurityId].SecurityId if account.ParentAccountId != -1 { @@ -190,6 +201,36 @@ var data = []TestData{ AlternateId: "978", }, }, + prices: []handlers.Price{ + handlers.Price{ + SecurityId: 1, + CurrencyId: 0, + Date: time.Date(2017, time.January, 2, 21, 0, 0, 0, time.UTC), + Value: "225.24", + RemoteId: "12387-129831-1238", + }, + handlers.Price{ + SecurityId: 1, + CurrencyId: 0, + Date: time.Date(2017, time.January, 3, 21, 0, 0, 0, time.UTC), + Value: "226.58", + RemoteId: "12387-129831-1239", + }, + handlers.Price{ + SecurityId: 1, + CurrencyId: 0, + Date: time.Date(2017, time.January, 4, 21, 0, 0, 0, time.UTC), + Value: "226.40", + RemoteId: "12387-129831-1240", + }, + handlers.Price{ + SecurityId: 1, + CurrencyId: 0, + Date: time.Date(2017, time.January, 5, 21, 0, 0, 0, time.UTC), + Value: "227.21", + RemoteId: "12387-129831-1241", + }, + }, accounts: []handlers.Account{ handlers.Account{ UserId: 0,