Move to a consistent way of handling IDs in URLs

This commit is contained in:
Aaron Lindsay 2017-11-12 21:12:49 -05:00
parent e99abfe866
commit 9624f0c5bc
9 changed files with 66 additions and 94 deletions

View File

@ -5,7 +5,6 @@ import (
"errors"
"log"
"net/http"
"regexp"
"strings"
)
@ -100,14 +99,6 @@ type AccountList struct {
Accounts *[]Account `json:"accounts"`
}
var accountTransactionsRE *regexp.Regexp
var accountImportRE *regexp.Regexp
func init() {
accountTransactionsRE = regexp.MustCompile(`^/v1/accounts/[0-9]+/transactions/?$`)
accountImportRE = regexp.MustCompile(`^/v1/accounts/[0-9]+/imports/[a-z]+/?$`)
}
func (a *Account) Write(w http.ResponseWriter) error {
enc := json.NewEncoder(w)
return enc.Encode(a)
@ -384,18 +375,12 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
}
if r.Method == "POST" {
// if URL looks like /v1/accounts/[0-9]+/imports, use the account
// import handler
if accountImportRE.MatchString(r.URL.Path) {
var accountid int64
var importtype string
n, err := GetURLPieces(r.URL.Path, "/v1/accounts/%d/imports/%s", &accountid, &importtype)
if err != nil || n != 2 {
log.Print(err)
return NewError(999 /*Internal Error*/)
if !context.LastLevel() {
accountid, err := context.NextID()
if err != nil || context.NextLevel() != "imports" {
return NewError(3 /*Invalid Request*/)
}
return AccountImportHandler(context, r, user, accountid, importtype)
return AccountImportHandler(context, r, user, accountid)
}
account_json := r.PostFormValue("account")
@ -433,10 +418,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
return ResponseWrapper{201, &account}
} else if r.Method == "GET" {
var accountid int64
n, err := GetURLPieces(r.URL.Path, "/v1/accounts/%d", &accountid)
if err != nil || n != 1 {
if context.LastLevel() {
//Return all Accounts
var al AccountList
accounts, err := GetAccounts(context.Tx, user.UserId)
@ -446,13 +428,14 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
}
al.Accounts = accounts
return &al
} else {
// if URL looks like /account/[0-9]+/transactions, use the account
// transaction handler
if accountTransactionsRE.MatchString(r.URL.Path) {
return AccountTransactionsHandler(context, r, user, accountid)
}
}
accountid, err := context.NextID()
if err != nil {
return NewError(3 /*Invalid Request*/)
}
if context.LastLevel() {
// Return Account with this Id
account, err := GetAccount(context.Tx, accountid, user.UserId)
if err != nil {
@ -460,9 +443,11 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter {
}
return account
} else if context.NextLevel() == "transactions" {
return AccountTransactionsHandler(context, r, user, accountid)
}
} else {
accountid, err := GetURLID(r.URL.Path)
accountid, err := context.NextID()
if err != nil {
return NewError(3 /*Invalid Request*/)
}

View File

@ -5,6 +5,7 @@ import (
"log"
"net/http"
"path"
"strconv"
"strings"
)
@ -35,6 +36,14 @@ func (c *Context) NextLevel() string {
return split[0]
}
func (c *Context) NextID() (int64, error) {
return strconv.ParseInt(c.NextLevel(), 0, 64)
}
func (c *Context) LastLevel() bool {
return len(c.remainingURL) == 0
}
type Handler func(*http.Request, *Context) ResponseWriterWriter
type APIHandler struct {

View File

@ -335,9 +335,10 @@ func OFXFileImportHandler(context *Context, r *http.Request, user *User, account
/*
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated
*/
func AccountImportHandler(context *Context, r *http.Request, user *User, accountid int64, importtype string) ResponseWriterWriter {
func AccountImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter {
switch importtype {
importType := context.NextLevel()
switch importType {
case "ofx":
return OFXImportHandler(context, r, user, accountid)
case "ofxfile":

View File

@ -165,10 +165,7 @@ func PriceHandler(r *http.Request, context *Context) ResponseWriterWriter {
return ResponseWrapper{201, &price}
} else if r.Method == "GET" {
var priceid int64
n, err := GetURLPieces(r.URL.Path, "/v1/prices/%d", &priceid)
if err != nil || n != 1 {
if context.LastLevel() {
//Return all prices
var pl PriceList
@ -180,16 +177,21 @@ func PriceHandler(r *http.Request, context *Context) ResponseWriterWriter {
pl.Prices = prices
return &pl
} else {
price, err := GetPrice(context.Tx, priceid, user.UserId)
if err != nil {
return NewError(3 /*Invalid Request*/)
}
return price
}
priceid, err := context.NextID()
if err != nil {
return NewError(3 /*Invalid Request*/)
}
price, err := GetPrice(context.Tx, priceid, user.UserId)
if err != nil {
return NewError(3 /*Invalid Request*/)
}
return price
} else {
priceid, err := GetURLID(r.URL.Path)
priceid, err := context.NextID()
if err != nil {
return NewError(3 /*Invalid Request*/)
}

View File

@ -8,17 +8,10 @@ import (
"github.com/yuin/gopher-lua"
"log"
"net/http"
"regexp"
"strings"
"time"
)
var reportTabulationRE *regexp.Regexp
func init() {
reportTabulationRE = regexp.MustCompile(`^/v1/reports/[0-9]+/tabulations/?$`)
}
//type and value to store user in lua's Context
type key int
@ -255,19 +248,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
return ResponseWrapper{201, &report}
} else if r.Method == "GET" {
if reportTabulationRE.MatchString(r.URL.Path) {
var reportid int64
n, err := GetURLPieces(r.URL.Path, "/v1/reports/%d/tabulations", &reportid)
if err != nil || n != 1 {
log.Print(err)
return NewError(999 /*InternalError*/)
}
return ReportTabulationHandler(context.Tx, r, user, reportid)
}
var reportid int64
n, err := GetURLPieces(r.URL.Path, "/v1/reports/%d", &reportid)
if err != nil || n != 1 {
if context.LastLevel() {
//Return all Reports
var rl ReportList
reports, err := GetReports(context.Tx, user.UserId)
@ -277,6 +258,15 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
}
rl.Reports = reports
return &rl
}
reportid, err := context.NextID()
if err != nil {
return NewError(3 /*Invalid Request*/)
}
if context.NextLevel() == "tabulations" {
return ReportTabulationHandler(context.Tx, r, user, reportid)
} else {
// Return Report with this Id
report, err := GetReport(context.Tx, reportid, user.UserId)
@ -287,7 +277,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter {
return report
}
} else {
reportid, err := GetURLID(r.URL.Path)
reportid, err := context.NextID()
if err != nil {
return NewError(3 /*Invalid Request*/)
}

View File

@ -274,10 +274,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
return ResponseWrapper{201, &security}
} else if r.Method == "GET" {
var securityid int64
n, err := GetURLPieces(r.URL.Path, "/v1/securities/%d", &securityid)
if err != nil || n != 1 {
if context.LastLevel() {
//Return all securities
var sl SecurityList
@ -290,6 +287,10 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
sl.Securities = securities
return &sl
} else {
securityid, err := context.NextID()
if err != nil {
return NewError(3 /*Invalid Request*/)
}
security, err := GetSecurity(context.Tx, securityid, user.UserId)
if err != nil {
return NewError(3 /*Invalid Request*/)
@ -298,7 +299,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter {
return security
}
} else {
securityid, err := GetURLID(r.URL.Path)
securityid, err := context.NextID()
if err != nil {
return NewError(3 /*Invalid Request*/)
}

View File

@ -452,9 +452,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
return &transaction
} else if r.Method == "GET" {
transactionid, err := GetURLID(r.URL.Path)
if err != nil {
if context.LastLevel() {
//Return all Transactions
var al TransactionList
transactions, err := GetTransactions(context.Tx, user.UserId)
@ -466,6 +464,10 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
return &al
} else {
//Return Transaction with this Id
transactionid, err := context.NextID()
if err != nil {
return NewError(3 /*Invalid Request*/)
}
transaction, err := GetTransaction(context.Tx, transactionid, user.UserId)
if err != nil {
return NewError(3 /*Invalid Request*/)
@ -473,7 +475,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
return transaction
}
} else {
transactionid, err := GetURLID(r.URL.Path)
transactionid, err := context.NextID()
if err != nil {
return NewError(3 /*Invalid Request*/)
}
@ -518,11 +520,6 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter
return &transaction
} else if r.Method == "DELETE" {
transactionid, err := GetURLID(r.URL.Path)
if err != nil {
return NewError(3 /*Invalid Request*/)
}
transaction, err := GetTransaction(context.Tx, transactionid, user.UserId)
if err != nil {
return NewError(3 /*Invalid Request*/)

View File

@ -207,7 +207,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter {
return NewError(1 /*Not Signed In*/)
}
userid, err := GetURLID(r.URL.Path)
userid, err := context.NextID()
if err != nil {
return NewError(3 /*Invalid Request*/)
}

View File

@ -3,21 +3,8 @@ package handlers
import (
"fmt"
"net/http"
"strconv"
"strings"
)
func GetURLID(url string) (int64, error) {
pieces := strings.Split(strings.Trim(url, "/"), "/")
return strconv.ParseInt(pieces[len(pieces)-1], 10, 0)
}
func GetURLPieces(url string, format string, a ...interface{}) (int, error) {
url = strings.Replace(url, "/", " ", -1)
format = strings.Replace(format, "/", " ", -1)
return fmt.Sscanf(url, format, a...)
}
type ResponseWrapper struct {
Code int
Writer ResponseWriterWriter