From f72c86ef58d927c0136b0bdd57d2d68112438e19 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sun, 3 Dec 2017 06:38:22 -0500 Subject: [PATCH] Split securities into models --- internal/db/db.go | 2 +- internal/handlers/accounts.go | 3 +- internal/handlers/balance_lua.go | 3 +- internal/handlers/gnucash.go | 15 +-- internal/handlers/gnucash_test.go | 3 +- internal/handlers/imports.go | 2 +- internal/handlers/ofx.go | 49 ++++----- internal/handlers/ofx_test.go | 7 +- internal/handlers/prices.go | 6 +- .../handlers/scripts/gen_security_list.py | 11 ++- internal/handlers/securities.go | 99 +++++-------------- internal/handlers/securities_lua.go | 16 +-- internal/handlers/securities_test.go | 19 ++-- internal/handlers/security_templates_test.go | 7 +- internal/handlers/testdata_test.go | 13 +-- internal/handlers/users.go | 4 +- internal/models/securities.go | 62 ++++++++++++ 17 files changed, 170 insertions(+), 151 deletions(-) create mode 100644 internal/models/securities.go diff --git a/internal/db/db.go b/internal/db/db.go index 4ee3192..0d71bfa 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -37,7 +37,7 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId") dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId") dbmap.AddTableWithName(handlers.Account{}, "accounts").SetKeys(true, "AccountId") - dbmap.AddTableWithName(handlers.Security{}, "securities").SetKeys(true, "SecurityId") + dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId") dbmap.AddTableWithName(handlers.Transaction{}, "transactions").SetKeys(true, "TransactionId") dbmap.AddTableWithName(handlers.Split{}, "splits").SetKeys(true, "SplitId") dbmap.AddTableWithName(handlers.Price{}, "prices").SetKeys(true, "PriceId") diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index 9bdc6f3..dd3fa14 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -3,6 +3,7 @@ package handlers import ( "encoding/json" "errors" + "github.com/aclindsa/moneygo/internal/models" "log" "net/http" "strings" @@ -214,7 +215,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*Account, error) func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*Account, error) { var imbalanceAccount Account var account Account - xxxtemplate := FindSecurityTemplate("XXX", Currency) + xxxtemplate := FindSecurityTemplate("XXX", models.Currency) if xxxtemplate == nil { return nil, errors.New("Couldn't find XXX security template") } diff --git a/internal/handlers/balance_lua.go b/internal/handlers/balance_lua.go index 118d0d6..c4d6b63 100644 --- a/internal/handlers/balance_lua.go +++ b/internal/handlers/balance_lua.go @@ -1,12 +1,13 @@ package handlers import ( + "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" "math/big" ) type Balance struct { - Security *Security + Security *models.Security Amount *big.Rat } diff --git a/internal/handlers/gnucash.go b/internal/handlers/gnucash.go index 75a949c..9670f1e 100644 --- a/internal/handlers/gnucash.go +++ b/internal/handlers/gnucash.go @@ -6,6 +6,7 @@ import ( "encoding/xml" "errors" "fmt" + "github.com/aclindsa/moneygo/internal/models" "io" "log" "math" @@ -22,7 +23,7 @@ type GnucashXMLCommodity struct { XCode string `xml:"http://www.gnucash.org/XML/cmdty xcode"` } -type GnucashCommodity struct{ Security } +type GnucashCommodity struct{ models.Security } func (gc *GnucashCommodity) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { var gxc GnucashXMLCommodity @@ -35,12 +36,12 @@ func (gc *GnucashCommodity) UnmarshalXML(d *xml.Decoder, start xml.StartElement) gc.Description = gxc.Description gc.AlternateId = gxc.XCode - gc.Security.Type = Stock // assumed default + gc.Security.Type = models.Stock // assumed default if gxc.Type == "ISO4217" { - gc.Security.Type = Currency + gc.Security.Type = models.Currency // Get the number from our templates for the AlternateId because // Gnucash uses 'id' (our Name) to supply the string ISO4217 code - template := FindSecurityTemplate(gxc.Name, Currency) + template := FindSecurityTemplate(gxc.Name, models.Currency) if template == nil { return errors.New("Unable to find security template for Gnucash ISO4217 commodity") } @@ -125,7 +126,7 @@ type GnucashXMLImport struct { } type GnucashImport struct { - Securities []Security + Securities []models.Security Accounts []Account Transactions []Transaction Prices []Price @@ -143,7 +144,7 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { } // Fixup securities, making a map of them as we go - securityMap := make(map[string]Security) + securityMap := make(map[string]models.Security) for i := range gncxml.Commodities { s := gncxml.Commodities[i].Security s.SecurityId = int64(i + 1) @@ -169,7 +170,7 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { if !ok { return nil, fmt.Errorf("Unable to find currency '%s' for price '%s'", price.Currency.Name, price.Id) } - if currency.Type != Currency { + if currency.Type != models.Currency { return nil, fmt.Errorf("Currency for imported price isn't actually a currency\n") } p.PriceId = int64(i + 1) diff --git a/internal/handlers/gnucash_test.go b/internal/handlers/gnucash_test.go index 1cdf9d0..7eacc03 100644 --- a/internal/handlers/gnucash_test.go +++ b/internal/handlers/gnucash_test.go @@ -2,6 +2,7 @@ package handlers_test import ( "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "testing" ) @@ -94,7 +95,7 @@ func TestImportGnucash(t *testing.T) { accountBalanceHelper(t, d.clients[0], groceries, "287.56") // 87.19 from preexisting transactions and 200.37 from Gnucash accountBalanceHelper(t, d.clients[0], cable, "89.98") - var ge *handlers.Security + var ge *models.Security securities, err := getSecurities(d.clients[0]) if err != nil { t.Fatalf("Error fetching securities: %s\n", err) diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index a90fb8b..442c4e3 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -57,7 +57,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re // Find matching existing securities or create new ones for those // referenced by the OFX import. Also create a map from placeholder import // SecurityIds to the actual SecurityIDs - var securitymap = make(map[int64]Security) + var securitymap = make(map[int64]models.Security) for _, ofxsecurity := range itl.Securities { // save off since ImportGetCreateSecurity overwrites SecurityId on // ofxsecurity diff --git a/internal/handlers/ofx.go b/internal/handlers/ofx.go index 8c08a67..befd147 100644 --- a/internal/handlers/ofx.go +++ b/internal/handlers/ofx.go @@ -3,26 +3,27 @@ package handlers import ( "errors" "fmt" + "github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/ofxgo" "io" "math/big" ) type OFXImport struct { - Securities []Security + Securities []models.Security Accounts []Account Transactions []Transaction // Balances map[int64]string // map AccountIDs to ending balances } -func (i *OFXImport) GetSecurity(ofxsecurityid int64) (*Security, error) { +func (i *OFXImport) GetSecurity(ofxsecurityid int64) (*models.Security, error) { if ofxsecurityid < 0 || ofxsecurityid > int64(len(i.Securities)) { return nil, errors.New("OFXImport.GetSecurity: SecurityID out of range") } return &i.Securities[ofxsecurityid], nil } -func (i *OFXImport) GetSecurityAlternateId(alternateid string, securityType SecurityType) (*Security, error) { +func (i *OFXImport) GetSecurityAlternateId(alternateid string, securityType models.SecurityType) (*models.Security, error) { for _, security := range i.Securities { if alternateid == security.AlternateId && securityType == security.Type { return &security, nil @@ -32,18 +33,18 @@ func (i *OFXImport) GetSecurityAlternateId(alternateid string, securityType Secu return nil, errors.New("OFXImport.FindSecurity: Unable to find security") } -func (i *OFXImport) GetAddCurrency(isoname string) (*Security, error) { +func (i *OFXImport) GetAddCurrency(isoname string) (*models.Security, error) { for _, security := range i.Securities { - if isoname == security.Name && Currency == security.Type { + if isoname == security.Name && models.Currency == security.Type { return &security, nil } } - template := FindSecurityTemplate(isoname, Currency) + template := FindSecurityTemplate(isoname, models.Currency) if template == nil { return nil, fmt.Errorf("Failed to find Security for \"%s\"", isoname) } - var security Security = *template + var security models.Security = *template security.SecurityId = int64(len(i.Securities) + 1) i.Securities = append(i.Securities, security) @@ -186,13 +187,13 @@ func (i *OFXImport) importSecurities(seclist *ofxgo.SecurityList) error { } else { return errors.New("Can't import unrecognized type satisfying ofxgo.Security interface") } - s := Security{ + s := models.Security{ SecurityId: int64(len(i.Securities) + 1), Name: string(si.SecName), Description: string(si.Memo), Symbol: string(si.Ticker), Precision: 5, // TODO How to actually determine this? - Type: Stock, + Type: models.Stock, AlternateId: string(si.SecID.UniqueID), } if len(s.Description) == 0 { @@ -214,10 +215,10 @@ func (i *OFXImport) GetInvTran(invtran *ofxgo.InvTran) Transaction { return t } -func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&buy.InvTran) - security, err := i.GetSecurityAlternateId(string(buy.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(buy.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -348,10 +349,10 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *Security, account * return &t, nil } -func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&income.InvTran) - security, err := i.GetSecurityAlternateId(string(income.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(income.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -394,10 +395,10 @@ func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *Security, accoun return &t, nil } -func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&expense.InvTran) - security, err := i.GetSecurityAlternateId(string(expense.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(expense.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -439,7 +440,7 @@ func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *Securit return &t, nil } -func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&marginint.InvTran) memo := string(marginint.InvTran.Memo) @@ -478,10 +479,10 @@ func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curde return &t, nil } -func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&reinvest.InvTran) - security, err := i.GetSecurityAlternateId(string(reinvest.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(reinvest.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -634,10 +635,10 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *Security, return &t, nil } -func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&retofcap.InvTran) - security, err := i.GetSecurityAlternateId(string(retofcap.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(retofcap.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -679,10 +680,10 @@ func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *Security, return &t, nil } -func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&sell.InvTran) - security, err := i.GetSecurityAlternateId(string(sell.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(sell.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -819,7 +820,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *Security, accoun func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *Account) (*Transaction, error) { t := i.GetInvTran(&transfer.InvTran) - security, err := i.GetSecurityAlternateId(string(transfer.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(transfer.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -858,7 +859,7 @@ func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *Account) return &t, nil } -func (i *OFXImport) AddInvTransaction(invtran *ofxgo.InvTransaction, account *Account, curdef *Security) error { +func (i *OFXImport) AddInvTransaction(invtran *ofxgo.InvTransaction, account *Account, curdef *models.Security) error { if curdef.SecurityId < 1 || curdef.SecurityId > int64(len(i.Securities)) { return errors.New("Internal error: security index not found in OFX import\n") } diff --git a/internal/handlers/ofx_test.go b/internal/handlers/ofx_test.go index baf452a..11c3f04 100644 --- a/internal/handlers/ofx_test.go +++ b/internal/handlers/ofx_test.go @@ -3,6 +3,7 @@ package handlers_test import ( "fmt" "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "strconv" "testing" @@ -63,7 +64,7 @@ func TestImportOFXCreditCard(t *testing.T) { }) } -func findSecurity(client *http.Client, symbol string, tipe handlers.SecurityType) (*handlers.Security, error) { +func findSecurity(client *http.Client, symbol string, tipe models.SecurityType) (*models.Security, error) { securities, err := getSecurities(client) if err != nil { return nil, err @@ -125,7 +126,7 @@ func TestImportOFX401kMutualFunds(t *testing.T) { // Make sure the security was created and that the trading account has // the right value - security, err := findSecurity(d.clients[0], "VANGUARD TARGET 2045", handlers.Stock) + security, err := findSecurity(d.clients[0], "VANGUARD TARGET 2045", models.Stock) if err != nil { t.Fatalf("Error finding VANGUARD TARGET 2045 security: %s\n", err) } @@ -204,7 +205,7 @@ func TestImportOFXBrokerage(t *testing.T) { } for _, check := range checks { - security, err := findSecurity(d.clients[0], check.Ticker, handlers.Stock) + security, err := findSecurity(d.clients[0], check.Ticker, models.Stock) if err != nil { t.Fatalf("Error finding security: %s\n", err) } diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index fa2c058..2689378 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -90,7 +90,7 @@ func GetPrices(tx *Tx, securityid int64) (*[]*Price, error) { } // Return the latest price for security in currency units before date -func GetLatestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { +func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*Price, error) { var p Price err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { @@ -100,7 +100,7 @@ func GetLatestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Pri } // Return the earliest price for security in currency units after date -func GetEarliestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { +func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*Price, error) { var p Price err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { @@ -110,7 +110,7 @@ func GetEarliestPrice(tx *Tx, security, currency *Security, date *time.Time) (*P } // Return the price for security in currency closest to date -func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { +func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*Price, error) { earliest, _ := GetEarliestPrice(tx, security, currency, date) latest, err := GetLatestPrice(tx, security, currency, date) diff --git a/internal/handlers/scripts/gen_security_list.py b/internal/handlers/scripts/gen_security_list.py index 49450d6..a51723a 100755 --- a/internal/handlers/scripts/gen_security_list.py +++ b/internal/handlers/scripts/gen_security_list.py @@ -26,7 +26,7 @@ class Security(object): self.type = _type self.precision = precision def unicode(self): - s = """\tSecurity{ + s = """\t{ \t\tName: \"%s\", \t\tDescription: \"%s\", \t\tSymbol: \"%s\", @@ -72,7 +72,7 @@ def process_ccyntry(currency_list, node): else: precision = int(n.firstChild.nodeValue) if nameSet and numberSet: - currency_list.add(Security(name, description, number, "Currency", precision)) + currency_list.add(Security(name, description, number, "models.Currency", precision)) def get_currency_list(): currency_list = SecurityList("ISO 4217, from http://www.currency-iso.org/en/home/tables/table-a1.html") @@ -97,7 +97,7 @@ def get_cusip_list(filename): cusip = row[0] name = row[1] description = ",".join(row[2:]) - cusip_list.add(Security(name, description, cusip, "Stock", 5)) + cusip_list.add(Security(name, description, cusip, "models.Stock", 5)) return cusip_list def main(): @@ -105,7 +105,10 @@ def main(): cusip_list = get_cusip_list('cusip_list.csv') print("package handlers\n") - print("var SecurityTemplates = []Security{") + print("import (") + print("\t\"github.com/aclindsa/moneygo/internal/models\"") + print(")\n") + print("var SecurityTemplates = []models.Security{") print(currency_list.unicode()) print(cusip_list.unicode()) print("}") diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index 5aabaa1..ab58de4 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -3,9 +3,9 @@ package handlers //go:generate make import ( - "encoding/json" "errors" "fmt" + "github.com/aclindsa/moneygo/internal/models" "log" "net/http" "net/url" @@ -13,64 +13,9 @@ import ( "strings" ) -type SecurityType int64 - -const ( - Currency SecurityType = 1 - Stock = 2 -) - -func GetSecurityType(typestring string) SecurityType { - if strings.EqualFold(typestring, "currency") { - return Currency - } else if strings.EqualFold(typestring, "stock") { - return Stock - } else { - return 0 - } -} - -type Security struct { - SecurityId int64 - UserId int64 - Name string - Description string - Symbol string - // Number of decimal digits (to the right of the decimal point) this - // security is precise to - Precision int `db:"Preciseness"` - Type SecurityType - // AlternateId is CUSIP for Type=Stock, ISO4217 for Type=Currency - AlternateId string -} - -type SecurityList struct { - Securities *[]*Security `json:"securities"` -} - -func (s *Security) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(s) -} - -func (s *Security) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(s) -} - -func (sl *SecurityList) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(sl) -} - -func (sl *SecurityList) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(sl) -} - -func SearchSecurityTemplates(search string, _type SecurityType, limit int64) []*Security { +func SearchSecurityTemplates(search string, _type models.SecurityType, limit int64) []*models.Security { upperSearch := strings.ToUpper(search) - var results []*Security + var results []*models.Security for i, security := range SecurityTemplates { if strings.Contains(strings.ToUpper(security.Name), upperSearch) || strings.Contains(strings.ToUpper(security.Description), upperSearch) || @@ -86,7 +31,7 @@ func SearchSecurityTemplates(search string, _type SecurityType, limit int64) []* return results } -func FindSecurityTemplate(name string, _type SecurityType) *Security { +func FindSecurityTemplate(name string, _type models.SecurityType) *models.Security { for _, security := range SecurityTemplates { if name == security.Name && _type == security.Type { return &security @@ -95,18 +40,18 @@ func FindSecurityTemplate(name string, _type SecurityType) *Security { return nil } -func FindCurrencyTemplate(iso4217 int64) *Security { +func FindCurrencyTemplate(iso4217 int64) *models.Security { iso4217string := strconv.FormatInt(iso4217, 10) for _, security := range SecurityTemplates { - if security.Type == Currency && security.AlternateId == iso4217string { + if security.Type == models.Currency && security.AlternateId == iso4217string { return &security } } return nil } -func GetSecurity(tx *Tx, securityid int64, userid int64) (*Security, error) { - var s Security +func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, error) { + var s models.Security err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) if err != nil { @@ -115,8 +60,8 @@ func GetSecurity(tx *Tx, securityid int64, userid int64) (*Security, error) { return &s, nil } -func GetSecurities(tx *Tx, userid int64) (*[]*Security, error) { - var securities []*Security +func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) { + var securities []*models.Security _, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) if err != nil { @@ -125,7 +70,7 @@ func GetSecurities(tx *Tx, userid int64) (*[]*Security, error) { return &securities, nil } -func InsertSecurity(tx *Tx, s *Security) error { +func InsertSecurity(tx *Tx, s *models.Security) error { err := tx.Insert(s) if err != nil { return err @@ -133,11 +78,11 @@ func InsertSecurity(tx *Tx, s *Security) error { return nil } -func UpdateSecurity(tx *Tx, s *Security) (err error) { +func UpdateSecurity(tx *Tx, s *models.Security) (err error) { user, err := GetUser(tx, s.UserId) if err != nil { return - } else if user.DefaultCurrency == s.SecurityId && s.Type != Currency { + } else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency { return errors.New("Cannot change security which is user's default currency to be non-currency") } @@ -160,7 +105,7 @@ func (e SecurityInUseError) Error() string { return e.message } -func DeleteSecurity(tx *Tx, s *Security) error { +func DeleteSecurity(tx *Tx, s *models.Security) error { // First, ensure no accounts are using this security accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) @@ -193,7 +138,7 @@ func DeleteSecurity(tx *Tx, s *Security) error { return nil } -func ImportGetCreateSecurity(tx *Tx, userid int64, security *Security) (*Security, error) { +func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*models.Security, error) { security.UserId = userid if len(security.AlternateId) == 0 { // Always create a new local security if we can't match on the AlternateId @@ -204,7 +149,7 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *Security) (*Securit return security, nil } - var securities []*Security + var securities []*models.Security _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision) if err != nil { @@ -264,7 +209,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { return PriceHandler(r, context, user, securityid) } - var security Security + var security models.Security if err := ReadJSON(r, &security); err != nil { return NewError(3 /*Invalid Request*/) } @@ -281,7 +226,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { } else if r.Method == "GET" { if context.LastLevel() { //Return all securities - var sl SecurityList + var sl models.SecurityList securities, err := GetSecurities(context.Tx, user.UserId) if err != nil { @@ -324,7 +269,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { } if r.Method == "PUT" { - var security Security + var security models.Security if err := ReadJSON(r, &security); err != nil || security.SecurityId != securityid { return NewError(3 /*Invalid Request*/) } @@ -359,17 +304,17 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { func SecurityTemplateHandler(r *http.Request, context *Context) ResponseWriterWriter { if r.Method == "GET" { - var sl SecurityList + var sl models.SecurityList query, _ := url.ParseQuery(r.URL.RawQuery) var limit int64 = -1 search := query.Get("search") - var _type SecurityType = 0 + var _type models.SecurityType = 0 typestring := query.Get("type") if len(typestring) > 0 { - _type = GetSecurityType(typestring) + _type = models.GetSecurityType(typestring) if _type == 0 { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/securities_lua.go b/internal/handlers/securities_lua.go index b294c1c..12783ce 100644 --- a/internal/handlers/securities_lua.go +++ b/internal/handlers/securities_lua.go @@ -9,8 +9,8 @@ import ( const luaSecurityTypeName = "security" -func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { - var security_map map[int64]*Security +func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error) { + var security_map map[int64]*models.Security ctx := L.Context() @@ -19,7 +19,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { return nil, errors.New("Couldn't find tx in lua's Context") } - security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security) + security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*models.Security) if !ok { user, ok := ctx.Value(userContextKey).(*models.User) if !ok { @@ -31,7 +31,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { return nil, err } - security_map = make(map[int64]*Security) + security_map = make(map[int64]*models.Security) for i := range *securities { security_map[(*securities)[i].SecurityId] = (*securities)[i] } @@ -43,7 +43,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { return security_map, nil } -func luaContextGetDefaultCurrency(L *lua.LState) (*Security, error) { +func luaContextGetDefaultCurrency(L *lua.LState) (*models.Security, error) { security_map, err := luaContextGetSecurities(L) if err != nil { return nil, err @@ -107,7 +107,7 @@ func luaRegisterSecurities(L *lua.LState) { L.SetGlobal("get_default_currency", getDefaultCurrencyFn) } -func SecurityToLua(L *lua.LState, security *Security) *lua.LUserData { +func SecurityToLua(L *lua.LState, security *models.Security) *lua.LUserData { ud := L.NewUserData() ud.Value = security L.SetMetatable(ud, L.GetTypeMetatable(luaSecurityTypeName)) @@ -115,9 +115,9 @@ func SecurityToLua(L *lua.LState, security *Security) *lua.LUserData { } // Checks whether the first lua argument is a *LUserData with *Security and returns this *Security. -func luaCheckSecurity(L *lua.LState, n int) *Security { +func luaCheckSecurity(L *lua.LState, n int) *models.Security { ud := L.CheckUserData(n) - if security, ok := ud.Value.(*Security); ok { + if security, ok := ud.Value.(*models.Security); ok { return security } L.ArgError(n, "security expected") diff --git a/internal/handlers/securities_test.go b/internal/handlers/securities_test.go index aab0a0c..8d786ad 100644 --- a/internal/handlers/securities_test.go +++ b/internal/handlers/securities_test.go @@ -2,19 +2,20 @@ package handlers_test import ( "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "strconv" "testing" ) -func createSecurity(client *http.Client, security *handlers.Security) (*handlers.Security, error) { - var s handlers.Security +func createSecurity(client *http.Client, security *models.Security) (*models.Security, error) { + var s models.Security err := create(client, security, &s, "/v1/securities/") return &s, err } -func getSecurity(client *http.Client, securityid int64) (*handlers.Security, error) { - var s handlers.Security +func getSecurity(client *http.Client, securityid int64) (*models.Security, error) { + var s models.Security err := read(client, &s, "/v1/securities/"+strconv.FormatInt(securityid, 10)) if err != nil { return nil, err @@ -22,8 +23,8 @@ func getSecurity(client *http.Client, securityid int64) (*handlers.Security, err return &s, nil } -func getSecurities(client *http.Client) (*handlers.SecurityList, error) { - var sl handlers.SecurityList +func getSecurities(client *http.Client) (*models.SecurityList, error) { + var sl models.SecurityList err := read(client, &sl, "/v1/securities/") if err != nil { return nil, err @@ -31,8 +32,8 @@ func getSecurities(client *http.Client) (*handlers.SecurityList, error) { return &sl, nil } -func updateSecurity(client *http.Client, security *handlers.Security) (*handlers.Security, error) { - var s handlers.Security +func updateSecurity(client *http.Client, security *models.Security) (*models.Security, error) { + var s models.Security err := update(client, security, &s, "/v1/securities/"+strconv.FormatInt(security.SecurityId, 10)) if err != nil { return nil, err @@ -40,7 +41,7 @@ func updateSecurity(client *http.Client, security *handlers.Security) (*handlers return &s, nil } -func deleteSecurity(client *http.Client, s *handlers.Security) error { +func deleteSecurity(client *http.Client, s *models.Security) error { err := remove(client, "/v1/securities/"+strconv.FormatInt(s.SecurityId, 10)) if err != nil { return err diff --git a/internal/handlers/security_templates_test.go b/internal/handlers/security_templates_test.go index 04baac6..1728576 100644 --- a/internal/handlers/security_templates_test.go +++ b/internal/handlers/security_templates_test.go @@ -2,12 +2,13 @@ package handlers_test import ( "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "io/ioutil" "testing" ) func TestSecurityTemplates(t *testing.T) { - var sl handlers.SecurityList + var sl models.SecurityList response, err := server.Client().Get(server.URL + "/v1/securitytemplates/?search=USD&type=currency") if err != nil { t.Fatal(err) @@ -30,7 +31,7 @@ func TestSecurityTemplates(t *testing.T) { num_usd := 0 if sl.Securities != nil { for _, s := range *sl.Securities { - if s.Type != handlers.Currency { + if s.Type != models.Currency { t.Fatalf("Requested Currency-only security templates, received a non-Currency template for %s", s.Name) } @@ -46,7 +47,7 @@ func TestSecurityTemplates(t *testing.T) { } func TestSecurityTemplateLimit(t *testing.T) { - var sl handlers.SecurityList + var sl models.SecurityList response, err := server.Client().Get(server.URL + "/v1/securitytemplates/?search=e&limit=5") if err != nil { t.Fatal(err) diff --git a/internal/handlers/testdata_test.go b/internal/handlers/testdata_test.go index db13381..c01544d 100644 --- a/internal/handlers/testdata_test.go +++ b/internal/handlers/testdata_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "strings" "testing" @@ -36,7 +37,7 @@ type TestData struct { initialized bool users []User clients []*http.Client - securities []handlers.Security + securities []models.Security prices []handlers.Price accounts []handlers.Account // accounts must appear after their parents in this slice transactions []handlers.Transaction @@ -170,14 +171,14 @@ var data = []TestData{ Email: "bbob+moneygo@my-domain.com", }, }, - securities: []handlers.Security{ + securities: []models.Security{ { UserId: 0, Name: "USD", Description: "US Dollar", Symbol: "$", Precision: 2, - Type: handlers.Currency, + Type: models.Currency, AlternateId: "840", }, { @@ -186,7 +187,7 @@ var data = []TestData{ Description: "SPDR S&P 500 ETF Trust", Symbol: "SPY", Precision: 5, - Type: handlers.Stock, + Type: models.Stock, AlternateId: "78462F103", }, { @@ -195,7 +196,7 @@ var data = []TestData{ Description: "Euro", Symbol: "€", Precision: 2, - Type: handlers.Currency, + Type: models.Currency, AlternateId: "978", }, { @@ -204,7 +205,7 @@ var data = []TestData{ Description: "Euro", Symbol: "€", Precision: 2, - Type: handlers.Currency, + Type: models.Currency, AlternateId: "978", }, }, diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 66b5737..ba1a9d0 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -54,7 +54,7 @@ func InsertUser(tx *Tx, u *models.User) error { } // Copy the security template and give it our new UserId - var security Security + var security models.Security security = *security_template security.UserId = u.UserId @@ -89,7 +89,7 @@ func UpdateUser(tx *Tx, u *models.User) error { return err } else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency { return errors.New("UserId and DefaultCurrency don't match the fetched security") - } else if security.Type != Currency { + } else if security.Type != models.Currency { return errors.New("New DefaultCurrency security is not a currency") } diff --git a/internal/models/securities.go b/internal/models/securities.go new file mode 100644 index 0000000..67557be --- /dev/null +++ b/internal/models/securities.go @@ -0,0 +1,62 @@ +package models + +import ( + "encoding/json" + "net/http" + "strings" +) + +type SecurityType int64 + +const ( + Currency SecurityType = 1 + Stock = 2 +) + +func GetSecurityType(typestring string) SecurityType { + if strings.EqualFold(typestring, "currency") { + return Currency + } else if strings.EqualFold(typestring, "stock") { + return Stock + } else { + return 0 + } +} + +type Security struct { + SecurityId int64 + UserId int64 + Name string + Description string + Symbol string + // Number of decimal digits (to the right of the decimal point) this + // security is precise to + Precision int `db:"Preciseness"` + Type SecurityType + // AlternateId is CUSIP for Type=Stock, ISO4217 for Type=Currency + AlternateId string +} + +type SecurityList struct { + Securities *[]*Security `json:"securities"` +} + +func (s *Security) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(s) +} + +func (s *Security) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(s) +} + +func (sl *SecurityList) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(sl) +} + +func (sl *SecurityList) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(sl) +}