Begin splitting models from handlers with User

This commit is contained in:
Aaron Lindsay 2017-12-02 06:14:47 -05:00
parent 382d6ad434
commit e70be1647c
11 changed files with 82 additions and 66 deletions

View File

@ -6,6 +6,7 @@ import (
"github.com/aclindsa/gorp" "github.com/aclindsa/gorp"
"github.com/aclindsa/moneygo/internal/config" "github.com/aclindsa/moneygo/internal/config"
"github.com/aclindsa/moneygo/internal/handlers" "github.com/aclindsa/moneygo/internal/handlers"
"github.com/aclindsa/moneygo/internal/models"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -33,7 +34,7 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
} }
dbmap := &gorp.DbMap{Db: db, Dialect: dialect} dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
dbmap.AddTableWithName(handlers.User{}, "users").SetKeys(true, "UserId") dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId")
dbmap.AddTableWithName(handlers.Session{}, "sessions").SetKeys(true, "SessionId") dbmap.AddTableWithName(handlers.Session{}, "sessions").SetKeys(true, "SessionId")
dbmap.AddTableWithName(handlers.Account{}, "accounts").SetKeys(true, "AccountId") dbmap.AddTableWithName(handlers.Account{}, "accounts").SetKeys(true, "AccountId")
dbmap.AddTableWithName(handlers.Security{}, "securities").SetKeys(true, "SecurityId") dbmap.AddTableWithName(handlers.Security{}, "securities").SetKeys(true, "SecurityId")

View File

@ -3,6 +3,7 @@ package handlers
import ( import (
"context" "context"
"errors" "errors"
"github.com/aclindsa/moneygo/internal/models"
"github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua"
"math/big" "math/big"
"strings" "strings"
@ -22,7 +23,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) {
account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account) account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account)
if !ok { if !ok {
user, ok := ctx.Value(userContextKey).(*User) user, ok := ctx.Value(userContextKey).(*models.User)
if !ok { if !ok {
return nil, errors.New("Couldn't find User in lua's Context") return nil, errors.New("Couldn't find User in lua's Context")
} }
@ -153,7 +154,7 @@ func luaAccountBalance(L *lua.LState) int {
if !ok { if !ok {
panic("Couldn't find tx in lua's Context") panic("Couldn't find tx in lua's Context")
} }
user, ok := ctx.Value(userContextKey).(*User) user, ok := ctx.Value(userContextKey).(*models.User)
if !ok { if !ok {
panic("Couldn't find User in lua's Context") panic("Couldn't find User in lua's Context")
} }

View File

@ -2,6 +2,7 @@ package handlers
import ( import (
"github.com/aclindsa/gorp" "github.com/aclindsa/gorp"
"github.com/aclindsa/moneygo/internal/models"
"log" "log"
"net/http" "net/http"
"path" "path"
@ -16,7 +17,7 @@ type ResponseWriterWriter interface {
type Context struct { type Context struct {
Tx *Tx Tx *Tx
User *User User *models.User
remainingURL string // portion of URL path not yet reached in the hierarchy remainingURL string // portion of URL path not yet reached in the hierarchy
} }

View File

@ -2,6 +2,7 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"github.com/aclindsa/moneygo/internal/models"
"github.com/aclindsa/ofxgo" "github.com/aclindsa/ofxgo"
"io" "io"
"log" "log"
@ -22,7 +23,7 @@ func (od *OFXDownload) Read(json_str string) error {
return dec.Decode(od) return dec.Decode(od)
} }
func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseWriterWriter { func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
itl, err := ImportOFX(r) itl, err := ImportOFX(r)
if err != nil { if err != nil {
@ -210,7 +211,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW
return SuccessWriter{} return SuccessWriter{}
} }
func OFXImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { func OFXImportHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter {
var ofxdownload OFXDownload var ofxdownload OFXDownload
if err := ReadJSON(r, &ofxdownload); err != nil { if err := ReadJSON(r, &ofxdownload); err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
@ -305,7 +306,7 @@ func OFXImportHandler(context *Context, r *http.Request, user *User, accountid i
return ofxImportHelper(context.Tx, response.Body, user, accountid) return ofxImportHelper(context.Tx, response.Body, user, accountid)
} }
func OFXFileImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { func OFXFileImportHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter {
multipartReader, err := r.MultipartReader() multipartReader, err := r.MultipartReader()
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
@ -329,7 +330,7 @@ func OFXFileImportHandler(context *Context, r *http.Request, user *User, account
/* /*
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated * Assumes the User is a valid, signed-in user, but accountid has not yet been validated
*/ */
func AccountImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { func AccountImportHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter {
importType := context.NextLevel() importType := context.NextLevel()
switch importType { switch importType {

View File

@ -2,6 +2,7 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"github.com/aclindsa/moneygo/internal/models"
"log" "log"
"net/http" "net/http"
"strings" "strings"
@ -129,7 +130,7 @@ func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Pr
} }
} }
func PriceHandler(r *http.Request, context *Context, user *User, securityid int64) ResponseWriterWriter { func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter {
security, err := GetSecurity(context.Tx, securityid, user.UserId) security, err := GetSecurity(context.Tx, securityid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/aclindsa/moneygo/internal/models"
"github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua"
"log" "log"
"net/http" "net/http"
@ -134,7 +135,7 @@ func DeleteReport(tx *Tx, r *Report) error {
return nil return nil
} }
func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) { func runReport(tx *Tx, user *models.User, report *Report) (*Tabulation, error) {
// Create a new LState without opening the default libs for security // Create a new LState without opening the default libs for security
L := lua.NewState(lua.Options{SkipOpenLibs: true}) L := lua.NewState(lua.Options{SkipOpenLibs: true})
defer L.Close() defer L.Close()
@ -198,7 +199,7 @@ func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) {
} }
} }
func ReportTabulationHandler(tx *Tx, r *http.Request, user *User, reportid int64) ResponseWriterWriter { func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
report, err := GetReport(tx, reportid, user.UserId) report, err := GetReport(tx, reportid, user.UserId)
if err != nil { if err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)

View File

@ -3,6 +3,7 @@ package handlers
import ( import (
"context" "context"
"errors" "errors"
"github.com/aclindsa/moneygo/internal/models"
"github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua"
) )
@ -20,7 +21,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) {
security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security) security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security)
if !ok { if !ok {
user, ok := ctx.Value(userContextKey).(*User) user, ok := ctx.Value(userContextKey).(*models.User)
if !ok { if !ok {
return nil, errors.New("Couldn't find User in lua's Context") return nil, errors.New("Couldn't find User in lua's Context")
} }
@ -50,7 +51,7 @@ func luaContextGetDefaultCurrency(L *lua.LState) (*Security, error) {
ctx := L.Context() ctx := L.Context()
user, ok := ctx.Value(userContextKey).(*User) user, ok := ctx.Value(userContextKey).(*models.User)
if !ok { if !ok {
return nil, errors.New("Couldn't find User in lua's Context") return nil, errors.New("Couldn't find User in lua's Context")
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/aclindsa/moneygo/internal/models"
"io" "io"
"log" "log"
"net/http" "net/http"
@ -120,7 +121,7 @@ func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error
func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter { func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter {
if r.Method == "POST" || r.Method == "PUT" { if r.Method == "POST" || r.Method == "PUT" {
var user User var user models.User
if err := ReadJSON(r, &user); err != nil { if err := ReadJSON(r, &user); err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }

View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/aclindsa/moneygo/internal/models"
"log" "log"
"math/big" "math/big"
"net/http" "net/http"
@ -221,7 +222,7 @@ func GetTransactions(tx *Tx, userid int64) (*[]Transaction, error) {
return &transactions, nil return &transactions, nil
} }
func incrementAccountVersions(tx *Tx, user *User, accountids []int64) error { func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error {
for i := range accountids { for i := range accountids {
account, err := GetAccount(tx, accountids[i], user.UserId) account, err := GetAccount(tx, accountids[i], user.UserId)
if err != nil { if err != nil {
@ -245,7 +246,7 @@ func (ame AccountMissingError) Error() string {
return "Account missing" return "Account missing"
} }
func InsertTransaction(tx *Tx, t *Transaction, user *User) error { func InsertTransaction(tx *Tx, t *Transaction, user *models.User) error {
// Map of any accounts with transaction splits being added // Map of any accounts with transaction splits being added
a_map := make(map[int64]bool) a_map := make(map[int64]bool)
for i := range t.Splits { for i := range t.Splits {
@ -295,7 +296,7 @@ func InsertTransaction(tx *Tx, t *Transaction, user *User) error {
return nil return nil
} }
func UpdateTransaction(tx *Tx, t *Transaction, user *User) error { func UpdateTransaction(tx *Tx, t *Transaction, user *models.User) error {
var existing_splits []*Split var existing_splits []*Split
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId) _, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
@ -372,7 +373,7 @@ func UpdateTransaction(tx *Tx, t *Transaction, user *User) error {
return nil return nil
} }
func DeleteTransaction(tx *Tx, t *Transaction, user *User) error { func DeleteTransaction(tx *Tx, t *Transaction, user *models.User) error {
var accountids []int64 var accountids []int64
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) _, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
if err != nil { if err != nil {
@ -549,7 +550,7 @@ func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []Trans
return &pageDifference, nil return &pageDifference, nil
} }
func GetAccountBalance(tx *Tx, user *User, accountid int64) (*big.Rat, error) { func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) {
var splits []Split var splits []Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?" sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
@ -572,7 +573,7 @@ func GetAccountBalance(tx *Tx, user *User, accountid int64) (*big.Rat, error) {
} }
// Assumes accountid is valid and is owned by the current user // Assumes accountid is valid and is owned by the current user
func GetAccountBalanceDate(tx *Tx, user *User, accountid int64, date *time.Time) (*big.Rat, error) { func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) {
var splits []Split var splits []Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?" sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
@ -594,7 +595,7 @@ func GetAccountBalanceDate(tx *Tx, user *User, accountid int64, date *time.Time)
return &balance, nil return &balance, nil
} }
func GetAccountBalanceDateRange(tx *Tx, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) { func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
var splits []Split var splits []Split
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?" sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
@ -616,7 +617,7 @@ func GetAccountBalanceDateRange(tx *Tx, user *User, accountid int64, begin, end
return &balance, nil return &balance, nil
} }
func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) { func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
var transactions []Transaction var transactions []Transaction
var atl AccountTransactionsList var atl AccountTransactionsList
@ -699,7 +700,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
// Return only those transactions which have at least one split pertaining to // Return only those transactions which have at least one split pertaining to
// an account // an account
func AccountTransactionsHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter { func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter {
var page uint64 = 0 var page uint64 = 0
var limit uint64 = 50 var limit uint64 = 50
var sort string = "date-desc" var sort string = "date-desc"

View File

@ -1,53 +1,21 @@
package handlers package handlers
import ( import (
"crypto/sha256"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "github.com/aclindsa/moneygo/internal/models"
"log" "log"
"net/http" "net/http"
"strings"
) )
type User struct {
UserId int64
DefaultCurrency int64 // SecurityId of default currency, or ISO4217 code for it if creating new user
Name string
Username string
Password string `db:"-"`
PasswordHash string `json:"-"`
Email string
}
const BogusPassword = "password"
type UserExistsError struct{} type UserExistsError struct{}
func (ueu UserExistsError) Error() string { func (ueu UserExistsError) Error() string {
return "User exists" return "User exists"
} }
func (u *User) Write(w http.ResponseWriter) error { func GetUser(tx *Tx, userid int64) (*models.User, error) {
enc := json.NewEncoder(w) var u models.User
return enc.Encode(u)
}
func (u *User) Read(json_str string) error {
dec := json.NewDecoder(strings.NewReader(json_str))
return dec.Decode(u)
}
func (u *User) HashPassword() {
password_hasher := sha256.New()
io.WriteString(password_hasher, u.Password)
u.PasswordHash = fmt.Sprintf("%x", password_hasher.Sum(nil))
u.Password = ""
}
func GetUser(tx *Tx, userid int64) (*User, error) {
var u User
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid) err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
if err != nil { if err != nil {
@ -56,8 +24,8 @@ func GetUser(tx *Tx, userid int64) (*User, error) {
return &u, nil return &u, nil
} }
func GetUserByUsername(tx *Tx, username string) (*User, error) { func GetUserByUsername(tx *Tx, username string) (*models.User, error) {
var u User var u models.User
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username) err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
if err != nil { if err != nil {
@ -66,7 +34,7 @@ func GetUserByUsername(tx *Tx, username string) (*User, error) {
return &u, nil return &u, nil
} }
func InsertUser(tx *Tx, u *User) error { func InsertUser(tx *Tx, u *models.User) error {
security_template := FindCurrencyTemplate(u.DefaultCurrency) security_template := FindCurrencyTemplate(u.DefaultCurrency)
if security_template == nil { if security_template == nil {
return errors.New("Invalid ISO4217 Default Currency") return errors.New("Invalid ISO4217 Default Currency")
@ -107,7 +75,7 @@ func InsertUser(tx *Tx, u *User) error {
return nil return nil
} }
func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) { func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) {
s, err := GetSession(tx, r) s, err := GetSession(tx, r)
if err != nil { if err != nil {
return nil, err return nil, err
@ -115,7 +83,7 @@ func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) {
return GetUser(tx, s.UserId) return GetUser(tx, s.UserId)
} }
func UpdateUser(tx *Tx, u *User) error { func UpdateUser(tx *Tx, u *models.User) error {
security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId) security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
if err != nil { if err != nil {
return err return err
@ -135,7 +103,7 @@ func UpdateUser(tx *Tx, u *User) error {
return nil return nil
} }
func DeleteUser(tx *Tx, u *User) error { func DeleteUser(tx *Tx, u *models.User) error {
count, err := tx.Delete(u) count, err := tx.Delete(u)
if err != nil { if err != nil {
return err return err
@ -177,7 +145,7 @@ func DeleteUser(tx *Tx, u *User) error {
func UserHandler(r *http.Request, context *Context) ResponseWriterWriter { func UserHandler(r *http.Request, context *Context) ResponseWriterWriter {
if r.Method == "POST" { if r.Method == "POST" {
var user User var user models.User
if err := ReadJSON(r, &user); err != nil { if err := ReadJSON(r, &user); err != nil {
return NewError(3 /*Invalid Request*/) return NewError(3 /*Invalid Request*/)
} }
@ -221,7 +189,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter {
} }
// If the user didn't create a new password, keep their old one // If the user didn't create a new password, keep their old one
if user.Password != BogusPassword { if user.Password != models.BogusPassword {
user.HashPassword() user.HashPassword()
} else { } else {
user.Password = "" user.Password = ""

39
internal/models/users.go Normal file
View File

@ -0,0 +1,39 @@
package models
import (
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
type User struct {
UserId int64
DefaultCurrency int64 // SecurityId of default currency, or ISO4217 code for it if creating new user
Name string
Username string
Password string `db:"-"`
PasswordHash string `json:"-"`
Email string
}
const BogusPassword = "password"
func (u *User) Write(w http.ResponseWriter) error {
enc := json.NewEncoder(w)
return enc.Encode(u)
}
func (u *User) Read(json_str string) error {
dec := json.NewDecoder(strings.NewReader(json_str))
return dec.Decode(u)
}
func (u *User) HashPassword() {
password_hasher := sha256.New()
io.WriteString(password_hasher, u.Password)
u.PasswordHash = fmt.Sprintf("%x", password_hasher.Sum(nil))
u.Password = ""
}