mirror of
https://github.com/aclindsa/moneygo.git
synced 2024-10-31 16:00:05 -04:00
Begin splitting models from handlers with User
This commit is contained in:
parent
382d6ad434
commit
e70be1647c
@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/aclindsa/gorp"
|
"github.com/aclindsa/gorp"
|
||||||
"github.com/aclindsa/moneygo/internal/config"
|
"github.com/aclindsa/moneygo/internal/config"
|
||||||
"github.com/aclindsa/moneygo/internal/handlers"
|
"github.com/aclindsa/moneygo/internal/handlers"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
@ -33,7 +34,7 @@ func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
|
dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
|
||||||
dbmap.AddTableWithName(handlers.User{}, "users").SetKeys(true, "UserId")
|
dbmap.AddTableWithName(models.User{}, "users").SetKeys(true, "UserId")
|
||||||
dbmap.AddTableWithName(handlers.Session{}, "sessions").SetKeys(true, "SessionId")
|
dbmap.AddTableWithName(handlers.Session{}, "sessions").SetKeys(true, "SessionId")
|
||||||
dbmap.AddTableWithName(handlers.Account{}, "accounts").SetKeys(true, "AccountId")
|
dbmap.AddTableWithName(handlers.Account{}, "accounts").SetKeys(true, "AccountId")
|
||||||
dbmap.AddTableWithName(handlers.Security{}, "securities").SetKeys(true, "SecurityId")
|
dbmap.AddTableWithName(handlers.Security{}, "securities").SetKeys(true, "SecurityId")
|
||||||
|
@ -3,6 +3,7 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
"github.com/yuin/gopher-lua"
|
"github.com/yuin/gopher-lua"
|
||||||
"math/big"
|
"math/big"
|
||||||
"strings"
|
"strings"
|
||||||
@ -22,7 +23,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) {
|
|||||||
|
|
||||||
account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account)
|
account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account)
|
||||||
if !ok {
|
if !ok {
|
||||||
user, ok := ctx.Value(userContextKey).(*User)
|
user, ok := ctx.Value(userContextKey).(*models.User)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Couldn't find User in lua's Context")
|
return nil, errors.New("Couldn't find User in lua's Context")
|
||||||
}
|
}
|
||||||
@ -153,7 +154,7 @@ func luaAccountBalance(L *lua.LState) int {
|
|||||||
if !ok {
|
if !ok {
|
||||||
panic("Couldn't find tx in lua's Context")
|
panic("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
user, ok := ctx.Value(userContextKey).(*User)
|
user, ok := ctx.Value(userContextKey).(*models.User)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("Couldn't find User in lua's Context")
|
panic("Couldn't find User in lua's Context")
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/aclindsa/gorp"
|
"github.com/aclindsa/gorp"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
@ -16,7 +17,7 @@ type ResponseWriterWriter interface {
|
|||||||
|
|
||||||
type Context struct {
|
type Context struct {
|
||||||
Tx *Tx
|
Tx *Tx
|
||||||
User *User
|
User *models.User
|
||||||
remainingURL string // portion of URL path not yet reached in the hierarchy
|
remainingURL string // portion of URL path not yet reached in the hierarchy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
"github.com/aclindsa/ofxgo"
|
"github.com/aclindsa/ofxgo"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@ -22,7 +23,7 @@ func (od *OFXDownload) Read(json_str string) error {
|
|||||||
return dec.Decode(od)
|
return dec.Decode(od)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseWriterWriter {
|
func ofxImportHelper(tx *Tx, r io.Reader, user *models.User, accountid int64) ResponseWriterWriter {
|
||||||
itl, err := ImportOFX(r)
|
itl, err := ImportOFX(r)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -210,7 +211,7 @@ func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseW
|
|||||||
return SuccessWriter{}
|
return SuccessWriter{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func OFXImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter {
|
func OFXImportHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter {
|
||||||
var ofxdownload OFXDownload
|
var ofxdownload OFXDownload
|
||||||
if err := ReadJSON(r, &ofxdownload); err != nil {
|
if err := ReadJSON(r, &ofxdownload); err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
@ -305,7 +306,7 @@ func OFXImportHandler(context *Context, r *http.Request, user *User, accountid i
|
|||||||
return ofxImportHelper(context.Tx, response.Body, user, accountid)
|
return ofxImportHelper(context.Tx, response.Body, user, accountid)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OFXFileImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter {
|
func OFXFileImportHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter {
|
||||||
multipartReader, err := r.MultipartReader()
|
multipartReader, err := r.MultipartReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
@ -329,7 +330,7 @@ func OFXFileImportHandler(context *Context, r *http.Request, user *User, account
|
|||||||
/*
|
/*
|
||||||
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated
|
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated
|
||||||
*/
|
*/
|
||||||
func AccountImportHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter {
|
func AccountImportHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter {
|
||||||
|
|
||||||
importType := context.NextLevel()
|
importType := context.NextLevel()
|
||||||
switch importType {
|
switch importType {
|
||||||
|
@ -2,6 +2,7 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@ -129,7 +130,7 @@ func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Pr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func PriceHandler(r *http.Request, context *Context, user *User, securityid int64) ResponseWriterWriter {
|
func PriceHandler(r *http.Request, context *Context, user *models.User, securityid int64) ResponseWriterWriter {
|
||||||
security, err := GetSecurity(context.Tx, securityid, user.UserId)
|
security, err := GetSecurity(context.Tx, securityid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
"github.com/yuin/gopher-lua"
|
"github.com/yuin/gopher-lua"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -134,7 +135,7 @@ func DeleteReport(tx *Tx, r *Report) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) {
|
func runReport(tx *Tx, user *models.User, report *Report) (*Tabulation, error) {
|
||||||
// Create a new LState without opening the default libs for security
|
// Create a new LState without opening the default libs for security
|
||||||
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
||||||
defer L.Close()
|
defer L.Close()
|
||||||
@ -198,7 +199,7 @@ func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReportTabulationHandler(tx *Tx, r *http.Request, user *User, reportid int64) ResponseWriterWriter {
|
func ReportTabulationHandler(tx *Tx, r *http.Request, user *models.User, reportid int64) ResponseWriterWriter {
|
||||||
report, err := GetReport(tx, reportid, user.UserId)
|
report, err := GetReport(tx, reportid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
|
@ -3,6 +3,7 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
"github.com/yuin/gopher-lua"
|
"github.com/yuin/gopher-lua"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,7 +21,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) {
|
|||||||
|
|
||||||
security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security)
|
security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security)
|
||||||
if !ok {
|
if !ok {
|
||||||
user, ok := ctx.Value(userContextKey).(*User)
|
user, ok := ctx.Value(userContextKey).(*models.User)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Couldn't find User in lua's Context")
|
return nil, errors.New("Couldn't find User in lua's Context")
|
||||||
}
|
}
|
||||||
@ -50,7 +51,7 @@ func luaContextGetDefaultCurrency(L *lua.LState) (*Security, error) {
|
|||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
|
|
||||||
user, ok := ctx.Value(userContextKey).(*User)
|
user, ok := ctx.Value(userContextKey).(*models.User)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Couldn't find User in lua's Context")
|
return nil, errors.New("Couldn't find User in lua's Context")
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -120,7 +121,7 @@ func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error
|
|||||||
|
|
||||||
func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
func SessionHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||||
if r.Method == "POST" || r.Method == "PUT" {
|
if r.Method == "POST" || r.Method == "PUT" {
|
||||||
var user User
|
var user models.User
|
||||||
if err := ReadJSON(r, &user); err != nil {
|
if err := ReadJSON(r, &user); err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
"log"
|
"log"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -221,7 +222,7 @@ func GetTransactions(tx *Tx, userid int64) (*[]Transaction, error) {
|
|||||||
return &transactions, nil
|
return &transactions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func incrementAccountVersions(tx *Tx, user *User, accountids []int64) error {
|
func incrementAccountVersions(tx *Tx, user *models.User, accountids []int64) error {
|
||||||
for i := range accountids {
|
for i := range accountids {
|
||||||
account, err := GetAccount(tx, accountids[i], user.UserId)
|
account, err := GetAccount(tx, accountids[i], user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -245,7 +246,7 @@ func (ame AccountMissingError) Error() string {
|
|||||||
return "Account missing"
|
return "Account missing"
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertTransaction(tx *Tx, t *Transaction, user *User) error {
|
func InsertTransaction(tx *Tx, t *Transaction, user *models.User) error {
|
||||||
// Map of any accounts with transaction splits being added
|
// Map of any accounts with transaction splits being added
|
||||||
a_map := make(map[int64]bool)
|
a_map := make(map[int64]bool)
|
||||||
for i := range t.Splits {
|
for i := range t.Splits {
|
||||||
@ -295,7 +296,7 @@ func InsertTransaction(tx *Tx, t *Transaction, user *User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateTransaction(tx *Tx, t *Transaction, user *User) error {
|
func UpdateTransaction(tx *Tx, t *Transaction, user *models.User) error {
|
||||||
var existing_splits []*Split
|
var existing_splits []*Split
|
||||||
|
|
||||||
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
||||||
@ -372,7 +373,7 @@ func UpdateTransaction(tx *Tx, t *Transaction, user *User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteTransaction(tx *Tx, t *Transaction, user *User) error {
|
func DeleteTransaction(tx *Tx, t *Transaction, user *models.User) error {
|
||||||
var accountids []int64
|
var accountids []int64
|
||||||
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
|
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -549,7 +550,7 @@ func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []Trans
|
|||||||
return &pageDifference, nil
|
return &pageDifference, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountBalance(tx *Tx, user *User, accountid int64) (*big.Rat, error) {
|
func GetAccountBalance(tx *Tx, user *models.User, accountid int64) (*big.Rat, error) {
|
||||||
var splits []Split
|
var splits []Split
|
||||||
|
|
||||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
|
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
|
||||||
@ -572,7 +573,7 @@ func GetAccountBalance(tx *Tx, user *User, accountid int64) (*big.Rat, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assumes accountid is valid and is owned by the current user
|
// Assumes accountid is valid and is owned by the current user
|
||||||
func GetAccountBalanceDate(tx *Tx, user *User, accountid int64, date *time.Time) (*big.Rat, error) {
|
func GetAccountBalanceDate(tx *Tx, user *models.User, accountid int64, date *time.Time) (*big.Rat, error) {
|
||||||
var splits []Split
|
var splits []Split
|
||||||
|
|
||||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
|
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
|
||||||
@ -594,7 +595,7 @@ func GetAccountBalanceDate(tx *Tx, user *User, accountid int64, date *time.Time)
|
|||||||
return &balance, nil
|
return &balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountBalanceDateRange(tx *Tx, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
func GetAccountBalanceDateRange(tx *Tx, user *models.User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
||||||
var splits []Split
|
var splits []Split
|
||||||
|
|
||||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
|
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
|
||||||
@ -616,7 +617,7 @@ func GetAccountBalanceDateRange(tx *Tx, user *User, accountid int64, begin, end
|
|||||||
return &balance, nil
|
return &balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
|
func GetAccountTransactions(tx *Tx, user *models.User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
|
||||||
var transactions []Transaction
|
var transactions []Transaction
|
||||||
var atl AccountTransactionsList
|
var atl AccountTransactionsList
|
||||||
|
|
||||||
@ -699,7 +700,7 @@ func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, pa
|
|||||||
|
|
||||||
// Return only those transactions which have at least one split pertaining to
|
// Return only those transactions which have at least one split pertaining to
|
||||||
// an account
|
// an account
|
||||||
func AccountTransactionsHandler(context *Context, r *http.Request, user *User, accountid int64) ResponseWriterWriter {
|
func AccountTransactionsHandler(context *Context, r *http.Request, user *models.User, accountid int64) ResponseWriterWriter {
|
||||||
var page uint64 = 0
|
var page uint64 = 0
|
||||||
var limit uint64 = 50
|
var limit uint64 = 50
|
||||||
var sort string = "date-desc"
|
var sort string = "date-desc"
|
||||||
|
@ -1,53 +1,21 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"github.com/aclindsa/moneygo/internal/models"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
|
||||||
UserId int64
|
|
||||||
DefaultCurrency int64 // SecurityId of default currency, or ISO4217 code for it if creating new user
|
|
||||||
Name string
|
|
||||||
Username string
|
|
||||||
Password string `db:"-"`
|
|
||||||
PasswordHash string `json:"-"`
|
|
||||||
Email string
|
|
||||||
}
|
|
||||||
|
|
||||||
const BogusPassword = "password"
|
|
||||||
|
|
||||||
type UserExistsError struct{}
|
type UserExistsError struct{}
|
||||||
|
|
||||||
func (ueu UserExistsError) Error() string {
|
func (ueu UserExistsError) Error() string {
|
||||||
return "User exists"
|
return "User exists"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) Write(w http.ResponseWriter) error {
|
func GetUser(tx *Tx, userid int64) (*models.User, error) {
|
||||||
enc := json.NewEncoder(w)
|
var u models.User
|
||||||
return enc.Encode(u)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *User) Read(json_str string) error {
|
|
||||||
dec := json.NewDecoder(strings.NewReader(json_str))
|
|
||||||
return dec.Decode(u)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *User) HashPassword() {
|
|
||||||
password_hasher := sha256.New()
|
|
||||||
io.WriteString(password_hasher, u.Password)
|
|
||||||
u.PasswordHash = fmt.Sprintf("%x", password_hasher.Sum(nil))
|
|
||||||
u.Password = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUser(tx *Tx, userid int64) (*User, error) {
|
|
||||||
var u User
|
|
||||||
|
|
||||||
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -56,8 +24,8 @@ func GetUser(tx *Tx, userid int64) (*User, error) {
|
|||||||
return &u, nil
|
return &u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByUsername(tx *Tx, username string) (*User, error) {
|
func GetUserByUsername(tx *Tx, username string) (*models.User, error) {
|
||||||
var u User
|
var u models.User
|
||||||
|
|
||||||
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
|
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -66,7 +34,7 @@ func GetUserByUsername(tx *Tx, username string) (*User, error) {
|
|||||||
return &u, nil
|
return &u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertUser(tx *Tx, u *User) error {
|
func InsertUser(tx *Tx, u *models.User) error {
|
||||||
security_template := FindCurrencyTemplate(u.DefaultCurrency)
|
security_template := FindCurrencyTemplate(u.DefaultCurrency)
|
||||||
if security_template == nil {
|
if security_template == nil {
|
||||||
return errors.New("Invalid ISO4217 Default Currency")
|
return errors.New("Invalid ISO4217 Default Currency")
|
||||||
@ -107,7 +75,7 @@ func InsertUser(tx *Tx, u *User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) {
|
func GetUserFromSession(tx *Tx, r *http.Request) (*models.User, error) {
|
||||||
s, err := GetSession(tx, r)
|
s, err := GetSession(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -115,7 +83,7 @@ func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) {
|
|||||||
return GetUser(tx, s.UserId)
|
return GetUser(tx, s.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUser(tx *Tx, u *User) error {
|
func UpdateUser(tx *Tx, u *models.User) error {
|
||||||
security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
|
security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -135,7 +103,7 @@ func UpdateUser(tx *Tx, u *User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteUser(tx *Tx, u *User) error {
|
func DeleteUser(tx *Tx, u *models.User) error {
|
||||||
count, err := tx.Delete(u)
|
count, err := tx.Delete(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -177,7 +145,7 @@ func DeleteUser(tx *Tx, u *User) error {
|
|||||||
|
|
||||||
func UserHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
func UserHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
||||||
if r.Method == "POST" {
|
if r.Method == "POST" {
|
||||||
var user User
|
var user models.User
|
||||||
if err := ReadJSON(r, &user); err != nil {
|
if err := ReadJSON(r, &user); err != nil {
|
||||||
return NewError(3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
@ -221,7 +189,7 @@ func UserHandler(r *http.Request, context *Context) ResponseWriterWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the user didn't create a new password, keep their old one
|
// If the user didn't create a new password, keep their old one
|
||||||
if user.Password != BogusPassword {
|
if user.Password != models.BogusPassword {
|
||||||
user.HashPassword()
|
user.HashPassword()
|
||||||
} else {
|
} else {
|
||||||
user.Password = ""
|
user.Password = ""
|
||||||
|
39
internal/models/users.go
Normal file
39
internal/models/users.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
UserId int64
|
||||||
|
DefaultCurrency int64 // SecurityId of default currency, or ISO4217 code for it if creating new user
|
||||||
|
Name string
|
||||||
|
Username string
|
||||||
|
Password string `db:"-"`
|
||||||
|
PasswordHash string `json:"-"`
|
||||||
|
Email string
|
||||||
|
}
|
||||||
|
|
||||||
|
const BogusPassword = "password"
|
||||||
|
|
||||||
|
func (u *User) Write(w http.ResponseWriter) error {
|
||||||
|
enc := json.NewEncoder(w)
|
||||||
|
return enc.Encode(u)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) Read(json_str string) error {
|
||||||
|
dec := json.NewDecoder(strings.NewReader(json_str))
|
||||||
|
return dec.Decode(u)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) HashPassword() {
|
||||||
|
password_hasher := sha256.New()
|
||||||
|
io.WriteString(password_hasher, u.Password)
|
||||||
|
u.PasswordHash = fmt.Sprintf("%x", password_hasher.Sum(nil))
|
||||||
|
u.Password = ""
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user