diff --git a/internal/db/db.go b/internal/db/db.go index 68cfd6d..fbb7b9d 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -6,6 +6,7 @@ import ( "github.com/aclindsa/gorp" "github.com/aclindsa/moneygo/internal/config" "github.com/aclindsa/moneygo/internal/handlers" + "github.com/aclindsa/moneygo/internal/models" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" @@ -33,7 +34,7 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { } dbmap := &gorp.DbMap{Db: db, Dialect: dialect} - dbmap.AddTableWithName(handlers.User{}, "users").SetKeys(true, "UserId") + dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId") dbmap.AddTableWithName(handlers.Session{}, "sessions").SetKeys(true, "SessionId") dbmap.AddTableWithName(handlers.Account{}, "accounts").SetKeys(true, "AccountId") dbmap.AddTableWithName(handlers.Security{}, "securities").SetKeys(true, "SecurityId") diff --git a/internal/handlers/accounts_lua.go b/internal/handlers/accounts_lua.go index 4edf440..8a4d5db 100644 --- a/internal/handlers/accounts_lua.go +++ b/internal/handlers/accounts_lua.go @@ -3,6 +3,7 @@ package handlers import ( "context" "errors" + "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" "math/big" "strings" @@ -22,7 +23,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) { account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account) if !ok { - user, ok := ctx.Value(userContextKey).(*User) + user, ok := ctx.Value(userContextKey).(*models.User) if !ok { return nil, errors.New("Couldn't find User in lua's Context") } @@ -153,7 +154,7 @@ func luaAccountBalance(L *lua.LState) int { if !ok { panic("Couldn't find tx in lua's Context") } - user, ok := ctx.Value(userContextKey).(*User) + user, ok := ctx.Value(userContextKey).(*models.User) if !ok { panic("Couldn't find User in lua's Context") } diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index e42419d..309eed1 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -2,6 +2,7 @@ package handlers import ( "github.com/aclindsa/gorp" + "github.com/aclindsa/moneygo/internal/models" "log" "net/http" "path" @@ -16,7 +17,7 @@ type ResponseWriterWriter interface { type Context struct { Tx *Tx - User *User + User *models.User remainingURL string // portion of URL path not yet reached in the hierarchy } diff --git a/internal/handlers/imports.go b/internal/handlers/imports.go index f83ab0c..a90fb8b 100644 --- a/internal/handlers/imports.go +++ b/internal/handlers/imports.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/json" + "github.com/aclindsa/moneygo/internal/models" "github.com/aclindsa/ofxgo" "io" "log" @@ -22,7 +23,7 @@ func (od *OFXDownload) Read(json_str string) error { return dec.Decode(od) } -func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseWriterWriter { +func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter { itl, err := ImportOFX(r) if err != nil { @@ -210,7 +211,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW return SuccessWriter{} } -func OFXImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { +func OFXImportHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { var ofxdownload OFXDownload if err := ReadJSON(r, &ofxdownload); err != nil { return NewError(3 /*Invalid Request*/) @@ -305,7 +306,7 @@ func OFXImportHandler(context *Context, r *http.Request, user *User, accountid i return ofxImportHelper(context.Tx, response.Body, user, accountid) } -func OFXFileImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { +func OFXFileImportHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { multipartReader, err := r.MultipartReader() if err != nil { return NewError(3 /*Invalid Request*/) @@ -329,7 +330,7 @@ func OFXFileImportHandler(context *Context, r *http.Request, user *User, account /* * Assumes the User is a valid, signed-in user, but accountid has not yet been validated */ -func AccountImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { +func AccountImportHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { importType := context.NextLevel() switch importType { diff --git a/internal/handlers/prices.go b/internal/handlers/prices.go index 2027497..fa2c058 100644 --- a/internal/handlers/prices.go +++ b/internal/handlers/prices.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/json" + "github.com/aclindsa/moneygo/internal/models" "log" "net/http" "strings" @@ -129,7 +130,7 @@ func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Pr } } -func PriceHandler(r *http.Request, context *Context, user *User, securityid int64) ResponseWriterWriter { +func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter { security, err := GetSecurity(context.Tx, securityid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) diff --git a/internal/handlers/reports.go b/internal/handlers/reports.go index 97c4b64..cab32d1 100644 --- a/internal/handlers/reports.go +++ b/internal/handlers/reports.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" "log" "net/http" @@ -134,7 +135,7 @@ func DeleteReport(tx *Tx, r *Report) error { return nil } -func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) { +func runReport(tx *Tx, user *models.User, report *Report) (*Tabulation, error) { // Create a new LState without opening the default libs for security L := lua.NewState(lua.Options{SkipOpenLibs: true}) defer L.Close() @@ -198,7 +199,7 @@ func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) { } } -func ReportTabulationHandler(tx *Tx, r *http.Request, user *User, reportid int64) ResponseWriterWriter { +func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter { report, err := GetReport(tx, reportid, user.UserId) if err != nil { return NewError(3 /*Invalid Request*/) diff --git a/internal/handlers/securities_lua.go b/internal/handlers/securities_lua.go index 555bdc4..b294c1c 100644 --- a/internal/handlers/securities_lua.go +++ b/internal/handlers/securities_lua.go @@ -3,6 +3,7 @@ package handlers import ( "context" "errors" + "github.com/aclindsa/moneygo/internal/models" "github.com/yuin/gopher-lua" ) @@ -20,7 +21,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) { security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security) if !ok { - user, ok := ctx.Value(userContextKey).(*User) + user, ok := ctx.Value(userContextKey).(*models.User) if !ok { return nil, errors.New("Couldn't find User in lua's Context") } @@ -50,7 +51,7 @@ func luaContextGetDefaultCurrency(L *lua.LState) (*Security, error) { ctx := L.Context() - user, ok := ctx.Value(userContextKey).(*User) + user, ok := ctx.Value(userContextKey).(*models.User) if !ok { return nil, errors.New("Couldn't find User in lua's Context") } diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index 55c9317..81954b1 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "github.com/aclindsa/moneygo/internal/models" "io" "log" "net/http" @@ -120,7 +121,7 @@ func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter { if r.Method == "POST" || r.Method == "PUT" { - var user User + var user models.User if err := ReadJSON(r, &user); err != nil { return NewError(3 /*Invalid Request*/) } diff --git a/internal/handlers/transactions.go b/internal/handlers/transactions.go index a6adc8e..60c97b2 100644 --- a/internal/handlers/transactions.go +++ b/internal/handlers/transactions.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/aclindsa/moneygo/internal/models" "log" "math/big" "net/http" @@ -221,7 +222,7 @@ func GetTransactions(tx *Tx, userid int64) (*[]Transaction, error) { return &transactions, nil } -func incrementAccountVersions(tx *Tx, user *User, accountids []int64) error { +func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error { for i := range accountids { account, err := GetAccount(tx, accountids[i], user.UserId) if err != nil { @@ -245,7 +246,7 @@ func (ame AccountMissingError) Error() string { return "Account missing" } -func InsertTransaction(tx *Tx, t *Transaction, user *User) error { +func InsertTransaction(tx *Tx, t *Transaction, user *models.User) error { // Map of any accounts with transaction splits being added a_map := make(map[int64]bool) for i := range t.Splits { @@ -295,7 +296,7 @@ func InsertTransaction(tx *Tx, t *Transaction, user *User) error { return nil } -func UpdateTransaction(tx *Tx, t *Transaction, user *User) error { +func UpdateTransaction(tx *Tx, t *Transaction, user *models.User) error { var existing_splits []*Split _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) @@ -372,7 +373,7 @@ func UpdateTransaction(tx *Tx, t *Transaction, user *User) error { return nil } -func DeleteTransaction(tx *Tx, t *Transaction, user *User) error { +func DeleteTransaction(tx *Tx, t *Transaction, user *models.User) error { var accountids []int64 _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) if err != nil { @@ -549,7 +550,7 @@ func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []Trans return &pageDifference, nil } -func GetAccountBalance(tx *Tx, user *User, accountid int64) (*big.Rat, error) { +func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) { var splits []Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" @@ -572,7 +573,7 @@ func GetAccountBalance(tx *Tx, user *User, accountid int64) (*big.Rat, error) { } // Assumes accountid is valid and is owned by the current user -func GetAccountBalanceDate(tx *Tx, user *User, accountid int64, date *time.Time) (*big.Rat, error) { +func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) { var splits []Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" @@ -594,7 +595,7 @@ func GetAccountBalanceDate(tx *Tx, user *User, accountid int64, date *time.Time) return &balance, nil } -func GetAccountBalanceDateRange(tx *Tx, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { +func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) { var splits []Split sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" @@ -616,7 +617,7 @@ func GetAccountBalanceDateRange(tx *Tx, user *User, accountid int64, begin, end return &balance, nil } -func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { +func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { var transactions []Transaction var atl AccountTransactionsList @@ -699,7 +700,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa // Return only those transactions which have at least one split pertaining to // an account -func AccountTransactionsHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { +func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter { var page uint64 = 0 var limit uint64 = 50 var sort string = "date-desc" diff --git a/internal/handlers/users.go b/internal/handlers/users.go index 1dfc5d0..66b5737 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -1,53 +1,21 @@ package handlers import ( - "crypto/sha256" - "encoding/json" "errors" "fmt" - "io" + "github.com/aclindsa/moneygo/internal/models" "log" "net/http" - "strings" ) -type User struct { - UserId int64 - DefaultCurrency int64 // SecurityId of default currency, or ISO4217 code for it if creating new user - Name string - Username string - Password string `db:"-"` - PasswordHash string `json:"-"` - Email string -} - -const BogusPassword = "password" - type UserExistsError struct{} func (ueu UserExistsError) Error() string { return "User exists" } -func (u *User) Write(w http.ResponseWriter) error { - enc := json.NewEncoder(w) - return enc.Encode(u) -} - -func (u *User) Read(json_str string) error { - dec := json.NewDecoder(strings.NewReader(json_str)) - return dec.Decode(u) -} - -func (u *User) HashPassword() { - password_hasher := sha256.New() - io.WriteString(password_hasher, u.Password) - u.PasswordHash = fmt.Sprintf("%x", password_hasher.Sum(nil)) - u.Password = "" -} - -func GetUser(tx *Tx, userid int64) (*User, error) { - var u User +func GetUser(tx *Tx, userid int64) (*models.User, error) { + var u models.User err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) if err != nil { @@ -56,8 +24,8 @@ func GetUser(tx *Tx, userid int64) (*User, error) { return &u, nil } -func GetUserByUsername(tx *Tx, username string) (*User, error) { - var u User +func GetUserByUsername(tx *Tx, username string) (*models.User, error) { + var u models.User err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) if err != nil { @@ -66,7 +34,7 @@ func GetUserByUsername(tx *Tx, username string) (*User, error) { return &u, nil } -func InsertUser(tx *Tx, u *User) error { +func InsertUser(tx *Tx, u *models.User) error { security_template := FindCurrencyTemplate(u.DefaultCurrency) if security_template == nil { return errors.New("Invalid ISO4217 Default Currency") @@ -107,7 +75,7 @@ func InsertUser(tx *Tx, u *User) error { return nil } -func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) { +func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) { s, err := GetSession(tx, r) if err != nil { return nil, err @@ -115,7 +83,7 @@ func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) { return GetUser(tx, s.UserId) } -func UpdateUser(tx *Tx, u *User) error { +func UpdateUser(tx *Tx, u *models.User) error { security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId) if err != nil { return err @@ -135,7 +103,7 @@ func UpdateUser(tx *Tx, u *User) error { return nil } -func DeleteUser(tx *Tx, u *User) error { +func DeleteUser(tx *Tx, u *models.User) error { count, err := tx.Delete(u) if err != nil { return err @@ -177,7 +145,7 @@ func DeleteUser(tx *Tx, u *User) error { func UserHandler(r *http.Request, context *Context) ResponseWriterWriter { if r.Method == "POST" { - var user User + var user models.User if err := ReadJSON(r, &user); err != nil { return NewError(3 /*Invalid Request*/) } @@ -221,7 +189,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter { } // If the user didn't create a new password, keep their old one - if user.Password != BogusPassword { + if user.Password != models.BogusPassword { user.HashPassword() } else { user.Password = "" diff --git a/internal/models/users.go b/internal/models/users.go new file mode 100644 index 0000000..c5f45ce --- /dev/null +++ b/internal/models/users.go @@ -0,0 +1,39 @@ +package models + +import ( + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +type User struct { + UserId int64 + DefaultCurrency int64 // SecurityId of default currency, or ISO4217 code for it if creating new user + Name string + Username string + Password string `db:"-"` + PasswordHash string `json:"-"` + Email string +} + +const BogusPassword = "password" + +func (u *User) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(u) +} + +func (u *User) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(u) +} + +func (u *User) HashPassword() { + password_hasher := sha256.New() + io.WriteString(password_hasher, u.Password) + u.PasswordHash = fmt.Sprintf("%x", password_hasher.Sum(nil)) + u.Password = "" +}