From e70be1647c4d2f6895ea266073a802e0dd198d12 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sat, 2 Dec 2017 06:14:47 -0500 Subject: [PATCH 1/6] Begin splitting models from handlers with User --- internal/db/db.go | 3 +- internal/handlers/accounts_lua.go | 5 +-- internal/handlers/handlers.go | 3 +- internal/handlers/imports.go | 9 ++--- internal/handlers/prices.go | 3 +- internal/handlers/reports.go | 5 +-- internal/handlers/securities_lua.go | 5 +-- internal/handlers/sessions.go | 3 +- internal/handlers/transactions.go | 19 +++++----- internal/handlers/users.go | 54 ++++++----------------------- internal/models/users.go | 39 +++++++++++++++++++++ 11 files changed, 82 insertions(+), 66 deletions(-) create mode 100644 internal/models/users.go diff --git a/internal/db/db.go b/internal/db/db.go index 68cfd6d..fbb7b9d 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -6,6 +6,7 @@ import ( "github.com/aclindsa/gorp" "github.com/aclindsa/moneygo/internal/config" "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -33,7 +34,7 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { } dbmap := &gorp.DbMap{Db: db, Dialect: dialect} - dbmap.AddTableWithName(handlers.User{}, "users").SetKeys(true, "UserId") + dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId") dbmap.AddTableWithName(handlers.Session{}, "sessions").SetKeys(true, "SessionId") dbmap.AddTableWithName(handlers.Account{}, "accounts").SetKeys(true, "AccountId") dbmap.AddTableWithName(handlers.Security{}, "securities").SetKeys(true, "SecurityId") diff --git a/internal/handlers/accounts_lua.go b/internal/handlers/accounts_lua.go index 4edf440..8a4d5db 100644 --- a/internal/handlers/accounts_lua.go +++ b/internal/handlers/accounts_lua.go @@ -3,6 +3,7 @@ package handlers import ( "context" "errors" + "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" "math/big" "strings" @@ -22,7 +23,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account) if !ok { - user, ok := ctx.Value(userContextKey).(*User) + user, ok := ctx.Value(userContextKey).(*models.User) if !ok { return nil, errors.New("Couldn't find User in lua's Context") } @@ -153,7 +154,7 @@ func luaAccountBalance(L *lua.LState) int { if !ok { panic("Couldn't find tx in lua's Context") } - user, ok := ctx.Value(userContextKey).(*User) + user, ok := ctx.Value(userContextKey).(*models.User) if !ok { panic("Couldn't find User in lua's Context") } diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index e42419d..309eed1 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -2,6 +2,7 @@ package handlers import ( "github.com/aclindsa/gorp" + "github.com/aclindsa/moneygo/internal/models" "log" "net/http" "path" @@ -16,7 +17,7 @@ type ResponseWriterWriter interface { type Context struct { Tx *Tx - User *User + User *models.User remainingURL string // portion of URL path not yet reached in the hierarchy } diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index f83ab0c..a90fb8b 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/json" + "github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/ofxgo" "io" "log" @@ -22,7 +23,7 @@ func (od *OFXDownload) Read(json_str string) error { return dec.Decode(od) } -func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseWriterWriter { +func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter { itl, err := ImportOFX(r) if err != nil { @@ -210,7 +211,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW return SuccessWriter{} } -func OFXImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { +func OFXImportHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { var ofxdownload OFXDownload if err := ReadJSON(r, &ofxdownload); err != nil { return NewError(3 /*Invalid Request*/) @@ -305,7 +306,7 @@ func OFXImportHandler(context *Context, r *http.Request, user *User, accountid i return ofxImportHelper(context.Tx, response.Body, user, accountid) } -func OFXFileImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { +func OFXFileImportHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { multipartReader, err := r.MultipartReader() if err != nil { return NewError(3 /*Invalid Request*/) @@ -329,7 +330,7 @@ func OFXFileImportHandler(context *Context, r *http.Request, user *User, account /* * Assumes the User is a valid, signed-in user, but accountid has not yet been validated */ -func AccountImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { +func AccountImportHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { importType := context.NextLevel() switch importType { diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index 2027497..fa2c058 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/json" + "github.com/aclindsa/moneygo/internal/models" "log" "net/http" "strings" @@ -129,7 +130,7 @@ func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Pr } } -func PriceHandler(r *http.Request, context *Context, user *User, securityid int64) ResponseWriterWriter { +func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter { security, err := GetSecurity(context.Tx, securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index 97c4b64..cab32d1 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" "log" "net/http" @@ -134,7 +135,7 @@ func DeleteReport(tx *Tx, r *Report) error { return nil } -func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) { +func runReport(tx *Tx, user *models.User, report *Report) (*Tabulation, error) { // Create a new LState without opening the default libs for security L := lua.NewState(lua.Options{SkipOpenLibs: true}) defer L.Close() @@ -198,7 +199,7 @@ func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) { } } -func ReportTabulationHandler(tx *Tx, r *http.Request, user *User, reportid int64) ResponseWriterWriter { +func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { report, err := GetReport(tx, reportid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) diff --git a/internal/handlers/securities_lua.go b/internal/handlers/securities_lua.go index 555bdc4..b294c1c 100644 --- a/internal/handlers/securities_lua.go +++ b/internal/handlers/securities_lua.go @@ -3,6 +3,7 @@ package handlers import ( "context" "errors" + "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" ) @@ -20,7 +21,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security) if !ok { - user, ok := ctx.Value(userContextKey).(*User) + user, ok := ctx.Value(userContextKey).(*models.User) if !ok { return nil, errors.New("Couldn't find User in lua's Context") } @@ -50,7 +51,7 @@ func luaContextGetDefaultCurrency(L *lua.LState) (*Security, error) { ctx := L.Context() - user, ok := ctx.Value(userContextKey).(*User) + user, ok := ctx.Value(userContextKey).(*models.User) if !ok { return nil, errors.New("Couldn't find User in lua's Context") } diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index 55c9317..81954b1 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "github.com/aclindsa/moneygo/internal/models" "io" "log" "net/http" @@ -120,7 +121,7 @@ func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter { if r.Method == "POST" || r.Method == "PUT" { - var user User + var user models.User if err := ReadJSON(r, &user); err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index a6adc8e..60c97b2 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/aclindsa/moneygo/internal/models" "log" "math/big" "net/http" @@ -221,7 +222,7 @@ func GetTransactions(tx *Tx, userid int64) (*[]Transaction, error) { return &transactions, nil } -func incrementAccountVersions(tx *Tx, user *User, accountids []int64) error { +func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error { for i := range accountids { account, err := GetAccount(tx, accountids[i], user.UserId) if err != nil { @@ -245,7 +246,7 @@ func (ame AccountMissingError) Error() string { return "Account missing" } -func InsertTransaction(tx *Tx, t *Transaction, user *User) error { +func InsertTransaction(tx *Tx, t *Transaction, user *models.User) error { // Map of any accounts with transaction splits being added a_map := make(map[int64]bool) for i := range t.Splits { @@ -295,7 +296,7 @@ func InsertTransaction(tx *Tx, t *Transaction, user *User) error { return nil } -func UpdateTransaction(tx *Tx, t *Transaction, user *User) error { +func UpdateTransaction(tx *Tx, t *Transaction, user *models.User) error { var existing_splits []*Split _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) @@ -372,7 +373,7 @@ func UpdateTransaction(tx *Tx, t *Transaction, user *User) error { return nil } -func DeleteTransaction(tx *Tx, t *Transaction, user *User) error { +func DeleteTransaction(tx *Tx, t *Transaction, user *models.User) error { var accountids []int64 _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) if err != nil { @@ -549,7 +550,7 @@ func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []Trans return &pageDifference, nil } -func GetAccountBalance(tx *Tx, user *User, accountid int64) (*big.Rat, error) { +func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) { var splits []Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" @@ -572,7 +573,7 @@ func GetAccountBalance(tx *Tx, user *User, accountid int64) (*big.Rat, error) { } // Assumes accountid is valid and is owned by the current user -func GetAccountBalanceDate(tx *Tx, user *User, accountid int64, date *time.Time) (*big.Rat, error) { +func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) { var splits []Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" @@ -594,7 +595,7 @@ func GetAccountBalanceDate(tx *Tx, user *User, accountid int64, date *time.Time) return &balance, nil } -func GetAccountBalanceDateRange(tx *Tx, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { +func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) { var splits []Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" @@ -616,7 +617,7 @@ func GetAccountBalanceDateRange(tx *Tx, user *User, accountid int64, begin, end return &balance, nil } -func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { +func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { var transactions []Transaction var atl AccountTransactionsList @@ -699,7 +700,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa // Return only those transactions which have at least one split pertaining to // an account -func AccountTransactionsHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { +func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { var page uint64 = 0 var limit uint64 = 50 var sort string = "date-desc" diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 1dfc5d0..66b5737 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -1,53 +1,21 @@ package handlers import ( - "crypto/sha256" - "encoding/json" "errors" "fmt" - "io" + "github.com/aclindsa/moneygo/internal/models" "log" "net/http" - "strings" ) -type User struct { - UserId int64 - DefaultCurrency int64 // SecurityId of default currency, or ISO4217 code for it if creating new user - Name string - Username string - Password string `db:"-"` - PasswordHash string `json:"-"` - Email string -} - -const BogusPassword = "password" - type UserExistsError struct{} func (ueu UserExistsError) Error() string { return "User exists" } -func (u *User) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(u) -} - -func (u *User) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(u) -} - -func (u *User) HashPassword() { - password_hasher := sha256.New() - io.WriteString(password_hasher, u.Password) - u.PasswordHash = fmt.Sprintf("%x", password_hasher.Sum(nil)) - u.Password = "" -} - -func GetUser(tx *Tx, userid int64) (*User, error) { - var u User +func GetUser(tx *Tx, userid int64) (*models.User, error) { + var u models.User err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) if err != nil { @@ -56,8 +24,8 @@ func GetUser(tx *Tx, userid int64) (*User, error) { return &u, nil } -func GetUserByUsername(tx *Tx, username string) (*User, error) { - var u User +func GetUserByUsername(tx *Tx, username string) (*models.User, error) { + var u models.User err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) if err != nil { @@ -66,7 +34,7 @@ func GetUserByUsername(tx *Tx, username string) (*User, error) { return &u, nil } -func InsertUser(tx *Tx, u *User) error { +func InsertUser(tx *Tx, u *models.User) error { security_template := FindCurrencyTemplate(u.DefaultCurrency) if security_template == nil { return errors.New("Invalid ISO4217 Default Currency") @@ -107,7 +75,7 @@ func InsertUser(tx *Tx, u *User) error { return nil } -func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) { +func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) { s, err := GetSession(tx, r) if err != nil { return nil, err @@ -115,7 +83,7 @@ func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) { return GetUser(tx, s.UserId) } -func UpdateUser(tx *Tx, u *User) error { +func UpdateUser(tx *Tx, u *models.User) error { security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId) if err != nil { return err @@ -135,7 +103,7 @@ func UpdateUser(tx *Tx, u *User) error { return nil } -func DeleteUser(tx *Tx, u *User) error { +func DeleteUser(tx *Tx, u *models.User) error { count, err := tx.Delete(u) if err != nil { return err @@ -177,7 +145,7 @@ func DeleteUser(tx *Tx, u *User) error { func UserHandler(r *http.Request, context *Context) ResponseWriterWriter { if r.Method == "POST" { - var user User + var user models.User if err := ReadJSON(r, &user); err != nil { return NewError(3 /*Invalid Request*/) } @@ -221,7 +189,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter { } // If the user didn't create a new password, keep their old one - if user.Password != BogusPassword { + if user.Password != models.BogusPassword { user.HashPassword() } else { user.Password = "" diff --git a/internal/models/users.go b/internal/models/users.go new file mode 100644 index 0000000..c5f45ce --- /dev/null +++ b/internal/models/users.go @@ -0,0 +1,39 @@ +package models + +import ( + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +type User struct { + UserId int64 + DefaultCurrency int64 // SecurityId of default currency, or ISO4217 code for it if creating new user + Name string + Username string + Password string `db:"-"` + PasswordHash string `json:"-"` + Email string +} + +const BogusPassword = "password" + +func (u *User) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(u) +} + +func (u *User) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(u) +} + +func (u *User) HashPassword() { + password_hasher := sha256.New() + io.WriteString(password_hasher, u.Password) + u.PasswordHash = fmt.Sprintf("%x", password_hasher.Sum(nil)) + u.Password = "" +} From 3f4d6d15a1441cfff0ebb11157c603f902a944dc Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sun, 3 Dec 2017 06:11:38 -0500 Subject: [PATCH 2/6] Split Sessions into models --- internal/db/db.go | 2 +- internal/handlers/sessions.go | 63 ++++------------------------ internal/handlers/sessions_test.go | 5 ++- internal/models/sessions.go | 67 ++++++++++++++++++++++++++++++ 4 files changed, 79 insertions(+), 58 deletions(-) create mode 100644 internal/models/sessions.go diff --git a/internal/db/db.go b/internal/db/db.go index fbb7b9d..4ee3192 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -35,7 +35,7 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { dbmap := &gorp.DbMap{Db: db, Dialect: dialect} dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId") - dbmap.AddTableWithName(handlers.Session{}, "sessions").SetKeys(true, "SessionId") + dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId") dbmap.AddTableWithName(handlers.Account{}, "accounts").SetKeys(true, "AccountId") dbmap.AddTableWithName(handlers.Security{}, "securities").SetKeys(true, "SecurityId") dbmap.AddTableWithName(handlers.Transaction{}, "transactions").SetKeys(true, "TransactionId") diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index 81954b1..8349613 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -1,38 +1,15 @@ package handlers import ( - "crypto/rand" - "encoding/base64" - "encoding/json" "fmt" "github.com/aclindsa/moneygo/internal/models" - "io" "log" "net/http" - "strings" "time" ) -type Session struct { - SessionId int64 - SessionSecret string `json:"-"` - UserId int64 - Created time.Time - Expires time.Time -} - -func (s *Session) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(s) -} - -func (s *Session) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(s) -} - -func GetSession(tx *Tx, r *http.Request) (*Session, error) { - var s Session +func GetSession(tx *Tx, r *http.Request) (*models.Session, error) { + var s models.Session cookie, err := r.Cookie("moneygo-session") if err != nil { @@ -63,16 +40,8 @@ func DeleteSessionIfExists(tx *Tx, r *http.Request) error { return nil } -func NewSessionCookie() (string, error) { - bits := make([]byte, 128) - if _, err := io.ReadFull(rand.Reader, bits); err != nil { - return "", err - } - return base64.StdEncoding.EncodeToString(bits), nil -} - type NewSessionWriter struct { - session *Session + session *models.Session cookie *http.Cookie } @@ -82,14 +51,12 @@ func (n *NewSessionWriter) Write(w http.ResponseWriter) error { } func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) { - s := Session{} - - session_secret, err := NewSessionCookie() + s, err := models.NewSession(userid) if err != nil { return nil, err } - existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", session_secret) + existing, err := tx.SelectInt("SELECT count(*) from sessions where SessionSecret=?", s.SessionSecret) if err != nil { return nil, err } @@ -97,26 +64,12 @@ func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error return nil, fmt.Errorf("%d session(s) exist with the generated session_secret", existing) } - cookie := http.Cookie{ - Name: "moneygo-session", - Value: session_secret, - Path: "/", - Domain: r.URL.Host, - Expires: time.Now().AddDate(0, 1, 0), // a month from now - Secure: true, - HttpOnly: true, - } - - s.SessionSecret = session_secret - s.UserId = userid - s.Created = time.Now() - s.Expires = cookie.Expires - - err = tx.Insert(&s) + err = tx.Insert(s) if err != nil { return nil, err } - return &NewSessionWriter{&s, &cookie}, nil + + return &NewSessionWriter{s, s.Cookie(r.URL.Host)}, nil } func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter { diff --git a/internal/handlers/sessions_test.go b/internal/handlers/sessions_test.go index c0bc6b4..e065bac 100644 --- a/internal/handlers/sessions_test.go +++ b/internal/handlers/sessions_test.go @@ -3,6 +3,7 @@ package handlers_test import ( "fmt" "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "net/http/cookiejar" "net/url" @@ -26,8 +27,8 @@ func newSession(user *User) (*http.Client, error) { return &client, nil } -func getSession(client *http.Client) (*handlers.Session, error) { - var s handlers.Session +func getSession(client *http.Client) (*models.Session, error) { + var s models.Session err := read(client, &s, "/v1/sessions/") return &s, err } diff --git a/internal/models/sessions.go b/internal/models/sessions.go new file mode 100644 index 0000000..872bc3c --- /dev/null +++ b/internal/models/sessions.go @@ -0,0 +1,67 @@ +package models + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "strings" + "time" +) + +type Session struct { + SessionId int64 + SessionSecret string `json:"-"` + UserId int64 + Created time.Time + Expires time.Time +} + +func (s *Session) Cookie(domain string) *http.Cookie { + return &http.Cookie{ + Name: "moneygo-session", + Value: s.SessionSecret, + Path: "/", + Domain: domain, + Expires: s.Expires, + Secure: true, + HttpOnly: true, + } +} + +func (s *Session) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(s) +} + +func (s *Session) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(s) +} + +func newSessionSecret() (string, error) { + bits := make([]byte, 128) + if _, err := io.ReadFull(rand.Reader, bits); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(bits), nil +} + +func NewSession(userid int64) (*Session, error) { + session_secret, err := newSessionSecret() + if err != nil { + return nil, err + } + + now := time.Now() + + s := Session{ + SessionSecret: session_secret, + UserId: userid, + Created: now, + Expires: now.AddDate(0, 1, 0), // a month from now + } + + return &s, nil +} From f72c86ef58d927c0136b0bdd57d2d68112438e19 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Sun, 3 Dec 2017 06:38:22 -0500 Subject: [PATCH 3/6] Split securities into models --- internal/db/db.go | 2 +- internal/handlers/accounts.go | 3 +- internal/handlers/balance_lua.go | 3 +- internal/handlers/gnucash.go | 15 +-- internal/handlers/gnucash_test.go | 3 +- internal/handlers/imports.go | 2 +- internal/handlers/ofx.go | 49 ++++----- internal/handlers/ofx_test.go | 7 +- internal/handlers/prices.go | 6 +- .../handlers/scripts/gen_security_list.py | 11 ++- internal/handlers/securities.go | 99 +++++-------------- internal/handlers/securities_lua.go | 16 +-- internal/handlers/securities_test.go | 19 ++-- internal/handlers/security_templates_test.go | 7 +- internal/handlers/testdata_test.go | 13 +-- internal/handlers/users.go | 4 +- internal/models/securities.go | 62 ++++++++++++ 17 files changed, 170 insertions(+), 151 deletions(-) create mode 100644 internal/models/securities.go diff --git a/internal/db/db.go b/internal/db/db.go index 4ee3192..0d71bfa 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -37,7 +37,7 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId") dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId") dbmap.AddTableWithName(handlers.Account{}, "accounts").SetKeys(true, "AccountId") - dbmap.AddTableWithName(handlers.Security{}, "securities").SetKeys(true, "SecurityId") + dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId") dbmap.AddTableWithName(handlers.Transaction{}, "transactions").SetKeys(true, "TransactionId") dbmap.AddTableWithName(handlers.Split{}, "splits").SetKeys(true, "SplitId") dbmap.AddTableWithName(handlers.Price{}, "prices").SetKeys(true, "PriceId") diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index 9bdc6f3..dd3fa14 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -3,6 +3,7 @@ package handlers import ( "encoding/json" "errors" + "github.com/aclindsa/moneygo/internal/models" "log" "net/http" "strings" @@ -214,7 +215,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*Account, error) func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*Account, error) { var imbalanceAccount Account var account Account - xxxtemplate := FindSecurityTemplate("XXX", Currency) + xxxtemplate := FindSecurityTemplate("XXX", models.Currency) if xxxtemplate == nil { return nil, errors.New("Couldn't find XXX security template") } diff --git a/internal/handlers/balance_lua.go b/internal/handlers/balance_lua.go index 118d0d6..c4d6b63 100644 --- a/internal/handlers/balance_lua.go +++ b/internal/handlers/balance_lua.go @@ -1,12 +1,13 @@ package handlers import ( + "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" "math/big" ) type Balance struct { - Security *Security + Security *models.Security Amount *big.Rat } diff --git a/internal/handlers/gnucash.go b/internal/handlers/gnucash.go index 75a949c..9670f1e 100644 --- a/internal/handlers/gnucash.go +++ b/internal/handlers/gnucash.go @@ -6,6 +6,7 @@ import ( "encoding/xml" "errors" "fmt" + "github.com/aclindsa/moneygo/internal/models" "io" "log" "math" @@ -22,7 +23,7 @@ type GnucashXMLCommodity struct { XCode string `xml:"http://www.gnucash.org/XML/cmdty xcode"` } -type GnucashCommodity struct{ Security } +type GnucashCommodity struct{ models.Security } func (gc *GnucashCommodity) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { var gxc GnucashXMLCommodity @@ -35,12 +36,12 @@ func (gc *GnucashCommodity) UnmarshalXML(d *xml.Decoder, start xml.StartElement) gc.Description = gxc.Description gc.AlternateId = gxc.XCode - gc.Security.Type = Stock // assumed default + gc.Security.Type = models.Stock // assumed default if gxc.Type == "ISO4217" { - gc.Security.Type = Currency + gc.Security.Type = models.Currency // Get the number from our templates for the AlternateId because // Gnucash uses 'id' (our Name) to supply the string ISO4217 code - template := FindSecurityTemplate(gxc.Name, Currency) + template := FindSecurityTemplate(gxc.Name, models.Currency) if template == nil { return errors.New("Unable to find security template for Gnucash ISO4217 commodity") } @@ -125,7 +126,7 @@ type GnucashXMLImport struct { } type GnucashImport struct { - Securities []Security + Securities []models.Security Accounts []Account Transactions []Transaction Prices []Price @@ -143,7 +144,7 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { } // Fixup securities, making a map of them as we go - securityMap := make(map[string]Security) + securityMap := make(map[string]models.Security) for i := range gncxml.Commodities { s := gncxml.Commodities[i].Security s.SecurityId = int64(i + 1) @@ -169,7 +170,7 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { if !ok { return nil, fmt.Errorf("Unable to find currency '%s' for price '%s'", price.Currency.Name, price.Id) } - if currency.Type != Currency { + if currency.Type != models.Currency { return nil, fmt.Errorf("Currency for imported price isn't actually a currency\n") } p.PriceId = int64(i + 1) diff --git a/internal/handlers/gnucash_test.go b/internal/handlers/gnucash_test.go index 1cdf9d0..7eacc03 100644 --- a/internal/handlers/gnucash_test.go +++ b/internal/handlers/gnucash_test.go @@ -2,6 +2,7 @@ package handlers_test import ( "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "testing" ) @@ -94,7 +95,7 @@ func TestImportGnucash(t *testing.T) { accountBalanceHelper(t, d.clients[0], groceries, "287.56") // 87.19 from preexisting transactions and 200.37 from Gnucash accountBalanceHelper(t, d.clients[0], cable, "89.98") - var ge *handlers.Security + var ge *models.Security securities, err := getSecurities(d.clients[0]) if err != nil { t.Fatalf("Error fetching securities: %s\n", err) diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index a90fb8b..442c4e3 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -57,7 +57,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re // Find matching existing securities or create new ones for those // referenced by the OFX import. Also create a map from placeholder import // SecurityIds to the actual SecurityIDs - var securitymap = make(map[int64]Security) + var securitymap = make(map[int64]models.Security) for _, ofxsecurity := range itl.Securities { // save off since ImportGetCreateSecurity overwrites SecurityId on // ofxsecurity diff --git a/internal/handlers/ofx.go b/internal/handlers/ofx.go index 8c08a67..befd147 100644 --- a/internal/handlers/ofx.go +++ b/internal/handlers/ofx.go @@ -3,26 +3,27 @@ package handlers import ( "errors" "fmt" + "github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/ofxgo" "io" "math/big" ) type OFXImport struct { - Securities []Security + Securities []models.Security Accounts []Account Transactions []Transaction // Balances map[int64]string // map AccountIDs to ending balances } -func (i *OFXImport) GetSecurity(ofxsecurityid int64) (*Security, error) { +func (i *OFXImport) GetSecurity(ofxsecurityid int64) (*models.Security, error) { if ofxsecurityid < 0 || ofxsecurityid > int64(len(i.Securities)) { return nil, errors.New("OFXImport.GetSecurity: SecurityID out of range") } return &i.Securities[ofxsecurityid], nil } -func (i *OFXImport) GetSecurityAlternateId(alternateid string, securityType SecurityType) (*Security, error) { +func (i *OFXImport) GetSecurityAlternateId(alternateid string, securityType models.SecurityType) (*models.Security, error) { for _, security := range i.Securities { if alternateid == security.AlternateId && securityType == security.Type { return &security, nil @@ -32,18 +33,18 @@ func (i *OFXImport) GetSecurityAlternateId(alternateid string, securityType Secu return nil, errors.New("OFXImport.FindSecurity: Unable to find security") } -func (i *OFXImport) GetAddCurrency(isoname string) (*Security, error) { +func (i *OFXImport) GetAddCurrency(isoname string) (*models.Security, error) { for _, security := range i.Securities { - if isoname == security.Name && Currency == security.Type { + if isoname == security.Name && models.Currency == security.Type { return &security, nil } } - template := FindSecurityTemplate(isoname, Currency) + template := FindSecurityTemplate(isoname, models.Currency) if template == nil { return nil, fmt.Errorf("Failed to find Security for \"%s\"", isoname) } - var security Security = *template + var security models.Security = *template security.SecurityId = int64(len(i.Securities) + 1) i.Securities = append(i.Securities, security) @@ -186,13 +187,13 @@ func (i *OFXImport) importSecurities(seclist *ofxgo.SecurityList) error { } else { return errors.New("Can't import unrecognized type satisfying ofxgo.Security interface") } - s := Security{ + s := models.Security{ SecurityId: int64(len(i.Securities) + 1), Name: string(si.SecName), Description: string(si.Memo), Symbol: string(si.Ticker), Precision: 5, // TODO How to actually determine this? - Type: Stock, + Type: models.Stock, AlternateId: string(si.SecID.UniqueID), } if len(s.Description) == 0 { @@ -214,10 +215,10 @@ func (i *OFXImport) GetInvTran(invtran *ofxgo.InvTran) Transaction { return t } -func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&buy.InvTran) - security, err := i.GetSecurityAlternateId(string(buy.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(buy.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -348,10 +349,10 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *Security, account * return &t, nil } -func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&income.InvTran) - security, err := i.GetSecurityAlternateId(string(income.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(income.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -394,10 +395,10 @@ func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *Security, accoun return &t, nil } -func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&expense.InvTran) - security, err := i.GetSecurityAlternateId(string(expense.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(expense.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -439,7 +440,7 @@ func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *Securit return &t, nil } -func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&marginint.InvTran) memo := string(marginint.InvTran.Memo) @@ -478,10 +479,10 @@ func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curde return &t, nil } -func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&reinvest.InvTran) - security, err := i.GetSecurityAlternateId(string(reinvest.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(reinvest.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -634,10 +635,10 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *Security, return &t, nil } -func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&retofcap.InvTran) - security, err := i.GetSecurityAlternateId(string(retofcap.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(retofcap.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -679,10 +680,10 @@ func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *Security, return &t, nil } -func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, account *Account) (*Transaction, error) { t := i.GetInvTran(&sell.InvTran) - security, err := i.GetSecurityAlternateId(string(sell.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(sell.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -819,7 +820,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *Security, accoun func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *Account) (*Transaction, error) { t := i.GetInvTran(&transfer.InvTran) - security, err := i.GetSecurityAlternateId(string(transfer.SecID.UniqueID), Stock) + security, err := i.GetSecurityAlternateId(string(transfer.SecID.UniqueID), models.Stock) if err != nil { return nil, err } @@ -858,7 +859,7 @@ func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *Account) return &t, nil } -func (i *OFXImport) AddInvTransaction(invtran *ofxgo.InvTransaction, account *Account, curdef *Security) error { +func (i *OFXImport) AddInvTransaction(invtran *ofxgo.InvTransaction, account *Account, curdef *models.Security) error { if curdef.SecurityId < 1 || curdef.SecurityId > int64(len(i.Securities)) { return errors.New("Internal error: security index not found in OFX import\n") } diff --git a/internal/handlers/ofx_test.go b/internal/handlers/ofx_test.go index baf452a..11c3f04 100644 --- a/internal/handlers/ofx_test.go +++ b/internal/handlers/ofx_test.go @@ -3,6 +3,7 @@ package handlers_test import ( "fmt" "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "strconv" "testing" @@ -63,7 +64,7 @@ func TestImportOFXCreditCard(t *testing.T) { }) } -func findSecurity(client *http.Client, symbol string, tipe handlers.SecurityType) (*handlers.Security, error) { +func findSecurity(client *http.Client, symbol string, tipe models.SecurityType) (*models.Security, error) { securities, err := getSecurities(client) if err != nil { return nil, err @@ -125,7 +126,7 @@ func TestImportOFX401kMutualFunds(t *testing.T) { // Make sure the security was created and that the trading account has // the right value - security, err := findSecurity(d.clients[0], "VANGUARD TARGET 2045", handlers.Stock) + security, err := findSecurity(d.clients[0], "VANGUARD TARGET 2045", models.Stock) if err != nil { t.Fatalf("Error finding VANGUARD TARGET 2045 security: %s\n", err) } @@ -204,7 +205,7 @@ func TestImportOFXBrokerage(t *testing.T) { } for _, check := range checks { - security, err := findSecurity(d.clients[0], check.Ticker, handlers.Stock) + security, err := findSecurity(d.clients[0], check.Ticker, models.Stock) if err != nil { t.Fatalf("Error finding security: %s\n", err) } diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index fa2c058..2689378 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -90,7 +90,7 @@ func GetPrices(tx *Tx, securityid int64) (*[]*Price, error) { } // Return the latest price for security in currency units before date -func GetLatestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { +func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*Price, error) { var p Price err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { @@ -100,7 +100,7 @@ func GetLatestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Pri } // Return the earliest price for security in currency units after date -func GetEarliestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { +func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*Price, error) { var p Price err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { @@ -110,7 +110,7 @@ func GetEarliestPrice(tx *Tx, security, currency *Security, date *time.Time) (*P } // Return the price for security in currency closest to date -func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) { +func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*Price, error) { earliest, _ := GetEarliestPrice(tx, security, currency, date) latest, err := GetLatestPrice(tx, security, currency, date) diff --git a/internal/handlers/scripts/gen_security_list.py b/internal/handlers/scripts/gen_security_list.py index 49450d6..a51723a 100755 --- a/internal/handlers/scripts/gen_security_list.py +++ b/internal/handlers/scripts/gen_security_list.py @@ -26,7 +26,7 @@ class Security(object): self.type = _type self.precision = precision def unicode(self): - s = """\tSecurity{ + s = """\t{ \t\tName: \"%s\", \t\tDescription: \"%s\", \t\tSymbol: \"%s\", @@ -72,7 +72,7 @@ def process_ccyntry(currency_list, node): else: precision = int(n.firstChild.nodeValue) if nameSet and numberSet: - currency_list.add(Security(name, description, number, "Currency", precision)) + currency_list.add(Security(name, description, number, "models.Currency", precision)) def get_currency_list(): currency_list = SecurityList("ISO 4217, from http://www.currency-iso.org/en/home/tables/table-a1.html") @@ -97,7 +97,7 @@ def get_cusip_list(filename): cusip = row[0] name = row[1] description = ",".join(row[2:]) - cusip_list.add(Security(name, description, cusip, "Stock", 5)) + cusip_list.add(Security(name, description, cusip, "models.Stock", 5)) return cusip_list def main(): @@ -105,7 +105,10 @@ def main(): cusip_list = get_cusip_list('cusip_list.csv') print("package handlers\n") - print("var SecurityTemplates = []Security{") + print("import (") + print("\t\"github.com/aclindsa/moneygo/internal/models\"") + print(")\n") + print("var SecurityTemplates = []models.Security{") print(currency_list.unicode()) print(cusip_list.unicode()) print("}") diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index 5aabaa1..ab58de4 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -3,9 +3,9 @@ package handlers //go:generate make import ( - "encoding/json" "errors" "fmt" + "github.com/aclindsa/moneygo/internal/models" "log" "net/http" "net/url" @@ -13,64 +13,9 @@ import ( "strings" ) -type SecurityType int64 - -const ( - Currency SecurityType = 1 - Stock = 2 -) - -func GetSecurityType(typestring string) SecurityType { - if strings.EqualFold(typestring, "currency") { - return Currency - } else if strings.EqualFold(typestring, "stock") { - return Stock - } else { - return 0 - } -} - -type Security struct { - SecurityId int64 - UserId int64 - Name string - Description string - Symbol string - // Number of decimal digits (to the right of the decimal point) this - // security is precise to - Precision int `db:"Preciseness"` - Type SecurityType - // AlternateId is CUSIP for Type=Stock, ISO4217 for Type=Currency - AlternateId string -} - -type SecurityList struct { - Securities *[]*Security `json:"securities"` -} - -func (s *Security) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(s) -} - -func (s *Security) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(s) -} - -func (sl *SecurityList) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(sl) -} - -func (sl *SecurityList) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(sl) -} - -func SearchSecurityTemplates(search string, _type SecurityType, limit int64) []*Security { +func SearchSecurityTemplates(search string, _type models.SecurityType, limit int64) []*models.Security { upperSearch := strings.ToUpper(search) - var results []*Security + var results []*models.Security for i, security := range SecurityTemplates { if strings.Contains(strings.ToUpper(security.Name), upperSearch) || strings.Contains(strings.ToUpper(security.Description), upperSearch) || @@ -86,7 +31,7 @@ func SearchSecurityTemplates(search string, _type SecurityType, limit int64) []* return results } -func FindSecurityTemplate(name string, _type SecurityType) *Security { +func FindSecurityTemplate(name string, _type models.SecurityType) *models.Security { for _, security := range SecurityTemplates { if name == security.Name && _type == security.Type { return &security @@ -95,18 +40,18 @@ func FindSecurityTemplate(name string, _type SecurityType) *Security { return nil } -func FindCurrencyTemplate(iso4217 int64) *Security { +func FindCurrencyTemplate(iso4217 int64) *models.Security { iso4217string := strconv.FormatInt(iso4217, 10) for _, security := range SecurityTemplates { - if security.Type == Currency && security.AlternateId == iso4217string { + if security.Type == models.Currency && security.AlternateId == iso4217string { return &security } } return nil } -func GetSecurity(tx *Tx, securityid int64, userid int64) (*Security, error) { - var s Security +func GetSecurity(tx *Tx, securityid int64, userid int64) (*models.Security, error) { + var s models.Security err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid) if err != nil { @@ -115,8 +60,8 @@ func GetSecurity(tx *Tx, securityid int64, userid int64) (*Security, error) { return &s, nil } -func GetSecurities(tx *Tx, userid int64) (*[]*Security, error) { - var securities []*Security +func GetSecurities(tx *Tx, userid int64) (*[]*models.Security, error) { + var securities []*models.Security _, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid) if err != nil { @@ -125,7 +70,7 @@ func GetSecurities(tx *Tx, userid int64) (*[]*Security, error) { return &securities, nil } -func InsertSecurity(tx *Tx, s *Security) error { +func InsertSecurity(tx *Tx, s *models.Security) error { err := tx.Insert(s) if err != nil { return err @@ -133,11 +78,11 @@ func InsertSecurity(tx *Tx, s *Security) error { return nil } -func UpdateSecurity(tx *Tx, s *Security) (err error) { +func UpdateSecurity(tx *Tx, s *models.Security) (err error) { user, err := GetUser(tx, s.UserId) if err != nil { return - } else if user.DefaultCurrency == s.SecurityId && s.Type != Currency { + } else if user.DefaultCurrency == s.SecurityId && s.Type != models.Currency { return errors.New("Cannot change security which is user's default currency to be non-currency") } @@ -160,7 +105,7 @@ func (e SecurityInUseError) Error() string { return e.message } -func DeleteSecurity(tx *Tx, s *Security) error { +func DeleteSecurity(tx *Tx, s *models.Security) error { // First, ensure no accounts are using this security accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId) @@ -193,7 +138,7 @@ func DeleteSecurity(tx *Tx, s *Security) error { return nil } -func ImportGetCreateSecurity(tx *Tx, userid int64, security *Security) (*Security, error) { +func ImportGetCreateSecurity(tx *Tx, userid int64, security *models.Security) (*models.Security, error) { security.UserId = userid if len(security.AlternateId) == 0 { // Always create a new local security if we can't match on the AlternateId @@ -204,7 +149,7 @@ func ImportGetCreateSecurity(tx *Tx, userid int64, security *Security) (*Securit return security, nil } - var securities []*Security + var securities []*models.Security _, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Preciseness=?", userid, security.Type, security.AlternateId, security.Precision) if err != nil { @@ -264,7 +209,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { return PriceHandler(r, context, user, securityid) } - var security Security + var security models.Security if err := ReadJSON(r, &security); err != nil { return NewError(3 /*Invalid Request*/) } @@ -281,7 +226,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { } else if r.Method == "GET" { if context.LastLevel() { //Return all securities - var sl SecurityList + var sl models.SecurityList securities, err := GetSecurities(context.Tx, user.UserId) if err != nil { @@ -324,7 +269,7 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { } if r.Method == "PUT" { - var security Security + var security models.Security if err := ReadJSON(r, &security); err != nil || security.SecurityId != securityid { return NewError(3 /*Invalid Request*/) } @@ -359,17 +304,17 @@ func SecurityHandler(r *http.Request, context *Context) ResponseWriterWriter { func SecurityTemplateHandler(r *http.Request, context *Context) ResponseWriterWriter { if r.Method == "GET" { - var sl SecurityList + var sl models.SecurityList query, _ := url.ParseQuery(r.URL.RawQuery) var limit int64 = -1 search := query.Get("search") - var _type SecurityType = 0 + var _type models.SecurityType = 0 typestring := query.Get("type") if len(typestring) > 0 { - _type = GetSecurityType(typestring) + _type = models.GetSecurityType(typestring) if _type == 0 { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/securities_lua.go b/internal/handlers/securities_lua.go index b294c1c..12783ce 100644 --- a/internal/handlers/securities_lua.go +++ b/internal/handlers/securities_lua.go @@ -9,8 +9,8 @@ import ( const luaSecurityTypeName = "security" -func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { - var security_map map[int64]*Security +func luaContextGetSecurities(L *lua.LState) (map[int64]*models.Security, error) { + var security_map map[int64]*models.Security ctx := L.Context() @@ -19,7 +19,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { return nil, errors.New("Couldn't find tx in lua's Context") } - security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security) + security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*models.Security) if !ok { user, ok := ctx.Value(userContextKey).(*models.User) if !ok { @@ -31,7 +31,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { return nil, err } - security_map = make(map[int64]*Security) + security_map = make(map[int64]*models.Security) for i := range *securities { security_map[(*securities)[i].SecurityId] = (*securities)[i] } @@ -43,7 +43,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { return security_map, nil } -func luaContextGetDefaultCurrency(L *lua.LState) (*Security, error) { +func luaContextGetDefaultCurrency(L *lua.LState) (*models.Security, error) { security_map, err := luaContextGetSecurities(L) if err != nil { return nil, err @@ -107,7 +107,7 @@ func luaRegisterSecurities(L *lua.LState) { L.SetGlobal("get_default_currency", getDefaultCurrencyFn) } -func SecurityToLua(L *lua.LState, security *Security) *lua.LUserData { +func SecurityToLua(L *lua.LState, security *models.Security) *lua.LUserData { ud := L.NewUserData() ud.Value = security L.SetMetatable(ud, L.GetTypeMetatable(luaSecurityTypeName)) @@ -115,9 +115,9 @@ func SecurityToLua(L *lua.LState, security *Security) *lua.LUserData { } // Checks whether the first lua argument is a *LUserData with *Security and returns this *Security. -func luaCheckSecurity(L *lua.LState, n int) *Security { +func luaCheckSecurity(L *lua.LState, n int) *models.Security { ud := L.CheckUserData(n) - if security, ok := ud.Value.(*Security); ok { + if security, ok := ud.Value.(*models.Security); ok { return security } L.ArgError(n, "security expected") diff --git a/internal/handlers/securities_test.go b/internal/handlers/securities_test.go index aab0a0c..8d786ad 100644 --- a/internal/handlers/securities_test.go +++ b/internal/handlers/securities_test.go @@ -2,19 +2,20 @@ package handlers_test import ( "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "strconv" "testing" ) -func createSecurity(client *http.Client, security *handlers.Security) (*handlers.Security, error) { - var s handlers.Security +func createSecurity(client *http.Client, security *models.Security) (*models.Security, error) { + var s models.Security err := create(client, security, &s, "/v1/securities/") return &s, err } -func getSecurity(client *http.Client, securityid int64) (*handlers.Security, error) { - var s handlers.Security +func getSecurity(client *http.Client, securityid int64) (*models.Security, error) { + var s models.Security err := read(client, &s, "/v1/securities/"+strconv.FormatInt(securityid, 10)) if err != nil { return nil, err @@ -22,8 +23,8 @@ func getSecurity(client *http.Client, securityid int64) (*handlers.Security, err return &s, nil } -func getSecurities(client *http.Client) (*handlers.SecurityList, error) { - var sl handlers.SecurityList +func getSecurities(client *http.Client) (*models.SecurityList, error) { + var sl models.SecurityList err := read(client, &sl, "/v1/securities/") if err != nil { return nil, err @@ -31,8 +32,8 @@ func getSecurities(client *http.Client) (*handlers.SecurityList, error) { return &sl, nil } -func updateSecurity(client *http.Client, security *handlers.Security) (*handlers.Security, error) { - var s handlers.Security +func updateSecurity(client *http.Client, security *models.Security) (*models.Security, error) { + var s models.Security err := update(client, security, &s, "/v1/securities/"+strconv.FormatInt(security.SecurityId, 10)) if err != nil { return nil, err @@ -40,7 +41,7 @@ func updateSecurity(client *http.Client, security *handlers.Security) (*handlers return &s, nil } -func deleteSecurity(client *http.Client, s *handlers.Security) error { +func deleteSecurity(client *http.Client, s *models.Security) error { err := remove(client, "/v1/securities/"+strconv.FormatInt(s.SecurityId, 10)) if err != nil { return err diff --git a/internal/handlers/security_templates_test.go b/internal/handlers/security_templates_test.go index 04baac6..1728576 100644 --- a/internal/handlers/security_templates_test.go +++ b/internal/handlers/security_templates_test.go @@ -2,12 +2,13 @@ package handlers_test import ( "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "io/ioutil" "testing" ) func TestSecurityTemplates(t *testing.T) { - var sl handlers.SecurityList + var sl models.SecurityList response, err := server.Client().Get(server.URL + "/v1/securitytemplates/?search=USD&type=currency") if err != nil { t.Fatal(err) @@ -30,7 +31,7 @@ func TestSecurityTemplates(t *testing.T) { num_usd := 0 if sl.Securities != nil { for _, s := range *sl.Securities { - if s.Type != handlers.Currency { + if s.Type != models.Currency { t.Fatalf("Requested Currency-only security templates, received a non-Currency template for %s", s.Name) } @@ -46,7 +47,7 @@ func TestSecurityTemplates(t *testing.T) { } func TestSecurityTemplateLimit(t *testing.T) { - var sl handlers.SecurityList + var sl models.SecurityList response, err := server.Client().Get(server.URL + "/v1/securitytemplates/?search=e&limit=5") if err != nil { t.Fatal(err) diff --git a/internal/handlers/testdata_test.go b/internal/handlers/testdata_test.go index db13381..c01544d 100644 --- a/internal/handlers/testdata_test.go +++ b/internal/handlers/testdata_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "strings" "testing" @@ -36,7 +37,7 @@ type TestData struct { initialized bool users []User clients []*http.Client - securities []handlers.Security + securities []models.Security prices []handlers.Price accounts []handlers.Account // accounts must appear after their parents in this slice transactions []handlers.Transaction @@ -170,14 +171,14 @@ var data = []TestData{ Email: "bbob+moneygo@my-domain.com", }, }, - securities: []handlers.Security{ + securities: []models.Security{ { UserId: 0, Name: "USD", Description: "US Dollar", Symbol: "$", Precision: 2, - Type: handlers.Currency, + Type: models.Currency, AlternateId: "840", }, { @@ -186,7 +187,7 @@ var data = []TestData{ Description: "SPDR S&P 500 ETF Trust", Symbol: "SPY", Precision: 5, - Type: handlers.Stock, + Type: models.Stock, AlternateId: "78462F103", }, { @@ -195,7 +196,7 @@ var data = []TestData{ Description: "Euro", Symbol: "€", Precision: 2, - Type: handlers.Currency, + Type: models.Currency, AlternateId: "978", }, { @@ -204,7 +205,7 @@ var data = []TestData{ Description: "Euro", Symbol: "€", Precision: 2, - Type: handlers.Currency, + Type: models.Currency, AlternateId: "978", }, }, diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 66b5737..ba1a9d0 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -54,7 +54,7 @@ func InsertUser(tx *Tx, u *models.User) error { } // Copy the security template and give it our new UserId - var security Security + var security models.Security security = *security_template security.UserId = u.UserId @@ -89,7 +89,7 @@ func UpdateUser(tx *Tx, u *models.User) error { return err } else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency { return errors.New("UserId and DefaultCurrency don't match the fetched security") - } else if security.Type != Currency { + } else if security.Type != models.Currency { return errors.New("New DefaultCurrency security is not a currency") } diff --git a/internal/models/securities.go b/internal/models/securities.go new file mode 100644 index 0000000..67557be --- /dev/null +++ b/internal/models/securities.go @@ -0,0 +1,62 @@ +package models + +import ( + "encoding/json" + "net/http" + "strings" +) + +type SecurityType int64 + +const ( + Currency SecurityType = 1 + Stock = 2 +) + +func GetSecurityType(typestring string) SecurityType { + if strings.EqualFold(typestring, "currency") { + return Currency + } else if strings.EqualFold(typestring, "stock") { + return Stock + } else { + return 0 + } +} + +type Security struct { + SecurityId int64 + UserId int64 + Name string + Description string + Symbol string + // Number of decimal digits (to the right of the decimal point) this + // security is precise to + Precision int `db:"Preciseness"` + Type SecurityType + // AlternateId is CUSIP for Type=Stock, ISO4217 for Type=Currency + AlternateId string +} + +type SecurityList struct { + Securities *[]*Security `json:"securities"` +} + +func (s *Security) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(s) +} + +func (s *Security) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(s) +} + +func (sl *SecurityList) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(sl) +} + +func (sl *SecurityList) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(sl) +} From 128ea57c4dfe87e669ca6c48d48db0f5277dc676 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Mon, 4 Dec 2017 05:55:25 -0500 Subject: [PATCH 4/6] Split accounts and transactions into models --- internal/db/db.go | 6 +- internal/handlers/accounts.go | 163 +++------------ internal/handlers/accounts_lua.go | 16 +- internal/handlers/accounts_test.go | 21 +- internal/handlers/common_test.go | 3 +- internal/handlers/gnucash.go | 42 ++-- internal/handlers/gnucash_test.go | 21 +- internal/handlers/imports.go | 18 +- internal/handlers/ofx.go | 270 ++++++++++++------------- internal/handlers/ofx_test.go | 21 +- internal/handlers/prices_lua.go | 3 +- internal/handlers/testdata_test.go | 50 ++--- internal/handlers/transactions.go | 185 +++-------------- internal/handlers/transactions_test.go | 51 ++--- internal/models/accounts.go | 118 +++++++++++ internal/models/transactions.go | 133 ++++++++++++ 16 files changed, 568 insertions(+), 553 deletions(-) create mode 100644 internal/models/accounts.go create mode 100644 internal/models/transactions.go diff --git a/internal/db/db.go b/internal/db/db.go index 0d71bfa..b92e9bd 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -36,10 +36,10 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { dbmap := &gorp.DbMap{Db: db, Dialect: dialect} dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId") dbmap.AddTableWithName(models.Session{}, "sessions").SetKeys(true, "SessionId") - dbmap.AddTableWithName(handlers.Account{}, "accounts").SetKeys(true, "AccountId") + dbmap.AddTableWithName(models.Account{}, "accounts").SetKeys(true, "AccountId") dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId") - dbmap.AddTableWithName(handlers.Transaction{}, "transactions").SetKeys(true, "TransactionId") - dbmap.AddTableWithName(handlers.Split{}, "splits").SetKeys(true, "SplitId") + dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId") + dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId") dbmap.AddTableWithName(handlers.Price{}, "prices").SetKeys(true, "PriceId") rtable := dbmap.AddTableWithName(handlers.Report{}, "reports").SetKeys(true, "ReportId") rtable.ColMap("Lua").SetMaxSize(handlers.LuaMaxLength + luaMaxLengthBuffer) diff --git a/internal/handlers/accounts.go b/internal/handlers/accounts.go index dd3fa14..2812fb4 100644 --- a/internal/handlers/accounts.go +++ b/internal/handlers/accounts.go @@ -1,127 +1,14 @@ package handlers import ( - "encoding/json" "errors" "github.com/aclindsa/moneygo/internal/models" "log" "net/http" - "strings" ) -type AccountType int64 - -const ( - Bank AccountType = 1 // start at 1 so that the default (0) is invalid - Cash = 2 - Asset = 3 - Liability = 4 - Investment = 5 - Income = 6 - Expense = 7 - Trading = 8 - Equity = 9 - Receivable = 10 - Payable = 11 -) - -var AccountTypes = []AccountType{ - Bank, - Cash, - Asset, - Liability, - Investment, - Income, - Expense, - Trading, - Equity, - Receivable, - Payable, -} - -func (t AccountType) String() string { - switch t { - case Bank: - return "Bank" - case Cash: - return "Cash" - case Asset: - return "Asset" - case Liability: - return "Liability" - case Investment: - return "Investment" - case Income: - return "Income" - case Expense: - return "Expense" - case Trading: - return "Trading" - case Equity: - return "Equity" - case Receivable: - return "Receivable" - case Payable: - return "Payable" - } - return "" -} - -type Account struct { - AccountId int64 - ExternalAccountId string - UserId int64 - SecurityId int64 - ParentAccountId int64 // -1 if this account is at the root - Type AccountType - Name string - - // monotonically-increasing account transaction version number. Used for - // allowing a client to ensure they have a consistent version when paging - // through transactions. - AccountVersion int64 `json:"Version"` - - // Optional fields specifying how to fetch transactions from a bank via OFX - OFXURL string - OFXORG string - OFXFID string - OFXUser string - OFXBankID string // OFX BankID (BrokerID if AcctType == Investment) - OFXAcctID string - OFXAcctType string // ofxgo.acctType - OFXClientUID string - OFXAppID string - OFXAppVer string - OFXVersion string - OFXNoIndent bool -} - -type AccountList struct { - Accounts *[]Account `json:"accounts"` -} - -func (a *Account) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(a) -} - -func (a *Account) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(a) -} - -func (al *AccountList) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(al) -} - -func (al *AccountList) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(al) -} - -func GetAccount(tx *Tx, accountid int64, userid int64) (*Account, error) { - var a Account +func GetAccount(tx *Tx, accountid int64, userid int64) (*models.Account, error) { + var a models.Account err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid) if err != nil { @@ -130,8 +17,8 @@ func GetAccount(tx *Tx, accountid int64, userid int64) (*Account, error) { return &a, nil } -func GetAccounts(tx *Tx, userid int64) (*[]Account, error) { - var accounts []Account +func GetAccounts(tx *Tx, userid int64) (*[]models.Account, error) { + var accounts []models.Account _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid) if err != nil { @@ -142,9 +29,9 @@ func GetAccounts(tx *Tx, userid int64) (*[]Account, error) { // Get (and attempt to create if it doesn't exist). Matches on UserId, // SecurityId, Type, Name, and ParentAccountId -func GetCreateAccount(tx *Tx, a Account) (*Account, error) { - var accounts []Account - var account Account +func GetCreateAccount(tx *Tx, a models.Account) (*models.Account, error) { + var accounts []models.Account + var account models.Account // Try to find the top-level trading account _, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId) @@ -170,9 +57,9 @@ func GetCreateAccount(tx *Tx, a Account) (*Account, error) { // Get (and attempt to create if it doesn't exist) the security/currency // trading account for the supplied security/currency -func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*Account, error) { - var tradingAccount Account - var account Account +func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) { + var tradingAccount models.Account + var account models.Account user, err := GetUser(tx, userid) if err != nil { @@ -180,7 +67,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*Account, error) } tradingAccount.UserId = userid - tradingAccount.Type = Trading + tradingAccount.Type = models.Trading tradingAccount.Name = "Trading" tradingAccount.SecurityId = user.DefaultCurrency tradingAccount.ParentAccountId = -1 @@ -200,7 +87,7 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*Account, error) account.Name = security.Name account.ParentAccountId = ta.AccountId account.SecurityId = securityid - account.Type = Trading + account.Type = models.Trading a, err := GetCreateAccount(tx, account) if err != nil { @@ -212,9 +99,9 @@ func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*Account, error) // Get (and attempt to create if it doesn't exist) the security/currency // imbalance account for the supplied security/currency -func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*Account, error) { - var imbalanceAccount Account - var account Account +func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*models.Account, error) { + var imbalanceAccount models.Account + var account models.Account xxxtemplate := FindSecurityTemplate("XXX", models.Currency) if xxxtemplate == nil { return nil, errors.New("Couldn't find XXX security template") @@ -228,7 +115,7 @@ func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*Account, erro imbalanceAccount.Name = "Imbalances" imbalanceAccount.ParentAccountId = -1 imbalanceAccount.SecurityId = xxxsecurity.SecurityId - imbalanceAccount.Type = Bank + imbalanceAccount.Type = models.Bank // Find/create the top-level trading account ia, err := GetCreateAccount(tx, imbalanceAccount) @@ -245,7 +132,7 @@ func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*Account, erro account.Name = security.Name account.ParentAccountId = ia.AccountId account.SecurityId = securityid - account.Type = Bank + account.Type = models.Bank a, err := GetCreateAccount(tx, account) if err != nil { @@ -273,7 +160,7 @@ func (cae CircularAccountsError) Error() string { return "Would result in circular account relationship" } -func insertUpdateAccount(tx *Tx, a *Account, insert bool) error { +func insertUpdateAccount(tx *Tx, a *models.Account, insert bool) error { found := make(map[int64]bool) if !insert { found[a.AccountId] = true @@ -286,7 +173,7 @@ func insertUpdateAccount(tx *Tx, a *Account, insert bool) error { return TooMuchNestingError{} } - var a Account + var a models.Account err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid) if err != nil { return ParentAccountMissingError{} @@ -329,15 +216,15 @@ func insertUpdateAccount(tx *Tx, a *Account, insert bool) error { return nil } -func InsertAccount(tx *Tx, a *Account) error { +func InsertAccount(tx *Tx, a *models.Account) error { return insertUpdateAccount(tx, a, true) } -func UpdateAccount(tx *Tx, a *Account) error { +func UpdateAccount(tx *Tx, a *models.Account) error { return insertUpdateAccount(tx, a, false) } -func DeleteAccount(tx *Tx, a *Account) error { +func DeleteAccount(tx *Tx, a *models.Account) error { if a.ParentAccountId != -1 { // Re-parent splits to this account's parent account if this account isn't a root account _, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId) @@ -384,7 +271,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { return AccountImportHandler(context, r, user, accountid) } - var account Account + var account models.Account if err := ReadJSON(r, &account); err != nil { return NewError(3 /*Invalid Request*/) } @@ -415,7 +302,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { } else if r.Method == "GET" { if context.LastLevel() { //Return all Accounts - var al AccountList + var al models.AccountList accounts, err := GetAccounts(context.Tx, user.UserId) if err != nil { log.Print(err) @@ -447,7 +334,7 @@ func AccountHandler(r *http.Request, context *Context) ResponseWriterWriter { return NewError(3 /*Invalid Request*/) } if r.Method == "PUT" { - var account Account + var account models.Account if err := ReadJSON(r, &account); err != nil || account.AccountId != accountid { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/accounts_lua.go b/internal/handlers/accounts_lua.go index 8a4d5db..5a2fc23 100644 --- a/internal/handlers/accounts_lua.go +++ b/internal/handlers/accounts_lua.go @@ -11,8 +11,8 @@ import ( const luaAccountTypeName = "account" -func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { - var account_map map[int64]*Account +func luaContextGetAccounts(L *lua.LState) (map[int64]*models.Account, error) { + var account_map map[int64]*models.Account ctx := L.Context() @@ -21,7 +21,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { return nil, errors.New("Couldn't find tx in lua's Context") } - account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account) + account_map, ok = ctx.Value(accountsContextKey).(map[int64]*models.Account) if !ok { user, ok := ctx.Value(userContextKey).(*models.User) if !ok { @@ -33,7 +33,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { return nil, err } - account_map = make(map[int64]*Account) + account_map = make(map[int64]*models.Account) for i := range *accounts { account_map[(*accounts)[i].AccountId] = &(*accounts)[i] } @@ -69,7 +69,7 @@ func luaRegisterAccounts(L *lua.LState) { L.SetField(mt, "__eq", L.NewFunction(luaAccount__eq)) L.SetField(mt, "__metatable", lua.LString("protected")) - for _, accttype := range AccountTypes { + for _, accttype := range models.AccountTypes { L.SetField(mt, accttype.String(), lua.LNumber(float64(accttype))) } @@ -79,7 +79,7 @@ func luaRegisterAccounts(L *lua.LState) { L.SetGlobal("get_accounts", getAccountsFn) } -func AccountToLua(L *lua.LState, account *Account) *lua.LUserData { +func AccountToLua(L *lua.LState, account *models.Account) *lua.LUserData { ud := L.NewUserData() ud.Value = account L.SetMetatable(ud, L.GetTypeMetatable(luaAccountTypeName)) @@ -87,9 +87,9 @@ func AccountToLua(L *lua.LState, account *Account) *lua.LUserData { } // Checks whether the first lua argument is a *LUserData with *Account and returns this *Account. -func luaCheckAccount(L *lua.LState, n int) *Account { +func luaCheckAccount(L *lua.LState, n int) *models.Account { ud := L.CheckUserData(n) - if account, ok := ud.Value.(*Account); ok { + if account, ok := ud.Value.(*models.Account); ok { return account } L.ArgError(n, "account expected") diff --git a/internal/handlers/accounts_test.go b/internal/handlers/accounts_test.go index 0abd029..d0b92a4 100644 --- a/internal/handlers/accounts_test.go +++ b/internal/handlers/accounts_test.go @@ -2,19 +2,20 @@ package handlers_test import ( "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "strconv" "testing" ) -func createAccount(client *http.Client, account *handlers.Account) (*handlers.Account, error) { - var a handlers.Account +func createAccount(client *http.Client, account *models.Account) (*models.Account, error) { + var a models.Account err := create(client, account, &a, "/v1/accounts/") return &a, err } -func getAccount(client *http.Client, accountid int64) (*handlers.Account, error) { - var a handlers.Account +func getAccount(client *http.Client, accountid int64) (*models.Account, error) { + var a models.Account err := read(client, &a, "/v1/accounts/"+strconv.FormatInt(accountid, 10)) if err != nil { return nil, err @@ -22,8 +23,8 @@ func getAccount(client *http.Client, accountid int64) (*handlers.Account, error) return &a, nil } -func getAccounts(client *http.Client) (*handlers.AccountList, error) { - var al handlers.AccountList +func getAccounts(client *http.Client) (*models.AccountList, error) { + var al models.AccountList err := read(client, &al, "/v1/accounts/") if err != nil { return nil, err @@ -31,8 +32,8 @@ func getAccounts(client *http.Client) (*handlers.AccountList, error) { return &al, nil } -func updateAccount(client *http.Client, account *handlers.Account) (*handlers.Account, error) { - var a handlers.Account +func updateAccount(client *http.Client, account *models.Account) (*models.Account, error) { + var a models.Account err := update(client, account, &a, "/v1/accounts/"+strconv.FormatInt(account.AccountId, 10)) if err != nil { return nil, err @@ -40,7 +41,7 @@ func updateAccount(client *http.Client, account *handlers.Account) (*handlers.Ac return &a, nil } -func deleteAccount(client *http.Client, a *handlers.Account) error { +func deleteAccount(client *http.Client, a *models.Account) error { err := remove(client, "/v1/accounts/"+strconv.FormatInt(a.AccountId, 10)) if err != nil { return err @@ -137,7 +138,7 @@ func TestUpdateAccount(t *testing.T) { curr := d.accounts[i] curr.Name = "blah" - curr.Type = handlers.Payable + curr.Type = models.Payable for _, s := range d.securities { if s.UserId == curr.UserId { curr.SecurityId = s.SecurityId diff --git a/internal/handlers/common_test.go b/internal/handlers/common_test.go index 5ff8c02..a0ef7f8 100644 --- a/internal/handlers/common_test.go +++ b/internal/handlers/common_test.go @@ -7,6 +7,7 @@ import ( "github.com/aclindsa/moneygo/internal/config" "github.com/aclindsa/moneygo/internal/db" "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "io" "io/ioutil" "log" @@ -202,7 +203,7 @@ func uploadFile(client *http.Client, filename, urlsuffix string) error { return nil } -func accountBalanceHelper(t *testing.T, client *http.Client, account *handlers.Account, balance string) { +func accountBalanceHelper(t *testing.T, client *http.Client, account *models.Account, balance string) { t.Helper() transactions, err := getAccountTransactions(client, account.AccountId, 0, 0, "") if err != nil { diff --git a/internal/handlers/gnucash.go b/internal/handlers/gnucash.go index 9670f1e..f99978e 100644 --- a/internal/handlers/gnucash.go +++ b/internal/handlers/gnucash.go @@ -127,8 +127,8 @@ type GnucashXMLImport struct { type GnucashImport struct { Securities []models.Security - Accounts []Account - Transactions []Transaction + Accounts []models.Account + Transactions []models.Transaction Prices []Price } @@ -206,7 +206,7 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { //Translate to our account format, figuring out parent relationships for guid := range accountMap { ga := accountMap[guid] - var a Account + var a models.Account a.AccountId = ga.accountid if ga.ParentAccountId == rootAccount.AccountId { @@ -229,29 +229,29 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { //TODO find account types switch ga.Type { default: - a.Type = Bank + a.Type = models.Bank case "ASSET": - a.Type = Asset + a.Type = models.Asset case "BANK": - a.Type = Bank + a.Type = models.Bank case "CASH": - a.Type = Cash + a.Type = models.Cash case "CREDIT", "LIABILITY": - a.Type = Liability + a.Type = models.Liability case "EQUITY": - a.Type = Equity + a.Type = models.Equity case "EXPENSE": - a.Type = Expense + a.Type = models.Expense case "INCOME": - a.Type = Income + a.Type = models.Income case "PAYABLE": - a.Type = Payable + a.Type = models.Payable case "RECEIVABLE": - a.Type = Receivable + a.Type = models.Receivable case "MUTUAL", "STOCK": - a.Type = Investment + a.Type = models.Investment case "TRADING": - a.Type = Trading + a.Type = models.Trading } gncimport.Accounts = append(gncimport.Accounts, a) @@ -261,20 +261,20 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { for i := range gncxml.Transactions { gt := gncxml.Transactions[i] - t := new(Transaction) + t := new(models.Transaction) t.Description = gt.Description t.Date = gt.DatePosted.Date.Time for j := range gt.Splits { gs := gt.Splits[j] - s := new(Split) + s := new(models.Split) switch gs.Status { default: // 'n', or not present - s.Status = Imported + s.Status = models.Imported case "c": - s.Status = Cleared + s.Status = models.Cleared case "y": - s.Status = Reconciled + s.Status = models.Reconciled } account, ok := accountMap[gs.AccountId] @@ -437,7 +437,7 @@ func GnucashImportHandler(r *http.Request, context *Context) ResponseWriterWrite } split.AccountId = acctId - exists, err := split.AlreadyImported(context.Tx) + exists, err := SplitAlreadyImported(context.Tx, split) if err != nil { log.Print("Error checking if split was already imported:", err) return NewError(999 /*Internal Error*/) diff --git a/internal/handlers/gnucash_test.go b/internal/handlers/gnucash_test.go index 7eacc03..960078f 100644 --- a/internal/handlers/gnucash_test.go +++ b/internal/handlers/gnucash_test.go @@ -1,7 +1,6 @@ package handlers_test import ( - "github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/models" "net/http" "testing" @@ -32,19 +31,19 @@ func TestImportGnucash(t *testing.T) { } // Next, find the Expenses/Groceries account and verify it's balance - var income, equity, liabilities, expenses, salary, creditcard, groceries, cable, openingbalances *handlers.Account + var income, equity, liabilities, expenses, salary, creditcard, groceries, cable, openingbalances *models.Account accounts, err := getAccounts(d.clients[0]) if err != nil { t.Fatalf("Error fetching accounts: %s\n", err) } for i, account := range *accounts.Accounts { - if account.Name == "Income" && account.Type == handlers.Income && account.ParentAccountId == -1 { + if account.Name == "Income" && account.Type == models.Income && account.ParentAccountId == -1 { income = &(*accounts.Accounts)[i] - } else if account.Name == "Equity" && account.Type == handlers.Equity && account.ParentAccountId == -1 { + } else if account.Name == "Equity" && account.Type == models.Equity && account.ParentAccountId == -1 { equity = &(*accounts.Accounts)[i] - } else if account.Name == "Liabilities" && account.Type == handlers.Liability && account.ParentAccountId == -1 { + } else if account.Name == "Liabilities" && account.Type == models.Liability && account.ParentAccountId == -1 { liabilities = &(*accounts.Accounts)[i] - } else if account.Name == "Expenses" && account.Type == handlers.Expense && account.ParentAccountId == -1 { + } else if account.Name == "Expenses" && account.Type == models.Expense && account.ParentAccountId == -1 { expenses = &(*accounts.Accounts)[i] } } @@ -61,15 +60,15 @@ func TestImportGnucash(t *testing.T) { t.Fatalf("Couldn't find 'Expenses' account") } for i, account := range *accounts.Accounts { - if account.Name == "Salary" && account.Type == handlers.Income && account.ParentAccountId == income.AccountId { + if account.Name == "Salary" && account.Type == models.Income && account.ParentAccountId == income.AccountId { salary = &(*accounts.Accounts)[i] - } else if account.Name == "Opening Balances" && account.Type == handlers.Equity && account.ParentAccountId == equity.AccountId { + } else if account.Name == "Opening Balances" && account.Type == models.Equity && account.ParentAccountId == equity.AccountId { openingbalances = &(*accounts.Accounts)[i] - } else if account.Name == "Credit Card" && account.Type == handlers.Liability && account.ParentAccountId == liabilities.AccountId { + } else if account.Name == "Credit Card" && account.Type == models.Liability && account.ParentAccountId == liabilities.AccountId { creditcard = &(*accounts.Accounts)[i] - } else if account.Name == "Groceries" && account.Type == handlers.Expense && account.ParentAccountId == expenses.AccountId { + } else if account.Name == "Groceries" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId { groceries = &(*accounts.Accounts)[i] - } else if account.Name == "Cable" && account.Type == handlers.Expense && account.ParentAccountId == expenses.AccountId { + } else if account.Name == "Cable" && account.Type == models.Expense && account.ParentAccountId == expenses.AccountId { cable = &(*accounts.Accounts)[i] } } diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index 442c4e3..78d5236 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -78,7 +78,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re // TODO Ensure all transactions have at least one split in the account // we're importing to? - var transactions []Transaction + var transactions []models.Transaction for _, transaction := range itl.Transactions { transaction.UserId = user.UserId @@ -91,7 +91,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re // and fixup the SecurityId to be a valid one for this user's actual // securities instead of a placeholder from the import for _, split := range transaction.Splits { - split.Status = Imported + split.Status = models.Imported if split.AccountId != -1 { if split.AccountId != importedAccount.AccountId { log.Print("Imported split's AccountId wasn't -1 but also didn't match the account") @@ -101,7 +101,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re } else if split.SecurityId != -1 { if sec, ok := securitymap[split.SecurityId]; ok { // TODO try to auto-match splits to existing accounts based on past transactions that look like this one - if split.ImportSplitType == TradingAccount { + if split.ImportSplitType == models.TradingAccount { // Find/make trading account if we're that type of split trading_account, err := GetTradingAccount(tx, user.UserId, sec.SecurityId) if err != nil { @@ -110,8 +110,8 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re } split.AccountId = trading_account.AccountId split.SecurityId = -1 - } else if split.ImportSplitType == SubAccount { - subaccount := &Account{ + } else if split.ImportSplitType == models.SubAccount { + subaccount := &models.Account{ UserId: user.UserId, Name: sec.Name, ParentAccountId: account.AccountId, @@ -138,7 +138,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re } } - imbalances, err := transaction.GetImbalances(tx) + imbalances, err := GetTransactionImbalances(tx, &transaction) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -155,7 +155,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re } // Add new split to fixup imbalance - split := new(Split) + split := new(models.Split) r := new(big.Rat) r.Neg(&imbalance) security, err := GetSecurity(tx, imbalanced_security, user.UserId) @@ -186,7 +186,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) Re split.SecurityId = -1 } - exists, err := split.AlreadyImported(tx) + exists, err := SplitAlreadyImported(tx, split) if err != nil { log.Print("Error checking if split was already imported:", err) return NewError(999 /*Internal Error*/) @@ -251,7 +251,7 @@ func OFXImportHandler(context *Context, r *http.Request, user *models.User, acco return NewError(999 /*Internal Error*/) } - if account.Type == Investment { + if account.Type == models.Investment { // Investment account statementRequest := ofxgo.InvStatementRequest{ TrnUID: *transactionuid, diff --git a/internal/handlers/ofx.go b/internal/handlers/ofx.go index befd147..a183aab 100644 --- a/internal/handlers/ofx.go +++ b/internal/handlers/ofx.go @@ -11,8 +11,8 @@ import ( type OFXImport struct { Securities []models.Security - Accounts []Account - Transactions []Transaction + Accounts []models.Account + Transactions []models.Transaction // Balances map[int64]string // map AccountIDs to ending balances } @@ -51,8 +51,8 @@ func (i *OFXImport) GetAddCurrency(isoname string) (*models.Security, error) { return &security, nil } -func (i *OFXImport) AddTransaction(tran *ofxgo.Transaction, account *Account) error { - var t Transaction +func (i *OFXImport) AddTransaction(tran *ofxgo.Transaction, account *models.Account) error { + var t models.Transaction t.Date = tran.DtPosted.UTC() @@ -70,7 +70,7 @@ func (i *OFXImport) AddTransaction(tran *ofxgo.Transaction, account *Account) er } } - var s1, s2 Split + var s1, s2 models.Split if len(tran.ExtdName) > 0 { s1.Memo = tran.ExtdName.String() } @@ -94,15 +94,15 @@ func (i *OFXImport) AddTransaction(tran *ofxgo.Transaction, account *Account) er s1.RemoteId = "ofx:" + tran.FiTID.String() // TODO CorrectFiTID/CorrectAction? - s1.ImportSplitType = ImportAccount - s2.ImportSplitType = ExternalAccount + s1.ImportSplitType = models.ImportAccount + s2.ImportSplitType = models.ExternalAccount security := i.Securities[account.SecurityId-1] s1.Amount = amt.FloatString(security.Precision) s2.Amount = amt.Neg(amt).FloatString(security.Precision) - s1.Status = Imported - s2.Status = Imported + s1.Status = models.Imported + s2.Status = models.Imported s1.AccountId = account.AccountId s2.AccountId = -1 @@ -122,12 +122,12 @@ func (i *OFXImport) importOFXBank(stmt *ofxgo.StatementResponse) error { return err } - account := Account{ + account := models.Account{ AccountId: int64(len(i.Accounts) + 1), ExternalAccountId: stmt.BankAcctFrom.AcctID.String(), SecurityId: security.SecurityId, ParentAccountId: -1, - Type: Bank, + Type: models.Bank, } if stmt.BankTranList != nil { @@ -149,12 +149,12 @@ func (i *OFXImport) importOFXCC(stmt *ofxgo.CCStatementResponse) error { return err } - account := Account{ + account := models.Account{ AccountId: int64(len(i.Accounts) + 1), ExternalAccountId: stmt.CCAcctFrom.AcctID.String(), SecurityId: security.SecurityId, ParentAccountId: -1, - Type: Liability, + Type: models.Liability, } i.Accounts = append(i.Accounts, account) @@ -208,14 +208,14 @@ func (i *OFXImport) importSecurities(seclist *ofxgo.SecurityList) error { return nil } -func (i *OFXImport) GetInvTran(invtran *ofxgo.InvTran) Transaction { - var t Transaction +func (i *OFXImport) GetInvTran(invtran *ofxgo.InvTran) models.Transaction { + var t models.Transaction t.Description = string(invtran.Memo) t.Date = invtran.DtTrade.UTC() return t } -func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, account *models.Account) (*models.Transaction, error) { t := i.GetInvTran(&buy.InvTran) security, err := i.GetSecurityAlternateId(string(buy.SecID.UniqueID), models.Stock) @@ -254,10 +254,10 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac } if num := commission.Num(); !num.IsInt64() || num.Int64() != 0 { - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: Commission, + Status: models.Imported, + ImportSplitType: models.Commission, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), @@ -266,10 +266,10 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac }) } if num := taxes.Num(); !num.IsInt64() || num.Int64() != 0 { - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: Taxes, + Status: models.Imported, + ImportSplitType: models.Taxes, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), @@ -278,10 +278,10 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac }) } if num := fees.Num(); !num.IsInt64() || num.Int64() != 0 { - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: Fees, + Status: models.Imported, + ImportSplitType: models.Fees, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), @@ -290,10 +290,10 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac }) } if num := load.Num(); !num.IsInt64() || num.Int64() != 0 { - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: Load, + Status: models.Imported, + ImportSplitType: models.Load, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), @@ -301,20 +301,20 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac Amount: load.FloatString(curdef.Precision), }) } - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: ImportAccount, + Status: models.Imported, + ImportSplitType: models.ImportAccount, AccountId: account.AccountId, SecurityId: -1, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), Memo: memo, Amount: total.FloatString(curdef.Precision), }) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: TradingAccount, + Status: models.Imported, + ImportSplitType: models.TradingAccount, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), @@ -324,10 +324,10 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac var units big.Rat units.Abs(&buy.Units.Rat) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: SubAccount, + Status: models.Imported, + ImportSplitType: models.SubAccount, AccountId: -1, SecurityId: security.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), @@ -335,10 +335,10 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac Amount: units.FloatString(security.Precision), }) units.Neg(&units) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: TradingAccount, + Status: models.Imported, + ImportSplitType: models.TradingAccount, AccountId: -1, SecurityId: security.SecurityId, RemoteId: "ofx:" + buy.InvTran.FiTID.String(), @@ -349,7 +349,7 @@ func (i *OFXImport) GetInvBuyTran(buy *ofxgo.InvBuy, curdef *models.Security, ac return &t, nil } -func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *models.Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *models.Security, account *models.Account) (*models.Transaction, error) { t := i.GetInvTran(&income.InvTran) security, err := i.GetSecurityAlternateId(string(income.SecID.UniqueID), models.Stock) @@ -370,10 +370,10 @@ func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *models.Security, total.Mul(&total, &income.Currency.CurRate.Rat) } - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: ImportAccount, + Status: models.Imported, + ImportSplitType: models.ImportAccount, AccountId: account.AccountId, SecurityId: -1, RemoteId: "ofx:" + income.InvTran.FiTID.String(), @@ -381,10 +381,10 @@ func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *models.Security, Amount: total.FloatString(curdef.Precision), }) total.Neg(&total) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: IncomeAccount, + Status: models.Imported, + ImportSplitType: models.IncomeAccount, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + income.InvTran.FiTID.String(), @@ -395,7 +395,7 @@ func (i *OFXImport) GetIncomeTran(income *ofxgo.Income, curdef *models.Security, return &t, nil } -func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *models.Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *models.Security, account *models.Account) (*models.Transaction, error) { t := i.GetInvTran(&expense.InvTran) security, err := i.GetSecurityAlternateId(string(expense.SecID.UniqueID), models.Stock) @@ -415,10 +415,10 @@ func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *models. total.Mul(&total, &expense.Currency.CurRate.Rat) } - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: ImportAccount, + Status: models.Imported, + ImportSplitType: models.ImportAccount, AccountId: account.AccountId, SecurityId: -1, RemoteId: "ofx:" + expense.InvTran.FiTID.String(), @@ -426,10 +426,10 @@ func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *models. Amount: total.FloatString(curdef.Precision), }) total.Neg(&total) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: ExpenseAccount, + Status: models.Imported, + ImportSplitType: models.ExpenseAccount, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + expense.InvTran.FiTID.String(), @@ -440,7 +440,7 @@ func (i *OFXImport) GetInvExpenseTran(expense *ofxgo.InvExpense, curdef *models. return &t, nil } -func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curdef *models.Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curdef *models.Security, account *models.Account) (*models.Transaction, error) { t := i.GetInvTran(&marginint.InvTran) memo := string(marginint.InvTran.Memo) @@ -454,10 +454,10 @@ func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curde total.Mul(&total, &marginint.Currency.CurRate.Rat) } - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: ImportAccount, + Status: models.Imported, + ImportSplitType: models.ImportAccount, AccountId: account.AccountId, SecurityId: -1, RemoteId: "ofx:" + marginint.InvTran.FiTID.String(), @@ -465,10 +465,10 @@ func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curde Amount: total.FloatString(curdef.Precision), }) total.Neg(&total) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: IncomeAccount, + Status: models.Imported, + ImportSplitType: models.IncomeAccount, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + marginint.InvTran.FiTID.String(), @@ -479,7 +479,7 @@ func (i *OFXImport) GetMarginInterestTran(marginint *ofxgo.MarginInterest, curde return &t, nil } -func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Security, account *models.Account) (*models.Transaction, error) { t := i.GetInvTran(&reinvest.InvTran) security, err := i.GetSecurityAlternateId(string(reinvest.SecID.UniqueID), models.Stock) @@ -518,10 +518,10 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec } if num := commission.Num(); !num.IsInt64() || num.Int64() != 0 { - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: Commission, + Status: models.Imported, + ImportSplitType: models.Commission, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), @@ -530,10 +530,10 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec }) } if num := taxes.Num(); !num.IsInt64() || num.Int64() != 0 { - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: Taxes, + Status: models.Imported, + ImportSplitType: models.Taxes, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), @@ -542,10 +542,10 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec }) } if num := fees.Num(); !num.IsInt64() || num.Int64() != 0 { - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: Fees, + Status: models.Imported, + ImportSplitType: models.Fees, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), @@ -554,10 +554,10 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec }) } if num := load.Num(); !num.IsInt64() || num.Int64() != 0 { - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: Load, + Status: models.Imported, + ImportSplitType: models.Load, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), @@ -565,10 +565,10 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec Amount: load.FloatString(curdef.Precision), }) } - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: ImportAccount, + Status: models.Imported, + ImportSplitType: models.ImportAccount, AccountId: account.AccountId, SecurityId: -1, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), @@ -576,10 +576,10 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec Amount: total.FloatString(curdef.Precision), }) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: IncomeAccount, + Status: models.Imported, + ImportSplitType: models.IncomeAccount, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), @@ -587,20 +587,20 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec Amount: total.FloatString(curdef.Precision), }) total.Neg(&total) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: ImportAccount, + Status: models.Imported, + ImportSplitType: models.ImportAccount, AccountId: account.AccountId, SecurityId: -1, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), Memo: memo, Amount: total.FloatString(curdef.Precision), }) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: TradingAccount, + Status: models.Imported, + ImportSplitType: models.TradingAccount, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), @@ -610,10 +610,10 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec var units big.Rat units.Abs(&reinvest.Units.Rat) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: SubAccount, + Status: models.Imported, + ImportSplitType: models.SubAccount, AccountId: -1, SecurityId: security.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), @@ -621,10 +621,10 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec Amount: units.FloatString(security.Precision), }) units.Neg(&units) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: TradingAccount, + Status: models.Imported, + ImportSplitType: models.TradingAccount, AccountId: -1, SecurityId: security.SecurityId, RemoteId: "ofx:" + reinvest.InvTran.FiTID.String(), @@ -635,7 +635,7 @@ func (i *OFXImport) GetReinvestTran(reinvest *ofxgo.Reinvest, curdef *models.Sec return &t, nil } -func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *models.Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *models.Security, account *models.Account) (*models.Transaction, error) { t := i.GetInvTran(&retofcap.InvTran) security, err := i.GetSecurityAlternateId(string(retofcap.SecID.UniqueID), models.Stock) @@ -655,10 +655,10 @@ func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *models.Sec total.Mul(&total, &retofcap.Currency.CurRate.Rat) } - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: ImportAccount, + Status: models.Imported, + ImportSplitType: models.ImportAccount, AccountId: account.AccountId, SecurityId: -1, RemoteId: "ofx:" + retofcap.InvTran.FiTID.String(), @@ -666,10 +666,10 @@ func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *models.Sec Amount: total.FloatString(curdef.Precision), }) total.Neg(&total) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: IncomeAccount, + Status: models.Imported, + ImportSplitType: models.IncomeAccount, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + retofcap.InvTran.FiTID.String(), @@ -680,7 +680,7 @@ func (i *OFXImport) GetRetOfCapTran(retofcap *ofxgo.RetOfCap, curdef *models.Sec return &t, nil } -func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, account *Account) (*Transaction, error) { +func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, account *models.Account) (*models.Transaction, error) { t := i.GetInvTran(&sell.InvTran) security, err := i.GetSecurityAlternateId(string(sell.SecID.UniqueID), models.Stock) @@ -722,10 +722,10 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, } if num := commission.Num(); !num.IsInt64() || num.Int64() != 0 { - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: Commission, + Status: models.Imported, + ImportSplitType: models.Commission, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), @@ -734,10 +734,10 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, }) } if num := taxes.Num(); !num.IsInt64() || num.Int64() != 0 { - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: Taxes, + Status: models.Imported, + ImportSplitType: models.Taxes, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), @@ -746,10 +746,10 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, }) } if num := fees.Num(); !num.IsInt64() || num.Int64() != 0 { - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: Fees, + Status: models.Imported, + ImportSplitType: models.Fees, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), @@ -758,10 +758,10 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, }) } if num := load.Num(); !num.IsInt64() || num.Int64() != 0 { - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: Load, + Status: models.Imported, + ImportSplitType: models.Load, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), @@ -769,20 +769,20 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, Amount: load.FloatString(curdef.Precision), }) } - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: ImportAccount, + Status: models.Imported, + ImportSplitType: models.ImportAccount, AccountId: account.AccountId, SecurityId: -1, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), Memo: memo, Amount: total.FloatString(curdef.Precision), }) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: TradingAccount, + Status: models.Imported, + ImportSplitType: models.TradingAccount, AccountId: -1, SecurityId: curdef.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), @@ -792,10 +792,10 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, var units big.Rat units.Abs(&sell.Units.Rat) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: TradingAccount, + Status: models.Imported, + ImportSplitType: models.TradingAccount, AccountId: -1, SecurityId: security.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), @@ -803,10 +803,10 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, Amount: units.FloatString(security.Precision), }) units.Neg(&units) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: SubAccount, + Status: models.Imported, + ImportSplitType: models.SubAccount, AccountId: -1, SecurityId: security.SecurityId, RemoteId: "ofx:" + sell.InvTran.FiTID.String(), @@ -817,7 +817,7 @@ func (i *OFXImport) GetInvSellTran(sell *ofxgo.InvSell, curdef *models.Security, return &t, nil } -func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *Account) (*Transaction, error) { +func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *models.Account) (*models.Transaction, error) { t := i.GetInvTran(&transfer.InvTran) security, err := i.GetSecurityAlternateId(string(transfer.SecID.UniqueID), models.Stock) @@ -834,10 +834,10 @@ func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *Account) units.Neg(&transfer.Units.Rat) } - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: SubAccount, + Status: models.Imported, + ImportSplitType: models.SubAccount, AccountId: -1, SecurityId: security.SecurityId, RemoteId: "ofx:" + transfer.InvTran.FiTID.String(), @@ -845,10 +845,10 @@ func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *Account) Amount: units.FloatString(security.Precision), }) units.Neg(&units) - t.Splits = append(t.Splits, &Split{ + t.Splits = append(t.Splits, &models.Split{ // TODO ReversalFiTID? - Status: Imported, - ImportSplitType: ExternalAccount, + Status: models.Imported, + ImportSplitType: models.ExternalAccount, AccountId: -1, SecurityId: security.SecurityId, RemoteId: "ofx:" + transfer.InvTran.FiTID.String(), @@ -859,12 +859,12 @@ func (i *OFXImport) GetTransferTran(transfer *ofxgo.Transfer, account *Account) return &t, nil } -func (i *OFXImport) AddInvTransaction(invtran *ofxgo.InvTransaction, account *Account, curdef *models.Security) error { +func (i *OFXImport) AddInvTransaction(invtran *ofxgo.InvTransaction, account *models.Account, curdef *models.Security) error { if curdef.SecurityId < 1 || curdef.SecurityId > int64(len(i.Securities)) { return errors.New("Internal error: security index not found in OFX import\n") } - var t *Transaction + var t *models.Transaction var err error if tran, ok := (*invtran).(ofxgo.BuyDebt); ok { t, err = i.GetInvBuyTran(&tran.InvBuy, curdef, account) @@ -926,12 +926,12 @@ func (i *OFXImport) importOFXInv(stmt *ofxgo.InvStatementResponse) error { return err } - account := Account{ + account := models.Account{ AccountId: int64(len(i.Accounts) + 1), ExternalAccountId: stmt.InvAcctFrom.AcctID.String(), SecurityId: security.SecurityId, ParentAccountId: -1, - Type: Investment, + Type: models.Investment, } i.Accounts = append(i.Accounts, account) diff --git a/internal/handlers/ofx_test.go b/internal/handlers/ofx_test.go index 11c3f04..4d78f25 100644 --- a/internal/handlers/ofx_test.go +++ b/internal/handlers/ofx_test.go @@ -2,7 +2,6 @@ package handlers_test import ( "fmt" - "github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/models" "net/http" "strconv" @@ -77,7 +76,7 @@ func findSecurity(client *http.Client, symbol string, tipe models.SecurityType) return nil, fmt.Errorf("Unable to find security: \"%s\"", symbol) } -func findAccount(client *http.Client, name string, tipe handlers.AccountType, securityid int64) (*handlers.Account, error) { +func findAccount(client *http.Client, name string, tipe models.AccountType, securityid int64) (*models.Account, error) { accounts, err := getAccounts(client) if err != nil { return nil, err @@ -105,11 +104,11 @@ func TestImportOFX401kMutualFunds(t *testing.T) { t.Fatalf("Error removing default security: %s\n", err) } - account := &handlers.Account{ + account := &models.Account{ SecurityId: d.securities[0].SecurityId, UserId: d.users[0].UserId, ParentAccountId: -1, - Type: handlers.Investment, + Type: models.Investment, Name: "401k", } @@ -130,14 +129,14 @@ func TestImportOFX401kMutualFunds(t *testing.T) { if err != nil { t.Fatalf("Error finding VANGUARD TARGET 2045 security: %s\n", err) } - tradingaccount, err := findAccount(d.clients[0], "VANGUARD TARGET 2045", handlers.Trading, security.SecurityId) + tradingaccount, err := findAccount(d.clients[0], "VANGUARD TARGET 2045", models.Trading, security.SecurityId) if err != nil { t.Fatalf("Error finding VANGUARD TARGET 2045 trading account: %s\n", err) } accountBalanceHelper(t, d.clients[0], tradingaccount, "-3.35400") // Ensure actual holding account was created and in the correct place - investmentaccount, err := findAccount(d.clients[0], "VANGUARD TARGET 2045", handlers.Investment, security.SecurityId) + investmentaccount, err := findAccount(d.clients[0], "VANGUARD TARGET 2045", models.Investment, security.SecurityId) if err != nil { t.Fatalf("Error finding VANGUARD TARGET 2045 investment account: %s\n", err) } @@ -164,11 +163,11 @@ func TestImportOFXBrokerage(t *testing.T) { } // Create the brokerage account - account := &handlers.Account{ + account := &models.Account{ SecurityId: d.securities[0].SecurityId, UserId: d.users[0].UserId, ParentAccountId: -1, - Type: handlers.Investment, + Type: models.Investment, Name: "Personal Brokerage", } @@ -185,7 +184,7 @@ func TestImportOFXBrokerage(t *testing.T) { // Make sure the USD trading account was created and has the right // value - usdtrading, err := findAccount(d.clients[0], "USD", handlers.Trading, d.users[0].DefaultCurrency) + usdtrading, err := findAccount(d.clients[0], "USD", models.Trading, d.users[0].DefaultCurrency) if err != nil { t.Fatalf("Error finding USD trading account: %s\n", err) } @@ -210,14 +209,14 @@ func TestImportOFXBrokerage(t *testing.T) { t.Fatalf("Error finding security: %s\n", err) } - account, err := findAccount(d.clients[0], check.Name, handlers.Investment, security.SecurityId) + account, err := findAccount(d.clients[0], check.Name, models.Investment, security.SecurityId) if err != nil { t.Fatalf("Error finding trading account: %s\n", err) } accountBalanceHelper(t, d.clients[0], account, check.Balance) - tradingaccount, err := findAccount(d.clients[0], check.Name, handlers.Trading, security.SecurityId) + tradingaccount, err := findAccount(d.clients[0], check.Name, models.Trading, security.SecurityId) if err != nil { t.Fatalf("Error finding trading account: %s\n", err) } diff --git a/internal/handlers/prices_lua.go b/internal/handlers/prices_lua.go index 0c3fe89..8450319 100644 --- a/internal/handlers/prices_lua.go +++ b/internal/handlers/prices_lua.go @@ -1,6 +1,7 @@ package handlers import ( + "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" ) @@ -59,7 +60,7 @@ func luaPrice__index(L *lua.LState) int { } L.Push(SecurityToLua(L, c)) case "Value", "value": - amt, err := GetBigAmount(p.Value) + amt, err := models.GetBigAmount(p.Value) if err != nil { panic(err) } diff --git a/internal/handlers/testdata_test.go b/internal/handlers/testdata_test.go index c01544d..2cef0cd 100644 --- a/internal/handlers/testdata_test.go +++ b/internal/handlers/testdata_test.go @@ -39,8 +39,8 @@ type TestData struct { clients []*http.Client securities []models.Security prices []handlers.Price - accounts []handlers.Account // accounts must appear after their parents in this slice - transactions []handlers.Transaction + accounts []models.Account // accounts must appear after their parents in this slice + transactions []models.Transaction reports []handlers.Report tabulations []handlers.Tabulation } @@ -113,7 +113,7 @@ func (t *TestData) Initialize() (*TestData, error) { } for i, transaction := range t.transactions { - transaction.Splits = []*handlers.Split{} + transaction.Splits = []*models.Split{} for _, s := range t.transactions[i].Splits { // Make a copy of the split since Splits is a slice of pointers so // copying the transaction doesn't @@ -246,78 +246,78 @@ var data = []TestData{ RemoteId: "USDEUR819298714", }, }, - accounts: []handlers.Account{ + accounts: []models.Account{ { UserId: 0, SecurityId: 0, ParentAccountId: -1, - Type: handlers.Asset, + Type: models.Asset, Name: "Assets", }, { UserId: 0, SecurityId: 0, ParentAccountId: 0, - Type: handlers.Bank, + Type: models.Bank, Name: "Credit Union Checking", }, { UserId: 0, SecurityId: 0, ParentAccountId: -1, - Type: handlers.Expense, + Type: models.Expense, Name: "Expenses", }, { UserId: 0, SecurityId: 0, ParentAccountId: 2, - Type: handlers.Expense, + Type: models.Expense, Name: "Groceries", }, { UserId: 0, SecurityId: 0, ParentAccountId: 2, - Type: handlers.Expense, + Type: models.Expense, Name: "Cable", }, { UserId: 1, SecurityId: 2, ParentAccountId: -1, - Type: handlers.Asset, + Type: models.Asset, Name: "Assets", }, { UserId: 1, SecurityId: 2, ParentAccountId: -1, - Type: handlers.Expense, + Type: models.Expense, Name: "Expenses", }, { UserId: 0, SecurityId: 0, ParentAccountId: -1, - Type: handlers.Liability, + Type: models.Liability, Name: "Credit Card", }, }, - transactions: []handlers.Transaction{ + transactions: []models.Transaction{ { UserId: 0, Description: "weekly groceries", Date: time.Date(2017, time.October, 15, 1, 16, 59, 0, time.UTC), - Splits: []*handlers.Split{ + Splits: []*models.Split{ { - Status: handlers.Reconciled, + Status: models.Reconciled, AccountId: 1, SecurityId: -1, Amount: "-5.6", }, { - Status: handlers.Reconciled, + Status: models.Reconciled, AccountId: 3, SecurityId: -1, Amount: "5.6", @@ -328,15 +328,15 @@ var data = []TestData{ UserId: 0, Description: "weekly groceries", Date: time.Date(2017, time.October, 31, 19, 10, 14, 0, time.UTC), - Splits: []*handlers.Split{ + Splits: []*models.Split{ { - Status: handlers.Reconciled, + Status: models.Reconciled, AccountId: 1, SecurityId: -1, Amount: "-81.59", }, { - Status: handlers.Reconciled, + Status: models.Reconciled, AccountId: 3, SecurityId: -1, Amount: "81.59", @@ -347,15 +347,15 @@ var data = []TestData{ UserId: 0, Description: "Cable", Date: time.Date(2017, time.September, 2, 0, 00, 00, 0, time.UTC), - Splits: []*handlers.Split{ + Splits: []*models.Split{ { - Status: handlers.Reconciled, + Status: models.Reconciled, AccountId: 1, SecurityId: -1, Amount: "-39.99", }, { - Status: handlers.Entered, + Status: models.Entered, AccountId: 4, SecurityId: -1, Amount: "39.99", @@ -366,15 +366,15 @@ var data = []TestData{ UserId: 1, Description: "Gas", Date: time.Date(2017, time.November, 1, 13, 19, 50, 0, time.UTC), - Splits: []*handlers.Split{ + Splits: []*models.Split{ { - Status: handlers.Reconciled, + Status: models.Reconciled, AccountId: 5, SecurityId: -1, Amount: "-24.56", }, { - Status: handlers.Entered, + Status: models.Entered, AccountId: 6, SecurityId: -1, Amount: "24.56", diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index 60c97b2..3795ce7 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -1,7 +1,6 @@ package handlers import ( - "encoding/json" "errors" "fmt" "github.com/aclindsa/moneygo/internal/models" @@ -10,141 +9,17 @@ import ( "net/http" "net/url" "strconv" - "strings" "time" ) -// Split.Status -const ( - Imported int64 = 1 - Entered = 2 - Cleared = 3 - Reconciled = 4 - Voided = 5 -) - -// Split.ImportSplitType -const ( - Default int64 = 0 - ImportAccount = 1 // This split belongs to the main account being imported - SubAccount = 2 // This split belongs to a sub-account of that being imported - ExternalAccount = 3 - TradingAccount = 4 - Commission = 5 - Taxes = 6 - Fees = 7 - Load = 8 - IncomeAccount = 9 - ExpenseAccount = 10 -) - -type Split struct { - SplitId int64 - TransactionId int64 - Status int64 - ImportSplitType int64 - - // One of AccountId and SecurityId must be -1 - // In normal splits, AccountId will be valid and SecurityId will be -1. The - // only case where this is reversed is for transactions that have been - // imported and not yet associated with an account. - AccountId int64 - SecurityId int64 - - RemoteId string // unique ID from server, for detecting duplicates - Number string // Check or reference number - Memo string - Amount string // String representation of decimal, suitable for passing to big.Rat.SetString() -} - -func GetBigAmount(amt string) (*big.Rat, error) { - var r big.Rat - _, success := r.SetString(amt) - if !success { - return nil, errors.New("Couldn't convert string amount to big.Rat via SetString()") - } - return &r, nil -} - -func (s *Split) GetAmount() (*big.Rat, error) { - return GetBigAmount(s.Amount) -} - -func (s *Split) Valid() bool { - if (s.AccountId == -1) == (s.SecurityId == -1) { - return false - } - _, err := s.GetAmount() - return err == nil -} - -func (s *Split) AlreadyImported(tx *Tx) (bool, error) { +func SplitAlreadyImported(tx *Tx, s *models.Split) (bool, error) { count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId) return count == 1, err } -type Transaction struct { - TransactionId int64 - UserId int64 - Description string - Date time.Time - Splits []*Split `db:"-"` -} - -type TransactionList struct { - Transactions *[]Transaction `json:"transactions"` -} - -type AccountTransactionsList struct { - Account *Account - Transactions *[]Transaction - TotalTransactions int64 - BeginningBalance string - EndingBalance string -} - -func (t *Transaction) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(t) -} - -func (t *Transaction) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(t) -} - -func (tl *TransactionList) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(tl) -} - -func (tl *TransactionList) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(tl) -} - -func (atl *AccountTransactionsList) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(atl) -} - -func (atl *AccountTransactionsList) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(atl) -} - -func (t *Transaction) Valid() bool { - for i := range t.Splits { - if !t.Splits[i].Valid() { - return false - } - } - return true -} - // Return a map of security ID's to big.Rat's containing the amount that // security is imbalanced by -func (t *Transaction) GetImbalances(tx *Tx) (map[int64]big.Rat, error) { +func GetTransactionImbalances(tx *Tx, t *models.Transaction) (map[int64]big.Rat, error) { sums := make(map[int64]big.Rat) if !t.Valid() { @@ -155,7 +30,7 @@ func (t *Transaction) GetImbalances(tx *Tx) (map[int64]big.Rat, error) { securityid := t.Splits[i].SecurityId if t.Splits[i].AccountId != -1 { var err error - var account *Account + var account *models.Account account, err = GetAccount(tx, t.Splits[i].AccountId, t.UserId) if err != nil { return nil, err @@ -172,10 +47,10 @@ func (t *Transaction) GetImbalances(tx *Tx) (map[int64]big.Rat, error) { // Returns true if all securities contained in this transaction are balanced, // false otherwise -func (t *Transaction) Balanced(tx *Tx) (bool, error) { +func TransactionBalanced(tx *Tx, t *models.Transaction) (bool, error) { var zero big.Rat - sums, err := t.GetImbalances(tx) + sums, err := GetTransactionImbalances(tx, t) if err != nil { return false, err } @@ -188,8 +63,8 @@ func (t *Transaction) Balanced(tx *Tx) (bool, error) { return true, nil } -func GetTransaction(tx *Tx, transactionid int64, userid int64) (*Transaction, error) { - var t Transaction +func GetTransaction(tx *Tx, transactionid int64, userid int64) (*models.Transaction, error) { + var t models.Transaction err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid) if err != nil { @@ -204,8 +79,8 @@ func GetTransaction(tx *Tx, transactionid int64, userid int64) (*Transaction, er return &t, nil } -func GetTransactions(tx *Tx, userid int64) (*[]Transaction, error) { - var transactions []Transaction +func GetTransactions(tx *Tx, userid int64) (*[]models.Transaction, error) { + var transactions []models.Transaction _, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid) if err != nil { @@ -246,7 +121,7 @@ func (ame AccountMissingError) Error() string { return "Account missing" } -func InsertTransaction(tx *Tx, t *Transaction, user *models.User) error { +func InsertTransaction(tx *Tx, t *models.Transaction, user *models.User) error { // Map of any accounts with transaction splits being added a_map := make(map[int64]bool) for i := range t.Splits { @@ -296,8 +171,8 @@ func InsertTransaction(tx *Tx, t *Transaction, user *models.User) error { return nil } -func UpdateTransaction(tx *Tx, t *Transaction, user *models.User) error { - var existing_splits []*Split +func UpdateTransaction(tx *Tx, t *models.Transaction, user *models.User) error { + var existing_splits []*models.Split _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) if err != nil { @@ -373,7 +248,7 @@ func UpdateTransaction(tx *Tx, t *Transaction, user *models.User) error { return nil } -func DeleteTransaction(tx *Tx, t *Transaction, user *models.User) error { +func DeleteTransaction(tx *Tx, t *models.Transaction, user *models.User) error { var accountids []int64 _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) if err != nil { @@ -408,7 +283,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter } if r.Method == "POST" { - var transaction Transaction + var transaction models.Transaction if err := ReadJSON(r, &transaction); err != nil { return NewError(3 /*Invalid Request*/) } @@ -427,7 +302,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter } } - balanced, err := transaction.Balanced(context.Tx) + balanced, err := TransactionBalanced(context.Tx, &transaction) if err != nil { return NewError(999 /*Internal Error*/) } @@ -449,7 +324,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter } else if r.Method == "GET" { if context.LastLevel() { //Return all Transactions - var al TransactionList + var al models.TransactionList transactions, err := GetTransactions(context.Tx, user.UserId) if err != nil { log.Print(err) @@ -475,13 +350,13 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return NewError(3 /*Invalid Request*/) } if r.Method == "PUT" { - var transaction Transaction + var transaction models.Transaction if err := ReadJSON(r, &transaction); err != nil || transaction.TransactionId != transactionid { return NewError(3 /*Invalid Request*/) } transaction.UserId = user.UserId - balanced, err := transaction.Balanced(context.Tx) + balanced, err := TransactionBalanced(context.Tx, &transaction) if err != nil { log.Print(err) return NewError(999 /*Internal Error*/) @@ -526,7 +401,7 @@ func TransactionHandler(r *http.Request, context *Context) ResponseWriterWriter return NewError(3 /*Invalid Request*/) } -func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []Transaction) (*big.Rat, error) { +func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []models.Transaction) (*big.Rat, error) { var pageDifference, tmp big.Rat for i := range transactions { _, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId) @@ -538,7 +413,7 @@ func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []Trans // an ending balance for j := range transactions[i].Splits { if transactions[i].Splits[j].AccountId == accountid { - rat_amount, err := GetBigAmount(transactions[i].Splits[j].Amount) + rat_amount, err := models.GetBigAmount(transactions[i].Splits[j].Amount) if err != nil { return nil, err } @@ -551,7 +426,7 @@ func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []Trans } func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) { - var splits []Split + var splits []models.Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" _, err := tx.Select(&splits, sql, accountid, user.UserId) @@ -561,7 +436,7 @@ func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, er var balance, tmp big.Rat for _, s := range splits { - rat_amount, err := GetBigAmount(s.Amount) + rat_amount, err := models.GetBigAmount(s.Amount) if err != nil { return nil, err } @@ -574,7 +449,7 @@ func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, er // Assumes accountid is valid and is owned by the current user func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) { - var splits []Split + var splits []models.Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" _, err := tx.Select(&splits, sql, accountid, user.UserId, date) @@ -584,7 +459,7 @@ func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *tim var balance, tmp big.Rat for _, s := range splits { - rat_amount, err := GetBigAmount(s.Amount) + rat_amount, err := models.GetBigAmount(s.Amount) if err != nil { return nil, err } @@ -596,7 +471,7 @@ func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *tim } func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) { - var splits []Split + var splits []models.Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" _, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end) @@ -606,7 +481,7 @@ func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begi var balance, tmp big.Rat for _, s := range splits { - rat_amount, err := GetBigAmount(s.Amount) + rat_amount, err := models.GetBigAmount(s.Amount) if err != nil { return nil, err } @@ -617,9 +492,9 @@ func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begi return &balance, nil } -func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { - var transactions []Transaction - var atl AccountTransactionsList +func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*models.AccountTransactionsList, error) { + var transactions []models.Transaction + var atl models.AccountTransactionsList var sqlsort, balanceLimitOffset string var balanceLimitOffsetArg uint64 @@ -685,7 +560,7 @@ func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort str var tmp, balance big.Rat for _, amount := range amounts { - rat_amount, err := GetBigAmount(amount) + rat_amount, err := models.GetBigAmount(amount) if err != nil { return nil, err } diff --git a/internal/handlers/transactions_test.go b/internal/handlers/transactions_test.go index 7d6285f..0f3b68d 100644 --- a/internal/handlers/transactions_test.go +++ b/internal/handlers/transactions_test.go @@ -3,6 +3,7 @@ package handlers_test import ( "fmt" "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "net/url" "strconv" @@ -10,14 +11,14 @@ import ( "time" ) -func createTransaction(client *http.Client, transaction *handlers.Transaction) (*handlers.Transaction, error) { - var s handlers.Transaction +func createTransaction(client *http.Client, transaction *models.Transaction) (*models.Transaction, error) { + var s models.Transaction err := create(client, transaction, &s, "/v1/transactions/") return &s, err } -func getTransaction(client *http.Client, transactionid int64) (*handlers.Transaction, error) { - var s handlers.Transaction +func getTransaction(client *http.Client, transactionid int64) (*models.Transaction, error) { + var s models.Transaction err := read(client, &s, "/v1/transactions/"+strconv.FormatInt(transactionid, 10)) if err != nil { return nil, err @@ -25,8 +26,8 @@ func getTransaction(client *http.Client, transactionid int64) (*handlers.Transac return &s, nil } -func getTransactions(client *http.Client) (*handlers.TransactionList, error) { - var tl handlers.TransactionList +func getTransactions(client *http.Client) (*models.TransactionList, error) { + var tl models.TransactionList err := read(client, &tl, "/v1/transactions/") if err != nil { return nil, err @@ -34,8 +35,8 @@ func getTransactions(client *http.Client) (*handlers.TransactionList, error) { return &tl, nil } -func getAccountTransactions(client *http.Client, accountid, page, limit int64, sort string) (*handlers.AccountTransactionsList, error) { - var atl handlers.AccountTransactionsList +func getAccountTransactions(client *http.Client, accountid, page, limit int64, sort string) (*models.AccountTransactionsList, error) { + var atl models.AccountTransactionsList params := url.Values{} query := fmt.Sprintf("/v1/accounts/%d/transactions/", accountid) @@ -57,8 +58,8 @@ func getAccountTransactions(client *http.Client, accountid, page, limit int64, s return &atl, nil } -func updateTransaction(client *http.Client, transaction *handlers.Transaction) (*handlers.Transaction, error) { - var s handlers.Transaction +func updateTransaction(client *http.Client, transaction *models.Transaction) (*models.Transaction, error) { + var s models.Transaction err := update(client, transaction, &s, "/v1/transactions/"+strconv.FormatInt(transaction.TransactionId, 10)) if err != nil { return nil, err @@ -66,7 +67,7 @@ func updateTransaction(client *http.Client, transaction *handlers.Transaction) ( return &s, nil } -func deleteTransaction(client *http.Client, s *handlers.Transaction) error { +func deleteTransaction(client *http.Client, s *models.Transaction) error { err := remove(client, "/v1/transactions/"+strconv.FormatInt(s.TransactionId, 10)) if err != nil { return err @@ -74,7 +75,7 @@ func deleteTransaction(client *http.Client, s *handlers.Transaction) error { return nil } -func ensureTransactionsMatch(t *testing.T, expected, tran *handlers.Transaction, accounts *[]handlers.Account, matchtransactionids, matchsplitids bool) { +func ensureTransactionsMatch(t *testing.T, expected, tran *models.Transaction, accounts *[]models.Account, matchtransactionids, matchsplitids bool) { t.Helper() if tran.TransactionId == 0 { @@ -136,9 +137,9 @@ func ensureTransactionsMatch(t *testing.T, expected, tran *handlers.Transaction, } } -func getAccountVersionMap(t *testing.T, client *http.Client, tran *handlers.Transaction) map[int64]*handlers.Account { +func getAccountVersionMap(t *testing.T, client *http.Client, tran *models.Transaction) map[int64]*models.Account { t.Helper() - accountMap := make(map[int64]*handlers.Account) + accountMap := make(map[int64]*models.Account) for _, split := range tran.Splits { account, err := getAccount(client, split.AccountId) if err != nil { @@ -149,7 +150,7 @@ func getAccountVersionMap(t *testing.T, client *http.Client, tran *handlers.Tran return accountMap } -func checkAccountVersionsUpdated(t *testing.T, client *http.Client, accountMap map[int64]*handlers.Account, tran *handlers.Transaction) { +func checkAccountVersionsUpdated(t *testing.T, client *http.Client, accountMap map[int64]*models.Account, tran *models.Transaction) { for _, split := range tran.Splits { account, err := getAccount(client, split.AccountId) if err != nil { @@ -177,19 +178,19 @@ func TestCreateTransaction(t *testing.T) { } // Don't allow imbalanced transactions - tran := handlers.Transaction{ + tran := models.Transaction{ UserId: d.users[0].UserId, Description: "Imbalanced", Date: time.Date(2017, time.September, 1, 0, 00, 00, 0, time.UTC), - Splits: []*handlers.Split{ + Splits: []*models.Split{ { - Status: handlers.Reconciled, + Status: models.Reconciled, AccountId: d.accounts[1].AccountId, SecurityId: -1, Amount: "-39.98", }, { - Status: handlers.Entered, + Status: models.Entered, AccountId: d.accounts[4].AccountId, SecurityId: -1, Amount: "39.99", @@ -209,7 +210,7 @@ func TestCreateTransaction(t *testing.T) { } // Don't allow transactions with 0 splits - tran.Splits = []*handlers.Split{} + tran.Splits = []*models.Split{} _, err = createTransaction(d.clients[0], &tran) if err == nil { t.Fatalf("Expected error creating with zero splits") @@ -316,9 +317,9 @@ func TestUpdateTransaction(t *testing.T) { ensureTransactionsMatch(t, &curr, tran, nil, true, true) - tran.Splits = []*handlers.Split{} + tran.Splits = []*models.Split{} for _, s := range curr.Splits { - var split handlers.Split + var split models.Split split = *s tran.Splits = append(tran.Splits, &split) } @@ -346,7 +347,7 @@ func TestUpdateTransaction(t *testing.T) { } // Don't allow transactions with 0 splits - tran.Splits = []*handlers.Split{} + tran.Splits = []*models.Split{} _, err = updateTransaction(d.clients[orig.UserId], tran) if err == nil { t.Fatalf("Expected error updating with zero splits") @@ -391,12 +392,12 @@ func TestDeleteTransaction(t *testing.T) { }) } -func helperTestAccountTransactions(t *testing.T, d *TestData, account *handlers.Account, limit int64, sort string) { +func helperTestAccountTransactions(t *testing.T, d *TestData, account *models.Account, limit int64, sort string) { if account.UserId != d.users[0].UserId { return } - var transactions []handlers.Transaction + var transactions []models.Transaction var lastFetchCount int64 for page := int64(0); page == 0 || lastFetchCount > 0; page++ { diff --git a/internal/models/accounts.go b/internal/models/accounts.go new file mode 100644 index 0000000..fdfac98 --- /dev/null +++ b/internal/models/accounts.go @@ -0,0 +1,118 @@ +package models + +import ( + "encoding/json" + "net/http" + "strings" +) + +type AccountType int64 + +const ( + Bank AccountType = 1 // start at 1 so that the default (0) is invalid + Cash = 2 + Asset = 3 + Liability = 4 + Investment = 5 + Income = 6 + Expense = 7 + Trading = 8 + Equity = 9 + Receivable = 10 + Payable = 11 +) + +var AccountTypes = []AccountType{ + Bank, + Cash, + Asset, + Liability, + Investment, + Income, + Expense, + Trading, + Equity, + Receivable, + Payable, +} + +func (t AccountType) String() string { + switch t { + case Bank: + return "Bank" + case Cash: + return "Cash" + case Asset: + return "Asset" + case Liability: + return "Liability" + case Investment: + return "Investment" + case Income: + return "Income" + case Expense: + return "Expense" + case Trading: + return "Trading" + case Equity: + return "Equity" + case Receivable: + return "Receivable" + case Payable: + return "Payable" + } + return "" +} + +type Account struct { + AccountId int64 + ExternalAccountId string + UserId int64 + SecurityId int64 + ParentAccountId int64 // -1 if this account is at the root + Type AccountType + Name string + + // monotonically-increasing account transaction version number. Used for + // allowing a client to ensure they have a consistent version when paging + // through transactions. + AccountVersion int64 `json:"Version"` + + // Optional fields specifying how to fetch transactions from a bank via OFX + OFXURL string + OFXORG string + OFXFID string + OFXUser string + OFXBankID string // OFX BankID (BrokerID if AcctType == Investment) + OFXAcctID string + OFXAcctType string // ofxgo.acctType + OFXClientUID string + OFXAppID string + OFXAppVer string + OFXVersion string + OFXNoIndent bool +} + +type AccountList struct { + Accounts *[]Account `json:"accounts"` +} + +func (a *Account) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(a) +} + +func (a *Account) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(a) +} + +func (al *AccountList) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(al) +} + +func (al *AccountList) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(al) +} diff --git a/internal/models/transactions.go b/internal/models/transactions.go new file mode 100644 index 0000000..8076995 --- /dev/null +++ b/internal/models/transactions.go @@ -0,0 +1,133 @@ +package models + +import ( + "encoding/json" + "errors" + "math/big" + "net/http" + "strings" + "time" +) + +// Split.Status +const ( + Imported int64 = 1 + Entered = 2 + Cleared = 3 + Reconciled = 4 + Voided = 5 +) + +// Split.ImportSplitType +const ( + Default int64 = 0 + ImportAccount = 1 // This split belongs to the main account being imported + SubAccount = 2 // This split belongs to a sub-account of that being imported + ExternalAccount = 3 + TradingAccount = 4 + Commission = 5 + Taxes = 6 + Fees = 7 + Load = 8 + IncomeAccount = 9 + ExpenseAccount = 10 +) + +type Split struct { + SplitId int64 + TransactionId int64 + Status int64 + ImportSplitType int64 + + // One of AccountId and SecurityId must be -1 + // In normal splits, AccountId will be valid and SecurityId will be -1. The + // only case where this is reversed is for transactions that have been + // imported and not yet associated with an account. + AccountId int64 + SecurityId int64 + + RemoteId string // unique ID from server, for detecting duplicates + Number string // Check or reference number + Memo string + Amount string // String representation of decimal, suitable for passing to big.Rat.SetString() +} + +func GetBigAmount(amt string) (*big.Rat, error) { + var r big.Rat + _, success := r.SetString(amt) + if !success { + return nil, errors.New("Couldn't convert string amount to big.Rat via SetString()") + } + return &r, nil +} + +func (s *Split) GetAmount() (*big.Rat, error) { + return GetBigAmount(s.Amount) +} + +func (s *Split) Valid() bool { + if (s.AccountId == -1) == (s.SecurityId == -1) { + return false + } + _, err := s.GetAmount() + return err == nil +} + +type Transaction struct { + TransactionId int64 + UserId int64 + Description string + Date time.Time + Splits []*Split `db:"-"` +} + +type TransactionList struct { + Transactions *[]Transaction `json:"transactions"` +} + +type AccountTransactionsList struct { + Account *Account + Transactions *[]Transaction + TotalTransactions int64 + BeginningBalance string + EndingBalance string +} + +func (t *Transaction) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(t) +} + +func (t *Transaction) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(t) +} + +func (tl *TransactionList) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(tl) +} + +func (tl *TransactionList) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(tl) +} + +func (atl *AccountTransactionsList) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(atl) +} + +func (atl *AccountTransactionsList) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(atl) +} + +func (t *Transaction) Valid() bool { + for i := range t.Splits { + if !t.Splits[i].Valid() { + return false + } + } + return true +} From 5f296e86693a8ce0171aeaab510e1532ce239f42 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Mon, 4 Dec 2017 21:05:17 -0500 Subject: [PATCH 5/6] Split prices into models --- internal/db/db.go | 5 ++- internal/handlers/gnucash.go | 4 +- internal/handlers/prices.go | 63 +++++++----------------------- internal/handlers/prices_lua.go | 6 +-- internal/handlers/prices_test.go | 19 ++++----- internal/handlers/testdata_test.go | 4 +- internal/models/prices.go | 41 +++++++++++++++++++ 7 files changed, 76 insertions(+), 66 deletions(-) create mode 100644 internal/models/prices.go diff --git a/internal/db/db.go b/internal/db/db.go index b92e9bd..a23a029 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -14,6 +14,9 @@ import ( "strings" ) +// luaMaxLengthBuffer is intended to be enough bytes such that a given string +// no longer than models.LuaMaxLength is sure to fit within a database +// implementation's string type specified by the same. const luaMaxLengthBuffer int = 4096 func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { @@ -40,7 +43,7 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { dbmap.AddTableWithName(models.Security{}, "securities").SetKeys(true, "SecurityId") dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId") dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId") - dbmap.AddTableWithName(handlers.Price{}, "prices").SetKeys(true, "PriceId") + dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId") rtable := dbmap.AddTableWithName(handlers.Report{}, "reports").SetKeys(true, "ReportId") rtable.ColMap("Lua").SetMaxSize(handlers.LuaMaxLength + luaMaxLengthBuffer) diff --git a/internal/handlers/gnucash.go b/internal/handlers/gnucash.go index f99978e..2399a6b 100644 --- a/internal/handlers/gnucash.go +++ b/internal/handlers/gnucash.go @@ -129,7 +129,7 @@ type GnucashImport struct { Securities []models.Security Accounts []models.Account Transactions []models.Transaction - Prices []Price + Prices []models.Price } func ImportGnucash(r io.Reader) (*GnucashImport, error) { @@ -161,7 +161,7 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) { // Create prices, setting security and currency IDs from securityMap for i := range gncxml.PriceDB.Prices { price := gncxml.PriceDB.Prices[i] - var p Price + var p models.Price security, ok := securityMap[price.Commodity.Name] if !ok { return nil, fmt.Errorf("Unable to find commodity '%s' for price '%s'", price.Commodity.Name, price.Id) diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index 2689378..c92eeb4 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -1,48 +1,13 @@ package handlers import ( - "encoding/json" "github.com/aclindsa/moneygo/internal/models" "log" "net/http" - "strings" "time" ) -type Price struct { - PriceId int64 - SecurityId int64 - CurrencyId int64 - Date time.Time - Value string // String representation of decimal price of Security in Currency units, suitable for passing to big.Rat.SetString() - RemoteId string // unique ID from source, for detecting duplicates -} - -type PriceList struct { - Prices *[]*Price `json:"prices"` -} - -func (p *Price) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(p) -} - -func (p *Price) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(p) -} - -func (pl *PriceList) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(pl) -} - -func (pl *PriceList) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(pl) -} - -func CreatePriceIfNotExist(tx *Tx, price *Price) error { +func CreatePriceIfNotExist(tx *Tx, price *models.Price) error { if len(price.RemoteId) == 0 { // Always create a new price if we can't match on the RemoteId err := tx.Insert(price) @@ -52,7 +17,7 @@ func CreatePriceIfNotExist(tx *Tx, price *Price) error { return nil } - var prices []*Price + var prices []*models.Price _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value) if err != nil { @@ -70,8 +35,8 @@ func CreatePriceIfNotExist(tx *Tx, price *Price) error { return nil } -func GetPrice(tx *Tx, priceid, securityid int64) (*Price, error) { - var p Price +func GetPrice(tx *Tx, priceid, securityid int64) (*models.Price, error) { + var p models.Price err := tx.SelectOne(&p, "SELECT * from prices where PriceId=? AND SecurityId=?", priceid, securityid) if err != nil { return nil, err @@ -79,8 +44,8 @@ func GetPrice(tx *Tx, priceid, securityid int64) (*Price, error) { return &p, nil } -func GetPrices(tx *Tx, securityid int64) (*[]*Price, error) { - var prices []*Price +func GetPrices(tx *Tx, securityid int64) (*[]*models.Price, error) { + var prices []*models.Price _, err := tx.Select(&prices, "SELECT * from prices where SecurityId=?", securityid) if err != nil { @@ -90,8 +55,8 @@ func GetPrices(tx *Tx, securityid int64) (*[]*Price, error) { } // Return the latest price for security in currency units before date -func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*Price, error) { - var p Price +func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { + var p models.Price err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { return nil, err @@ -100,8 +65,8 @@ func GetLatestPrice(tx *Tx, security, currency *models.Security, date *time.Time } // Return the earliest price for security in currency units after date -func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*Price, error) { - var p Price +func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { + var p models.Price err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date) if err != nil { return nil, err @@ -110,7 +75,7 @@ func GetEarliestPrice(tx *Tx, security, currency *models.Security, date *time.Ti } // Return the price for security in currency closest to date -func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*Price, error) { +func GetClosestPrice(tx *Tx, security, currency *models.Security, date *time.Time) (*models.Price, error) { earliest, _ := GetEarliestPrice(tx, security, currency, date) latest, err := GetLatestPrice(tx, security, currency, date) @@ -137,7 +102,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security } if r.Method == "POST" { - var price Price + var price models.Price if err := ReadJSON(r, &price); err != nil { return NewError(3 /*Invalid Request*/) } @@ -161,7 +126,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security } else if r.Method == "GET" { if context.LastLevel() { //Return all this security's prices - var pl PriceList + var pl models.PriceList prices, err := GetPrices(context.Tx, security.SecurityId) if err != nil { @@ -190,7 +155,7 @@ func PriceHandler(r *http.Request, context *Context, user *models.User, security return NewError(3 /*Invalid Request*/) } if r.Method == "PUT" { - var price Price + var price models.Price if err := ReadJSON(r, &price); err != nil || price.PriceId != priceid { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/prices_lua.go b/internal/handlers/prices_lua.go index 8450319..1ff0da2 100644 --- a/internal/handlers/prices_lua.go +++ b/internal/handlers/prices_lua.go @@ -15,7 +15,7 @@ func luaRegisterPrices(L *lua.LState) { L.SetField(mt, "__metatable", lua.LString("protected")) } -func PriceToLua(L *lua.LState, price *Price) *lua.LUserData { +func PriceToLua(L *lua.LState, price *models.Price) *lua.LUserData { ud := L.NewUserData() ud.Value = price L.SetMetatable(ud, L.GetTypeMetatable(luaPriceTypeName)) @@ -23,9 +23,9 @@ func PriceToLua(L *lua.LState, price *Price) *lua.LUserData { } // Checks whether the first lua argument is a *LUserData with *Price and returns this *Price. -func luaCheckPrice(L *lua.LState, n int) *Price { +func luaCheckPrice(L *lua.LState, n int) *models.Price { ud := L.CheckUserData(n) - if price, ok := ud.Value.(*Price); ok { + if price, ok := ud.Value.(*models.Price); ok { return price } L.ArgError(n, "price expected") diff --git a/internal/handlers/prices_test.go b/internal/handlers/prices_test.go index 1cbca93..8c44379 100644 --- a/internal/handlers/prices_test.go +++ b/internal/handlers/prices_test.go @@ -2,20 +2,21 @@ package handlers_test import ( "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "strconv" "testing" "time" ) -func createPrice(client *http.Client, price *handlers.Price) (*handlers.Price, error) { - var p handlers.Price +func createPrice(client *http.Client, price *models.Price) (*models.Price, error) { + var p models.Price err := create(client, price, &p, "/v1/securities/"+strconv.FormatInt(price.SecurityId, 10)+"/prices/") return &p, err } -func getPrice(client *http.Client, priceid, securityid int64) (*handlers.Price, error) { - var p handlers.Price +func getPrice(client *http.Client, priceid, securityid int64) (*models.Price, error) { + var p models.Price err := read(client, &p, "/v1/securities/"+strconv.FormatInt(securityid, 10)+"/prices/"+strconv.FormatInt(priceid, 10)) if err != nil { return nil, err @@ -23,8 +24,8 @@ func getPrice(client *http.Client, priceid, securityid int64) (*handlers.Price, return &p, nil } -func getPrices(client *http.Client, securityid int64) (*handlers.PriceList, error) { - var pl handlers.PriceList +func getPrices(client *http.Client, securityid int64) (*models.PriceList, error) { + var pl models.PriceList err := read(client, &pl, "/v1/securities/"+strconv.FormatInt(securityid, 10)+"/prices/") if err != nil { return nil, err @@ -32,8 +33,8 @@ func getPrices(client *http.Client, securityid int64) (*handlers.PriceList, erro return &pl, nil } -func updatePrice(client *http.Client, price *handlers.Price) (*handlers.Price, error) { - var p handlers.Price +func updatePrice(client *http.Client, price *models.Price) (*models.Price, error) { + var p models.Price err := update(client, price, &p, "/v1/securities/"+strconv.FormatInt(price.SecurityId, 10)+"/prices/"+strconv.FormatInt(price.PriceId, 10)) if err != nil { return nil, err @@ -41,7 +42,7 @@ func updatePrice(client *http.Client, price *handlers.Price) (*handlers.Price, e return &p, nil } -func deletePrice(client *http.Client, p *handlers.Price) error { +func deletePrice(client *http.Client, p *models.Price) error { err := remove(client, "/v1/securities/"+strconv.FormatInt(p.SecurityId, 10)+"/prices/"+strconv.FormatInt(p.PriceId, 10)) if err != nil { return err diff --git a/internal/handlers/testdata_test.go b/internal/handlers/testdata_test.go index 2cef0cd..65bfee8 100644 --- a/internal/handlers/testdata_test.go +++ b/internal/handlers/testdata_test.go @@ -38,7 +38,7 @@ type TestData struct { users []User clients []*http.Client securities []models.Security - prices []handlers.Price + prices []models.Price accounts []models.Account // accounts must appear after their parents in this slice transactions []models.Transaction reports []handlers.Report @@ -209,7 +209,7 @@ var data = []TestData{ AlternateId: "978", }, }, - prices: []handlers.Price{ + prices: []models.Price{ { SecurityId: 1, CurrencyId: 0, diff --git a/internal/models/prices.go b/internal/models/prices.go new file mode 100644 index 0000000..7958e52 --- /dev/null +++ b/internal/models/prices.go @@ -0,0 +1,41 @@ +package models + +import ( + "encoding/json" + "net/http" + "strings" + "time" +) + +type Price struct { + PriceId int64 + SecurityId int64 + CurrencyId int64 + Date time.Time + Value string // String representation of decimal price of Security in Currency units, suitable for passing to big.Rat.SetString() + RemoteId string // unique ID from source, for detecting duplicates +} + +type PriceList struct { + Prices *[]*Price `json:"prices"` +} + +func (p *Price) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(p) +} + +func (p *Price) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(p) +} + +func (pl *PriceList) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(pl) +} + +func (pl *PriceList) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(pl) +} From 3e3038295d0231b8e9cf7fbd1286c253e1a1a1e9 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Tue, 5 Dec 2017 05:58:36 -0500 Subject: [PATCH 6/6] Split reports into models --- internal/db/db.go | 5 +- internal/handlers/reports.go | 89 +++++---------------------- internal/handlers/reports_lua.go | 21 ++++--- internal/handlers/reports_lua_test.go | 4 +- internal/handlers/reports_test.go | 31 +++++----- internal/handlers/testdata_test.go | 23 ++++--- internal/models/reports.go | 66 ++++++++++++++++++++ 7 files changed, 122 insertions(+), 117 deletions(-) create mode 100644 internal/models/reports.go diff --git a/internal/db/db.go b/internal/db/db.go index a23a029..a33fb08 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/aclindsa/gorp" "github.com/aclindsa/moneygo/internal/config" - "github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/models" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" @@ -44,8 +43,8 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { dbmap.AddTableWithName(models.Transaction{}, "transactions").SetKeys(true, "TransactionId") dbmap.AddTableWithName(models.Split{}, "splits").SetKeys(true, "SplitId") dbmap.AddTableWithName(models.Price{}, "prices").SetKeys(true, "PriceId") - rtable := dbmap.AddTableWithName(handlers.Report{}, "reports").SetKeys(true, "ReportId") - rtable.ColMap("Lua").SetMaxSize(handlers.LuaMaxLength + luaMaxLengthBuffer) + rtable := dbmap.AddTableWithName(models.Report{}, "reports").SetKeys(true, "ReportId") + rtable.ColMap("Lua").SetMaxSize(models.LuaMaxLength + luaMaxLengthBuffer) err := dbmap.CreateTablesIfNotExists() if err != nil { diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index cab32d1..bec9525 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -2,14 +2,12 @@ package handlers import ( "context" - "encoding/json" "errors" "fmt" "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" "log" "net/http" - "strings" "time" ) @@ -26,67 +24,8 @@ const ( const luaTimeoutSeconds time.Duration = 30 // maximum time a lua request can run for -type Report struct { - ReportId int64 - UserId int64 - Name string - Lua string -} - -// The maximum length (in bytes) the Lua code may be. This is used to set the -// max size of the database columns (with an added fudge factor) -const LuaMaxLength int = 65536 - -func (r *Report) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(r) -} - -func (r *Report) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(r) -} - -type ReportList struct { - Reports *[]Report `json:"reports"` -} - -func (rl *ReportList) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(rl) -} - -func (rl *ReportList) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(rl) -} - -type Series struct { - Values []float64 - Series map[string]*Series -} - -type Tabulation struct { - ReportId int64 - Title string - Subtitle string - Units string - Labels []string - Series map[string]*Series -} - -func (t *Tabulation) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(t) -} - -func (t *Tabulation) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(t) -} - -func GetReport(tx *Tx, reportid int64, userid int64) (*Report, error) { - var r Report +func GetReport(tx *Tx, reportid int64, userid int64) (*models.Report, error) { + var r models.Report err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid) if err != nil { @@ -95,8 +34,8 @@ func GetReport(tx *Tx, reportid int64, userid int64) (*Report, error) { return &r, nil } -func GetReports(tx *Tx, userid int64) (*[]Report, error) { - var reports []Report +func GetReports(tx *Tx, userid int64) (*[]models.Report, error) { + var reports []models.Report _, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid) if err != nil { @@ -105,7 +44,7 @@ func GetReports(tx *Tx, userid int64) (*[]Report, error) { return &reports, nil } -func InsertReport(tx *Tx, r *Report) error { +func InsertReport(tx *Tx, r *models.Report) error { err := tx.Insert(r) if err != nil { return err @@ -113,7 +52,7 @@ func InsertReport(tx *Tx, r *Report) error { return nil } -func UpdateReport(tx *Tx, r *Report) error { +func UpdateReport(tx *Tx, r *models.Report) error { count, err := tx.Update(r) if err != nil { return err @@ -124,7 +63,7 @@ func UpdateReport(tx *Tx, r *Report) error { return nil } -func DeleteReport(tx *Tx, r *Report) error { +func DeleteReport(tx *Tx, r *models.Report) error { count, err := tx.Delete(r) if err != nil { return err @@ -135,7 +74,7 @@ func DeleteReport(tx *Tx, r *Report) error { return nil } -func runReport(tx *Tx, user *models.User, report *Report) (*Tabulation, error) { +func runReport(tx *Tx, user *models.User, report *models.Report) (*models.Tabulation, error) { // Create a new LState without opening the default libs for security L := lua.NewState(lua.Options{SkipOpenLibs: true}) defer L.Close() @@ -189,7 +128,7 @@ func runReport(tx *Tx, user *models.User, report *Report) (*Tabulation, error) { value := L.Get(-1) if ud, ok := value.(*lua.LUserData); ok { - if tabulation, ok := ud.Value.(*Tabulation); ok { + if tabulation, ok := ud.Value.(*models.Tabulation); ok { return tabulation, nil } else { return nil, fmt.Errorf("generate() for %s (Id: %d) didn't return a tabulation", report.Name, report.ReportId) @@ -224,14 +163,14 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { } if r.Method == "POST" { - var report Report + var report models.Report if err := ReadJSON(r, &report); err != nil { return NewError(3 /*Invalid Request*/) } report.ReportId = -1 report.UserId = user.UserId - if len(report.Lua) >= LuaMaxLength { + if len(report.Lua) >= models.LuaMaxLength { return NewError(3 /*Invalid Request*/) } @@ -245,7 +184,7 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { } else if r.Method == "GET" { if context.LastLevel() { //Return all Reports - var rl ReportList + var rl models.ReportList reports, err := GetReports(context.Tx, user.UserId) if err != nil { log.Print(err) @@ -278,13 +217,13 @@ func ReportHandler(r *http.Request, context *Context) ResponseWriterWriter { } if r.Method == "PUT" { - var report Report + var report models.Report if err := ReadJSON(r, &report); err != nil || report.ReportId != reportid { return NewError(3 /*Invalid Request*/) } report.UserId = user.UserId - if len(report.Lua) >= LuaMaxLength { + if len(report.Lua) >= models.LuaMaxLength { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/reports_lua.go b/internal/handlers/reports_lua.go index 2904a2b..51d919d 100644 --- a/internal/handlers/reports_lua.go +++ b/internal/handlers/reports_lua.go @@ -1,6 +1,7 @@ package handlers import ( + "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" ) @@ -21,9 +22,9 @@ func luaRegisterTabulations(L *lua.LState) { } // Checks whether the first lua argument is a *LUserData with *Tabulation and returns *Tabulation -func luaCheckTabulation(L *lua.LState, n int) *Tabulation { +func luaCheckTabulation(L *lua.LState, n int) *models.Tabulation { ud := L.CheckUserData(n) - if tabulation, ok := ud.Value.(*Tabulation); ok { + if tabulation, ok := ud.Value.(*models.Tabulation); ok { return tabulation } L.ArgError(n, "tabulation expected") @@ -31,9 +32,9 @@ func luaCheckTabulation(L *lua.LState, n int) *Tabulation { } // Checks whether the first lua argument is a *LUserData with *Series and returns *Series -func luaCheckSeries(L *lua.LState, n int) *Series { +func luaCheckSeries(L *lua.LState, n int) *models.Series { ud := L.CheckUserData(n) - if series, ok := ud.Value.(*Series); ok { + if series, ok := ud.Value.(*models.Series); ok { return series } L.ArgError(n, "series expected") @@ -43,9 +44,9 @@ func luaCheckSeries(L *lua.LState, n int) *Series { func luaTabulationNew(L *lua.LState) int { numvalues := L.CheckInt(1) ud := L.NewUserData() - ud.Value = &Tabulation{ + ud.Value = &models.Tabulation{ Labels: make([]string, numvalues), - Series: make(map[string]*Series), + Series: make(map[string]*models.Series), } L.SetMetatable(ud, L.GetTypeMetatable(luaTabulationTypeName)) L.Push(ud) @@ -94,8 +95,8 @@ func luaTabulationSeries(L *lua.LState) int { if ok { ud.Value = s } else { - tabulation.Series[name] = &Series{ - Series: make(map[string]*Series), + tabulation.Series[name] = &models.Series{ + Series: make(map[string]*models.Series), Values: make([]float64, cap(tabulation.Labels)), } ud.Value = tabulation.Series[name] @@ -175,8 +176,8 @@ func luaSeriesSeries(L *lua.LState) int { if ok { ud.Value = s } else { - parent.Series[name] = &Series{ - Series: make(map[string]*Series), + parent.Series[name] = &models.Series{ + Series: make(map[string]*models.Series), Values: make([]float64, cap(parent.Values)), } ud.Value = parent.Series[name] diff --git a/internal/handlers/reports_lua_test.go b/internal/handlers/reports_lua_test.go index 1ba1fa7..bf67f5b 100644 --- a/internal/handlers/reports_lua_test.go +++ b/internal/handlers/reports_lua_test.go @@ -2,7 +2,7 @@ package handlers_test import ( "fmt" - "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "testing" ) @@ -25,7 +25,7 @@ function generate() t:title(tostring(test())) return t end`, lt.Lua) - r := handlers.Report{ + r := models.Report{ Name: lt.Name, Lua: lua, } diff --git a/internal/handlers/reports_test.go b/internal/handlers/reports_test.go index 624cf0a..6f97b8c 100644 --- a/internal/handlers/reports_test.go +++ b/internal/handlers/reports_test.go @@ -2,19 +2,20 @@ package handlers_test import ( "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" "net/http" "strconv" "testing" ) -func createReport(client *http.Client, report *handlers.Report) (*handlers.Report, error) { - var r handlers.Report +func createReport(client *http.Client, report *models.Report) (*models.Report, error) { + var r models.Report err := create(client, report, &r, "/v1/reports/") return &r, err } -func getReport(client *http.Client, reportid int64) (*handlers.Report, error) { - var r handlers.Report +func getReport(client *http.Client, reportid int64) (*models.Report, error) { + var r models.Report err := read(client, &r, "/v1/reports/"+strconv.FormatInt(reportid, 10)) if err != nil { return nil, err @@ -22,8 +23,8 @@ func getReport(client *http.Client, reportid int64) (*handlers.Report, error) { return &r, nil } -func getReports(client *http.Client) (*handlers.ReportList, error) { - var rl handlers.ReportList +func getReports(client *http.Client) (*models.ReportList, error) { + var rl models.ReportList err := read(client, &rl, "/v1/reports/") if err != nil { return nil, err @@ -31,8 +32,8 @@ func getReports(client *http.Client) (*handlers.ReportList, error) { return &rl, nil } -func updateReport(client *http.Client, report *handlers.Report) (*handlers.Report, error) { - var r handlers.Report +func updateReport(client *http.Client, report *models.Report) (*models.Report, error) { + var r models.Report err := update(client, report, &r, "/v1/reports/"+strconv.FormatInt(report.ReportId, 10)) if err != nil { return nil, err @@ -40,7 +41,7 @@ func updateReport(client *http.Client, report *handlers.Report) (*handlers.Repor return &r, nil } -func deleteReport(client *http.Client, r *handlers.Report) error { +func deleteReport(client *http.Client, r *models.Report) error { err := remove(client, "/v1/reports/"+strconv.FormatInt(r.ReportId, 10)) if err != nil { return err @@ -48,8 +49,8 @@ func deleteReport(client *http.Client, r *handlers.Report) error { return nil } -func tabulateReport(client *http.Client, reportid int64) (*handlers.Tabulation, error) { - var t handlers.Tabulation +func tabulateReport(client *http.Client, reportid int64) (*models.Tabulation, error) { + var t models.Tabulation err := read(client, &t, "/v1/reports/"+strconv.FormatInt(reportid, 10)+"/tabulations") if err != nil { return nil, err @@ -73,7 +74,7 @@ func TestCreateReport(t *testing.T) { t.Errorf("Lua doesn't match") } - r.Lua = string(make([]byte, handlers.LuaMaxLength+1)) + r.Lua = string(make([]byte, models.LuaMaxLength+1)) _, err := createReport(d.clients[orig.UserId], &r) if err == nil { t.Fatalf("Expected error creating report with too-long Lua") @@ -173,7 +174,7 @@ func TestUpdateReport(t *testing.T) { t.Errorf("Lua doesn't match") } - r.Lua = string(make([]byte, handlers.LuaMaxLength+1)) + r.Lua = string(make([]byte, models.LuaMaxLength+1)) _, err = updateReport(d.clients[orig.UserId], r) if err == nil { t.Fatalf("Expected error updating report with too-long Lua") @@ -214,7 +215,7 @@ func TestDeleteReport(t *testing.T) { } }) } -func seriesEqualityHelper(t *testing.T, orig, curr map[string]*handlers.Series, name string) { +func seriesEqualityHelper(t *testing.T, orig, curr map[string]*models.Series, name string) { if orig == nil || curr == nil { if orig != nil { t.Fatalf("`%s` series unexpectedly nil", name) @@ -242,7 +243,7 @@ func seriesEqualityHelper(t *testing.T, orig, curr map[string]*handlers.Series, } } -func tabulationEqualityHelper(t *testing.T, orig, curr *handlers.Tabulation) { +func tabulationEqualityHelper(t *testing.T, orig, curr *models.Tabulation) { if orig.Title != curr.Title { t.Errorf("Tabulation Title doesn't match") } diff --git a/internal/handlers/testdata_test.go b/internal/handlers/testdata_test.go index 65bfee8..3d15888 100644 --- a/internal/handlers/testdata_test.go +++ b/internal/handlers/testdata_test.go @@ -3,7 +3,6 @@ package handlers_test import ( "encoding/json" "fmt" - "github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/models" "net/http" "strings" @@ -41,8 +40,8 @@ type TestData struct { prices []models.Price accounts []models.Account // accounts must appear after their parents in this slice transactions []models.Transaction - reports []handlers.Report - tabulations []handlers.Tabulation + reports []models.Report + tabulations []models.Tabulation } type TestDataFunc func(*testing.T, *TestData) @@ -382,7 +381,7 @@ var data = []TestData{ }, }, }, - reports: []handlers.Report{ + reports: []models.Report{ { UserId: 0, Name: "This Year's Monthly Expenses", @@ -440,39 +439,39 @@ function generate() end`, }, }, - tabulations: []handlers.Tabulation{ + tabulations: []models.Tabulation{ { ReportId: 0, Title: "2017 Monthly Expenses", Subtitle: "This is my subtitle", Units: "USD", Labels: []string{"2017-01-01", "2017-02-01", "2017-03-01", "2017-04-01", "2017-05-01", "2017-06-01", "2017-07-01", "2017-08-01", "2017-09-01", "2017-10-01", "2017-11-01", "2017-12-01"}, - Series: map[string]*handlers.Series{ + Series: map[string]*models.Series{ "Assets": { Values: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - Series: map[string]*handlers.Series{ + Series: map[string]*models.Series{ "Credit Union Checking": { Values: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - Series: map[string]*handlers.Series{}, + Series: map[string]*models.Series{}, }, }, }, "Expenses": { Values: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - Series: map[string]*handlers.Series{ + Series: map[string]*models.Series{ "Groceries": { Values: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 87.19, 0, 0}, - Series: map[string]*handlers.Series{}, + Series: map[string]*models.Series{}, }, "Cable": { Values: []float64{0, 0, 0, 0, 0, 0, 0, 0, 39.99, 0, 0, 0}, - Series: map[string]*handlers.Series{}, + Series: map[string]*models.Series{}, }, }, }, "Credit Card": { Values: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - Series: map[string]*handlers.Series{}, + Series: map[string]*models.Series{}, }, }, }, diff --git a/internal/models/reports.go b/internal/models/reports.go new file mode 100644 index 0000000..493fd21 --- /dev/null +++ b/internal/models/reports.go @@ -0,0 +1,66 @@ +package models + +import ( + "encoding/json" + "net/http" + "strings" +) + +type Report struct { + ReportId int64 + UserId int64 + Name string + Lua string +} + +// The maximum length (in bytes) the Lua code may be. This is used to set the +// max size of the database columns (with an added fudge factor) +const LuaMaxLength int = 65536 + +func (r *Report) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(r) +} + +func (r *Report) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(r) +} + +type ReportList struct { + Reports *[]Report `json:"reports"` +} + +func (rl *ReportList) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(rl) +} + +func (rl *ReportList) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(rl) +} + +type Series struct { + Values []float64 + Series map[string]*Series +} + +type Tabulation struct { + ReportId int64 + Title string + Subtitle string + Units string + Labels []string + Series map[string]*Series +} + +func (t *Tabulation) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(t) +} + +func (t *Tabulation) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(t) +}