mirror of
https://github.com/aclindsa/moneygo.git
synced 2024-12-27 07:52:28 -05:00
Merge pull request #24 from aclindsa/sql_transaction_updates
Sql transaction updates
This commit is contained in:
commit
f2ce7adb52
@ -3,7 +3,6 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"gopkg.in/gorp.v1"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
@ -129,31 +128,20 @@ func (al *AccountList) Read(json_str string) error {
|
|||||||
return dec.Decode(al)
|
return dec.Decode(al)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccount(db *DB, accountid int64, userid int64) (*Account, error) {
|
func GetAccount(tx *Tx, accountid int64, userid int64) (*Account, error) {
|
||||||
var a Account
|
var a Account
|
||||||
|
|
||||||
err := db.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
err := tx.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &a, nil
|
return &a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountTx(transaction *gorp.Transaction, accountid int64, userid int64) (*Account, error) {
|
func GetAccounts(tx *Tx, userid int64) (*[]Account, error) {
|
||||||
var a Account
|
|
||||||
|
|
||||||
err := transaction.SelectOne(&a, "SELECT * from accounts where UserId=? AND AccountId=?", userid, accountid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &a, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAccounts(db *DB, userid int64) (*[]Account, error) {
|
|
||||||
var accounts []Account
|
var accounts []Account
|
||||||
|
|
||||||
_, err := db.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=?", userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -162,12 +150,12 @@ func GetAccounts(db *DB, userid int64) (*[]Account, error) {
|
|||||||
|
|
||||||
// Get (and attempt to create if it doesn't exist). Matches on UserId,
|
// Get (and attempt to create if it doesn't exist). Matches on UserId,
|
||||||
// SecurityId, Type, Name, and ParentAccountId
|
// SecurityId, Type, Name, and ParentAccountId
|
||||||
func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, error) {
|
func GetCreateAccount(tx *Tx, a Account) (*Account, error) {
|
||||||
var accounts []Account
|
var accounts []Account
|
||||||
var account Account
|
var account Account
|
||||||
|
|
||||||
// Try to find the top-level trading account
|
// Try to find the top-level trading account
|
||||||
_, err := transaction.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId)
|
_, err := tx.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -180,7 +168,7 @@ func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, err
|
|||||||
account.Name = a.Name
|
account.Name = a.Name
|
||||||
account.ParentAccountId = a.ParentAccountId
|
account.ParentAccountId = a.ParentAccountId
|
||||||
|
|
||||||
err = transaction.Insert(&account)
|
err = tx.Insert(&account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -190,11 +178,11 @@ func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, err
|
|||||||
|
|
||||||
// Get (and attempt to create if it doesn't exist) the security/currency
|
// Get (and attempt to create if it doesn't exist) the security/currency
|
||||||
// trading account for the supplied security/currency
|
// trading account for the supplied security/currency
|
||||||
func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) {
|
func GetTradingAccount(tx *Tx, userid int64, securityid int64) (*Account, error) {
|
||||||
var tradingAccount Account
|
var tradingAccount Account
|
||||||
var account Account
|
var account Account
|
||||||
|
|
||||||
user, err := GetUserTx(transaction, userid)
|
user, err := GetUser(tx, userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -206,12 +194,12 @@ func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid i
|
|||||||
tradingAccount.ParentAccountId = -1
|
tradingAccount.ParentAccountId = -1
|
||||||
|
|
||||||
// Find/create the top-level trading account
|
// Find/create the top-level trading account
|
||||||
ta, err := GetCreateAccountTx(transaction, tradingAccount)
|
ta, err := GetCreateAccount(tx, tradingAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
security, err := GetSecurityTx(transaction, securityid, userid)
|
security, err := GetSecurity(tx, securityid, userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -222,7 +210,7 @@ func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid i
|
|||||||
account.SecurityId = securityid
|
account.SecurityId = securityid
|
||||||
account.Type = Trading
|
account.Type = Trading
|
||||||
|
|
||||||
a, err := GetCreateAccountTx(transaction, account)
|
a, err := GetCreateAccount(tx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -232,14 +220,14 @@ func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid i
|
|||||||
|
|
||||||
// Get (and attempt to create if it doesn't exist) the security/currency
|
// Get (and attempt to create if it doesn't exist) the security/currency
|
||||||
// imbalance account for the supplied security/currency
|
// imbalance account for the supplied security/currency
|
||||||
func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) {
|
func GetImbalanceAccount(tx *Tx, userid int64, securityid int64) (*Account, error) {
|
||||||
var imbalanceAccount Account
|
var imbalanceAccount Account
|
||||||
var account Account
|
var account Account
|
||||||
xxxtemplate := FindSecurityTemplate("XXX", Currency)
|
xxxtemplate := FindSecurityTemplate("XXX", Currency)
|
||||||
if xxxtemplate == nil {
|
if xxxtemplate == nil {
|
||||||
return nil, errors.New("Couldn't find XXX security template")
|
return nil, errors.New("Couldn't find XXX security template")
|
||||||
}
|
}
|
||||||
xxxsecurity, err := ImportGetCreateSecurity(transaction, userid, xxxtemplate)
|
xxxsecurity, err := ImportGetCreateSecurity(tx, userid, xxxtemplate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("Couldn't create XXX security")
|
return nil, errors.New("Couldn't create XXX security")
|
||||||
}
|
}
|
||||||
@ -251,12 +239,12 @@ func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid
|
|||||||
imbalanceAccount.Type = Bank
|
imbalanceAccount.Type = Bank
|
||||||
|
|
||||||
// Find/create the top-level trading account
|
// Find/create the top-level trading account
|
||||||
ia, err := GetCreateAccountTx(transaction, imbalanceAccount)
|
ia, err := GetCreateAccount(tx, imbalanceAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
security, err := GetSecurityTx(transaction, securityid, userid)
|
security, err := GetSecurity(tx, securityid, userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -267,7 +255,7 @@ func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid
|
|||||||
account.SecurityId = securityid
|
account.SecurityId = securityid
|
||||||
account.Type = Bank
|
account.Type = Bank
|
||||||
|
|
||||||
a, err := GetCreateAccountTx(transaction, account)
|
a, err := GetCreateAccount(tx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -293,12 +281,7 @@ func (cae CircularAccountsError) Error() string {
|
|||||||
return "Would result in circular account relationship"
|
return "Would result in circular account relationship"
|
||||||
}
|
}
|
||||||
|
|
||||||
func insertUpdateAccount(db *DB, a *Account, insert bool) error {
|
func insertUpdateAccount(tx *Tx, a *Account, insert bool) error {
|
||||||
transaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
found := make(map[int64]bool)
|
found := make(map[int64]bool)
|
||||||
if !insert {
|
if !insert {
|
||||||
found[a.AccountId] = true
|
found[a.AccountId] = true
|
||||||
@ -308,14 +291,12 @@ func insertUpdateAccount(db *DB, a *Account, insert bool) error {
|
|||||||
for parentid != -1 {
|
for parentid != -1 {
|
||||||
depth += 1
|
depth += 1
|
||||||
if depth > 100 {
|
if depth > 100 {
|
||||||
transaction.Rollback()
|
|
||||||
return TooMuchNestingError{}
|
return TooMuchNestingError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var a Account
|
var a Account
|
||||||
err := transaction.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid)
|
err := tx.SelectOne(&a, "SELECT * from accounts where AccountId=?", parentid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return ParentAccountMissingError{}
|
return ParentAccountMissingError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -327,107 +308,79 @@ func insertUpdateAccount(db *DB, a *Account, insert bool) error {
|
|||||||
found[parentid] = true
|
found[parentid] = true
|
||||||
parentid = a.ParentAccountId
|
parentid = a.ParentAccountId
|
||||||
if _, ok := found[parentid]; ok {
|
if _, ok := found[parentid]; ok {
|
||||||
transaction.Rollback()
|
|
||||||
return CircularAccountsError{}
|
return CircularAccountsError{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if insert {
|
if insert {
|
||||||
err = transaction.Insert(a)
|
err := tx.Insert(a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
oldacct, err := GetAccountTx(transaction, a.AccountId, a.UserId)
|
oldacct, err := GetAccount(tx, a.AccountId, a.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
a.AccountVersion = oldacct.AccountVersion + 1
|
a.AccountVersion = oldacct.AccountVersion + 1
|
||||||
|
|
||||||
count, err := transaction.Update(a)
|
count, err := tx.Update(a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if count != 1 {
|
if count != 1 {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("Updated more than one account")
|
return errors.New("Updated more than one account")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertAccount(db *DB, a *Account) error {
|
func InsertAccount(tx *Tx, a *Account) error {
|
||||||
return insertUpdateAccount(db, a, true)
|
return insertUpdateAccount(tx, a, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateAccount(db *DB, a *Account) error {
|
func UpdateAccount(tx *Tx, a *Account) error {
|
||||||
return insertUpdateAccount(db, a, false)
|
return insertUpdateAccount(tx, a, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteAccount(db *DB, a *Account) error {
|
func DeleteAccount(tx *Tx, a *Account) error {
|
||||||
transaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if a.ParentAccountId != -1 {
|
if a.ParentAccountId != -1 {
|
||||||
// Re-parent splits to this account's parent account if this account isn't a root account
|
// Re-parent splits to this account's parent account if this account isn't a root account
|
||||||
_, err = transaction.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId)
|
_, err := tx.Exec("UPDATE splits SET AccountId=? WHERE AccountId=?", a.ParentAccountId, a.AccountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Delete splits if this account is a root account
|
// Delete splits if this account is a root account
|
||||||
_, err = transaction.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId)
|
_, err := tx.Exec("DELETE FROM splits WHERE AccountId=?", a.AccountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Re-parent child accounts to this account's parent account
|
// Re-parent child accounts to this account's parent account
|
||||||
_, err = transaction.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId)
|
_, err := tx.Exec("UPDATE accounts SET ParentAccountId=? WHERE ParentAccountId=?", a.ParentAccountId, a.AccountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := transaction.Delete(a)
|
count, err := tx.Delete(a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if count != 1 {
|
if count != 1 {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("Was going to delete more than one account")
|
return errors.New("Was going to delete more than one account")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
func AccountHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
|
||||||
user, err := GetUserFromSession(db, r)
|
user, err := GetUserFromSession(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
return NewError(1 /*Not Signed In*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method == "POST" {
|
if r.Method == "POST" {
|
||||||
@ -439,59 +392,46 @@ func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
n, err := GetURLPieces(r.URL.Path, "/account/%d/import/%s", &accountid, &importtype)
|
n, err := GetURLPieces(r.URL.Path, "/account/%d/import/%s", &accountid, &importtype)
|
||||||
|
|
||||||
if err != nil || n != 2 {
|
if err != nil || n != 2 {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
AccountImportHandler(db, w, r, user, accountid, importtype)
|
return AccountImportHandler(tx, r, user, accountid, importtype)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
account_json := r.PostFormValue("account")
|
account_json := r.PostFormValue("account")
|
||||||
if account_json == "" {
|
if account_json == "" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var account Account
|
var account Account
|
||||||
err := account.Read(account_json)
|
err := account.Read(account_json)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
account.AccountId = -1
|
account.AccountId = -1
|
||||||
account.UserId = user.UserId
|
account.UserId = user.UserId
|
||||||
account.AccountVersion = 0
|
account.AccountVersion = 0
|
||||||
|
|
||||||
security, err := GetSecurity(db, account.SecurityId, user.UserId)
|
security, err := GetSecurity(tx, account.SecurityId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
if security == nil {
|
if security == nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = InsertAccount(db, &account)
|
err = InsertAccount(tx, &account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(ParentAccountMissingError); ok {
|
if _, ok := err.(ParentAccountMissingError); ok {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
} else {
|
} else {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(201 /*Created*/)
|
return ResponseWrapper{201, &account}
|
||||||
err = account.Write(w)
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if r.Method == "GET" {
|
} else if r.Method == "GET" {
|
||||||
var accountid int64
|
var accountid int64
|
||||||
n, err := GetURLPieces(r.URL.Path, "/account/%d", &accountid)
|
n, err := GetURLPieces(r.URL.Path, "/account/%d", &accountid)
|
||||||
@ -499,112 +439,86 @@ func AccountHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
if err != nil || n != 1 {
|
if err != nil || n != 1 {
|
||||||
//Return all Accounts
|
//Return all Accounts
|
||||||
var al AccountList
|
var al AccountList
|
||||||
accounts, err := GetAccounts(db, user.UserId)
|
accounts, err := GetAccounts(tx, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
al.Accounts = accounts
|
al.Accounts = accounts
|
||||||
err = (&al).Write(w)
|
return &al
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// if URL looks like /account/[0-9]+/transactions, use the account
|
// if URL looks like /account/[0-9]+/transactions, use the account
|
||||||
// transaction handler
|
// transaction handler
|
||||||
if accountTransactionsRE.MatchString(r.URL.Path) {
|
if accountTransactionsRE.MatchString(r.URL.Path) {
|
||||||
AccountTransactionsHandler(db, w, r, user, accountid)
|
return AccountTransactionsHandler(tx, r, user, accountid)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return Account with this Id
|
// Return Account with this Id
|
||||||
account, err := GetAccount(db, accountid, user.UserId)
|
account, err := GetAccount(tx, accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = account.Write(w)
|
return account
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
accountid, err := GetURLID(r.URL.Path)
|
accountid, err := GetURLID(r.URL.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if r.Method == "PUT" {
|
if r.Method == "PUT" {
|
||||||
account_json := r.PostFormValue("account")
|
account_json := r.PostFormValue("account")
|
||||||
if account_json == "" {
|
if account_json == "" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var account Account
|
var account Account
|
||||||
err := account.Read(account_json)
|
err := account.Read(account_json)
|
||||||
if err != nil || account.AccountId != accountid {
|
if err != nil || account.AccountId != accountid {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
account.UserId = user.UserId
|
account.UserId = user.UserId
|
||||||
|
|
||||||
security, err := GetSecurity(db, account.SecurityId, user.UserId)
|
security, err := GetSecurity(tx, account.SecurityId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
if security == nil {
|
if security == nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.ParentAccountId == account.AccountId {
|
if account.ParentAccountId == account.AccountId {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = UpdateAccount(db, &account)
|
err = UpdateAccount(tx, &account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(ParentAccountMissingError); ok {
|
if _, ok := err.(ParentAccountMissingError); ok {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
} else if _, ok := err.(CircularAccountsError); ok {
|
} else if _, ok := err.(CircularAccountsError); ok {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
} else {
|
} else {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = account.Write(w)
|
return &account
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
account, err := GetAccount(db, accountid, user.UserId)
|
account, err := GetAccount(tx, accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteAccount(db, account)
|
err = DeleteAccount(tx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
WriteSuccess(w)
|
return SuccessWriter{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
@ -15,9 +15,9 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) {
|
|||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
|
|
||||||
db, ok := ctx.Value(dbContextKey).(*DB)
|
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Couldn't find DB in lua's Context")
|
return nil, errors.New("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
|
|
||||||
account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account)
|
account_map, ok = ctx.Value(accountsContextKey).(map[int64]*Account)
|
||||||
@ -27,7 +27,7 @@ func luaContextGetAccounts(L *lua.LState) (map[int64]*Account, error) {
|
|||||||
return nil, errors.New("Couldn't find User in lua's Context")
|
return nil, errors.New("Couldn't find User in lua's Context")
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, err := GetAccounts(db, user.UserId)
|
accounts, err := GetAccounts(tx, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -149,9 +149,9 @@ func luaAccountBalance(L *lua.LState) int {
|
|||||||
a := luaCheckAccount(L, 1)
|
a := luaCheckAccount(L, 1)
|
||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
db, ok := ctx.Value(dbContextKey).(*DB)
|
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("Couldn't find DB in lua's Context")
|
panic("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
user, ok := ctx.Value(userContextKey).(*User)
|
user, ok := ctx.Value(userContextKey).(*User)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -171,12 +171,12 @@ func luaAccountBalance(L *lua.LState) int {
|
|||||||
if date != nil {
|
if date != nil {
|
||||||
end := luaWeakCheckTime(L, 3)
|
end := luaWeakCheckTime(L, 3)
|
||||||
if end != nil {
|
if end != nil {
|
||||||
rat, err = GetAccountBalanceDateRange(db, user, a.AccountId, date, end)
|
rat, err = GetAccountBalanceDateRange(tx, user, a.AccountId, date, end)
|
||||||
} else {
|
} else {
|
||||||
rat, err = GetAccountBalanceDate(db, user, a.AccountId, date)
|
rat, err = GetAccountBalanceDate(tx, user, a.AccountId, date)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
rat, err = GetAccountBalance(db, user, a.AccountId)
|
rat, err = GetAccountBalance(tx, user, a.AccountId)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("Failed to GetAccountBalance:" + err.Error())
|
panic("Failed to GetAccountBalance:" + err.Error())
|
||||||
|
@ -38,13 +38,17 @@ var error_codes = map[int]string{
|
|||||||
999: "Internal Error",
|
999: "Internal Error",
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteError(w http.ResponseWriter, error_code int) {
|
func NewError(error_code int) *Error {
|
||||||
msg, ok := error_codes[error_code]
|
msg, ok := error_codes[error_code]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("Error: WriteError received error code of %d", error_code)
|
log.Printf("Error: WriteError received error code of %d", error_code)
|
||||||
msg = error_codes[999]
|
msg = error_codes[999]
|
||||||
}
|
}
|
||||||
e := Error{error_code, msg}
|
return &Error{error_code, msg}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteError(w http.ResponseWriter, error_code int) {
|
||||||
|
e := NewError(error_code)
|
||||||
|
|
||||||
err := e.Write(w)
|
err := e.Write(w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -308,42 +308,37 @@ func ImportGnucash(r io.Reader) (*GnucashImport, error) {
|
|||||||
return &gncimport, nil
|
return &gncimport, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
func GnucashImportHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
|
||||||
user, err := GetUserFromSession(db, r)
|
user, err := GetUserFromSession(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
return NewError(1 /*Not Signed In*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method != "POST" {
|
if r.Method != "POST" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
multipartReader, err := r.MultipartReader()
|
multipartReader, err := r.MultipartReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assume there is only one 'part' and it's the one we care about
|
// Assume there is only one 'part' and it's the one we care about
|
||||||
part, err := multipartReader.NextPart()
|
part, err := multipartReader.NextPart()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
} else {
|
} else {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bufread := bufio.NewReader(part)
|
bufread := bufio.NewReader(part)
|
||||||
gzHeader, err := bufread.Peek(2)
|
gzHeader, err := bufread.Peek(2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Does this look like a gzipped file?
|
// Does this look like a gzipped file?
|
||||||
@ -351,9 +346,8 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
if gzHeader[0] == 0x1f && gzHeader[1] == 0x8b {
|
if gzHeader[0] == 0x1f && gzHeader[1] == 0x8b {
|
||||||
gzr, err := gzip.NewReader(bufread)
|
gzr, err := gzip.NewReader(bufread)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
gnucashImport, err = ImportGnucash(gzr)
|
gnucashImport, err = ImportGnucash(gzr)
|
||||||
} else {
|
} else {
|
||||||
@ -361,15 +355,7 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sqltransaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Import securities, building map from Gnucash security IDs to our
|
// Import securities, building map from Gnucash security IDs to our
|
||||||
@ -377,13 +363,11 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
securityMap := make(map[int64]int64)
|
securityMap := make(map[int64]int64)
|
||||||
for _, security := range gnucashImport.Securities {
|
for _, security := range gnucashImport.Securities {
|
||||||
securityId := security.SecurityId // save off because it could be updated
|
securityId := security.SecurityId // save off because it could be updated
|
||||||
s, err := ImportGetCreateSecurity(sqltransaction, user.UserId, &security)
|
s, err := ImportGetCreateSecurity(tx, user.UserId, &security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 6 /*Import Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
log.Print(security)
|
log.Print(security)
|
||||||
return
|
return NewError(6 /*Import Error*/)
|
||||||
}
|
}
|
||||||
securityMap[securityId] = s.SecurityId
|
securityMap[securityId] = s.SecurityId
|
||||||
}
|
}
|
||||||
@ -394,12 +378,10 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
price.CurrencyId = securityMap[price.CurrencyId]
|
price.CurrencyId = securityMap[price.CurrencyId]
|
||||||
price.PriceId = 0
|
price.PriceId = 0
|
||||||
|
|
||||||
err := CreatePriceIfNotExist(sqltransaction, &price)
|
err := CreatePriceIfNotExist(tx, &price)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 6 /*Import Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(6 /*Import Error*/)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -425,12 +407,10 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
account.ParentAccountId = accountMap[account.ParentAccountId]
|
account.ParentAccountId = accountMap[account.ParentAccountId]
|
||||||
}
|
}
|
||||||
account.SecurityId = securityMap[account.SecurityId]
|
account.SecurityId = securityMap[account.SecurityId]
|
||||||
a, err := GetCreateAccountTx(sqltransaction, account)
|
a, err := GetCreateAccount(tx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
accountMap[account.AccountId] = a.AccountId
|
accountMap[account.AccountId] = a.AccountId
|
||||||
accountsRemaining--
|
accountsRemaining--
|
||||||
@ -438,10 +418,8 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
}
|
}
|
||||||
if accountsRemaining == accountsRemainingLast {
|
if accountsRemaining == accountsRemainingLast {
|
||||||
//We didn't make any progress in importing the next level of accounts, so there must be a circular parent-child relationship, so give up and tell the user they're wrong
|
//We didn't make any progress in importing the next level of accounts, so there must be a circular parent-child relationship, so give up and tell the user they're wrong
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(fmt.Errorf("Circular account parent-child relationship when importing %s", part.FileName()))
|
log.Print(fmt.Errorf("Circular account parent-child relationship when importing %s", part.FileName()))
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
accountsRemainingLast = accountsRemaining
|
accountsRemainingLast = accountsRemaining
|
||||||
}
|
}
|
||||||
@ -453,41 +431,27 @@ func GnucashImportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
for _, split := range transaction.Splits {
|
for _, split := range transaction.Splits {
|
||||||
acctId, ok := accountMap[split.AccountId]
|
acctId, ok := accountMap[split.AccountId]
|
||||||
if !ok {
|
if !ok {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(fmt.Errorf("Error: Split's AccountID Doesn't exist: %d\n", split.AccountId))
|
log.Print(fmt.Errorf("Error: Split's AccountID Doesn't exist: %d\n", split.AccountId))
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
split.AccountId = acctId
|
split.AccountId = acctId
|
||||||
|
|
||||||
exists, err := split.AlreadyImportedTx(sqltransaction)
|
exists, err := split.AlreadyImported(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print("Error checking if split was already imported:", err)
|
log.Print("Error checking if split was already imported:", err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
} else if exists {
|
} else if exists {
|
||||||
already_imported = true
|
already_imported = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !already_imported {
|
if !already_imported {
|
||||||
err := InsertTransactionTx(sqltransaction, &transaction, user)
|
err := InsertTransaction(tx, &transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = sqltransaction.Commit()
|
return SuccessWriter{}
|
||||||
if err != nil {
|
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
WriteSuccess(w)
|
|
||||||
}
|
}
|
||||||
|
@ -2,30 +2,64 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"gopkg.in/gorp.v1"
|
"gopkg.in/gorp.v1"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create a closure over db, allowing the handlers to look like a
|
// But who writes the ResponseWriterWriter?
|
||||||
// http.HandlerFunc
|
type ResponseWriterWriter interface {
|
||||||
type DB = gorp.DbMap
|
Write(http.ResponseWriter) error
|
||||||
type DBHandler func(http.ResponseWriter, *http.Request, *DB)
|
}
|
||||||
|
type Tx = gorp.Transaction
|
||||||
|
type TxHandler func(*http.Request, *Tx) ResponseWriterWriter
|
||||||
|
|
||||||
func DBHandlerFunc(h DBHandler, db *DB) http.HandlerFunc {
|
func TxHandlerFunc(t TxHandler, db *gorp.DbMap) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
h(w, r, db)
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
writer := t(r, tx)
|
||||||
|
|
||||||
|
if e, ok := writer.(*Error); ok {
|
||||||
|
tx.Rollback()
|
||||||
|
e.Write(w)
|
||||||
|
} else {
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
|
} else {
|
||||||
|
err = writer.Write(w)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
WriteError(w, 999 /*Internal Error*/)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetHandler(db *DB) *http.ServeMux {
|
func GetHandler(db *gorp.DbMap) *http.ServeMux {
|
||||||
servemux := http.NewServeMux()
|
servemux := http.NewServeMux()
|
||||||
servemux.HandleFunc("/session/", DBHandlerFunc(SessionHandler, db))
|
servemux.HandleFunc("/session/", TxHandlerFunc(SessionHandler, db))
|
||||||
servemux.HandleFunc("/user/", DBHandlerFunc(UserHandler, db))
|
servemux.HandleFunc("/user/", TxHandlerFunc(UserHandler, db))
|
||||||
servemux.HandleFunc("/security/", DBHandlerFunc(SecurityHandler, db))
|
servemux.HandleFunc("/security/", TxHandlerFunc(SecurityHandler, db))
|
||||||
servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler)
|
servemux.HandleFunc("/securitytemplate/", SecurityTemplateHandler)
|
||||||
servemux.HandleFunc("/account/", DBHandlerFunc(AccountHandler, db))
|
servemux.HandleFunc("/account/", TxHandlerFunc(AccountHandler, db))
|
||||||
servemux.HandleFunc("/transaction/", DBHandlerFunc(TransactionHandler, db))
|
servemux.HandleFunc("/transaction/", TxHandlerFunc(TransactionHandler, db))
|
||||||
servemux.HandleFunc("/import/gnucash", DBHandlerFunc(GnucashImportHandler, db))
|
servemux.HandleFunc("/import/gnucash", TxHandlerFunc(GnucashImportHandler, db))
|
||||||
servemux.HandleFunc("/report/", DBHandlerFunc(ReportHandler, db))
|
servemux.HandleFunc("/report/", TxHandlerFunc(ReportHandler, db))
|
||||||
|
|
||||||
return servemux
|
return servemux
|
||||||
}
|
}
|
||||||
|
@ -22,48 +22,35 @@ func (od *OFXDownload) Read(json_str string) error {
|
|||||||
return dec.Decode(od)
|
return dec.Decode(od)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, accountid int64) {
|
func ofxImportHelper(tx *Tx, r io.Reader, user *User, accountid int64) ResponseWriterWriter {
|
||||||
itl, err := ImportOFX(r)
|
itl, err := ImportOFX(r)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//TODO is this necessarily an invalid request (what if it was an error on our end)?
|
//TODO is this necessarily an invalid request (what if it was an error on our end)?
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(itl.Accounts) != 1 {
|
if len(itl.Accounts) != 1 {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
log.Printf("Found %d accounts when importing OFX, expected 1", len(itl.Accounts))
|
log.Printf("Found %d accounts when importing OFX, expected 1", len(itl.Accounts))
|
||||||
return
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
|
||||||
|
|
||||||
sqltransaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return Account with this Id
|
// Return Account with this Id
|
||||||
account, err := GetAccountTx(sqltransaction, accountid, user.UserId)
|
account, err := GetAccount(tx, accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
importedAccount := itl.Accounts[0]
|
importedAccount := itl.Accounts[0]
|
||||||
|
|
||||||
if len(account.ExternalAccountId) > 0 &&
|
if len(account.ExternalAccountId) > 0 &&
|
||||||
account.ExternalAccountId != importedAccount.ExternalAccountId {
|
account.ExternalAccountId != importedAccount.ExternalAccountId {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
log.Printf("OFX import has \"%s\" as ExternalAccountId, but the account being imported to has\"%s\"",
|
log.Printf("OFX import has \"%s\" as ExternalAccountId, but the account being imported to has\"%s\"",
|
||||||
importedAccount.ExternalAccountId,
|
importedAccount.ExternalAccountId,
|
||||||
account.ExternalAccountId)
|
account.ExternalAccountId)
|
||||||
return
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find matching existing securities or create new ones for those
|
// Find matching existing securities or create new ones for those
|
||||||
@ -74,21 +61,17 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc
|
|||||||
// save off since ImportGetCreateSecurity overwrites SecurityId on
|
// save off since ImportGetCreateSecurity overwrites SecurityId on
|
||||||
// ofxsecurity
|
// ofxsecurity
|
||||||
oldsecurityid := ofxsecurity.SecurityId
|
oldsecurityid := ofxsecurity.SecurityId
|
||||||
security, err := ImportGetCreateSecurity(sqltransaction, user.UserId, &ofxsecurity)
|
security, err := ImportGetCreateSecurity(tx, user.UserId, &ofxsecurity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
securitymap[oldsecurityid] = *security
|
securitymap[oldsecurityid] = *security
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.SecurityId != securitymap[importedAccount.SecurityId].SecurityId {
|
if account.SecurityId != securitymap[importedAccount.SecurityId].SecurityId {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
log.Printf("OFX import account's SecurityId (%d) does not match this account's (%d)", securitymap[importedAccount.SecurityId].SecurityId, account.SecurityId)
|
log.Printf("OFX import account's SecurityId (%d) does not match this account's (%d)", securitymap[importedAccount.SecurityId].SecurityId, account.SecurityId)
|
||||||
return
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO Ensure all transactions have at least one split in the account
|
// TODO Ensure all transactions have at least one split in the account
|
||||||
@ -99,10 +82,8 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc
|
|||||||
transaction.UserId = user.UserId
|
transaction.UserId = user.UserId
|
||||||
|
|
||||||
if !transaction.Valid() {
|
if !transaction.Valid() {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print("Unexpected invalid transaction from OFX import")
|
log.Print("Unexpected invalid transaction from OFX import")
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure that either AccountId or SecurityId is set for this split,
|
// Ensure that either AccountId or SecurityId is set for this split,
|
||||||
@ -112,10 +93,8 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc
|
|||||||
split.Status = Imported
|
split.Status = Imported
|
||||||
if split.AccountId != -1 {
|
if split.AccountId != -1 {
|
||||||
if split.AccountId != importedAccount.AccountId {
|
if split.AccountId != importedAccount.AccountId {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print("Imported split's AccountId wasn't -1 but also didn't match the account")
|
log.Print("Imported split's AccountId wasn't -1 but also didn't match the account")
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
split.AccountId = account.AccountId
|
split.AccountId = account.AccountId
|
||||||
} else if split.SecurityId != -1 {
|
} else if split.SecurityId != -1 {
|
||||||
@ -123,12 +102,10 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc
|
|||||||
// TODO try to auto-match splits to existing accounts based on past transactions that look like this one
|
// TODO try to auto-match splits to existing accounts based on past transactions that look like this one
|
||||||
if split.ImportSplitType == TradingAccount {
|
if split.ImportSplitType == TradingAccount {
|
||||||
// Find/make trading account if we're that type of split
|
// Find/make trading account if we're that type of split
|
||||||
trading_account, err := GetTradingAccount(sqltransaction, user.UserId, sec.SecurityId)
|
trading_account, err := GetTradingAccount(tx, user.UserId, sec.SecurityId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print("Couldn't find split's SecurityId in map during OFX import")
|
log.Print("Couldn't find split's SecurityId in map during OFX import")
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
split.AccountId = trading_account.AccountId
|
split.AccountId = trading_account.AccountId
|
||||||
split.SecurityId = -1
|
split.SecurityId = -1
|
||||||
@ -140,12 +117,10 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc
|
|||||||
SecurityId: sec.SecurityId,
|
SecurityId: sec.SecurityId,
|
||||||
Type: account.Type,
|
Type: account.Type,
|
||||||
}
|
}
|
||||||
subaccount, err := GetCreateAccountTx(sqltransaction, *subaccount)
|
subaccount, err := GetCreateAccount(tx, *subaccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
split.AccountId = subaccount.AccountId
|
split.AccountId = subaccount.AccountId
|
||||||
split.SecurityId = -1
|
split.SecurityId = -1
|
||||||
@ -153,49 +128,39 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc
|
|||||||
split.SecurityId = sec.SecurityId
|
split.SecurityId = sec.SecurityId
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print("Couldn't find split's SecurityId in map during OFX import")
|
log.Print("Couldn't find split's SecurityId in map during OFX import")
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print("Neither Split.AccountId Split.SecurityId was set during OFX import")
|
log.Print("Neither Split.AccountId Split.SecurityId was set during OFX import")
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
imbalances, err := transaction.GetImbalancesTx(sqltransaction)
|
imbalances, err := transaction.GetImbalances(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fixup any imbalances in transactions
|
// Fixup any imbalances in transactions
|
||||||
var zero big.Rat
|
var zero big.Rat
|
||||||
for imbalanced_security, imbalance := range imbalances {
|
for imbalanced_security, imbalance := range imbalances {
|
||||||
if imbalance.Cmp(&zero) != 0 {
|
if imbalance.Cmp(&zero) != 0 {
|
||||||
imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, imbalanced_security)
|
imbalanced_account, err := GetImbalanceAccount(tx, user.UserId, imbalanced_security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new split to fixup imbalance
|
// Add new split to fixup imbalance
|
||||||
split := new(Split)
|
split := new(Split)
|
||||||
r := new(big.Rat)
|
r := new(big.Rat)
|
||||||
r.Neg(&imbalance)
|
r.Neg(&imbalance)
|
||||||
security, err := GetSecurityTx(sqltransaction, imbalanced_security, user.UserId)
|
security, err := GetSecurity(tx, imbalanced_security, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
split.Amount = r.FloatString(security.Precision)
|
split.Amount = r.FloatString(security.Precision)
|
||||||
split.SecurityId = -1
|
split.SecurityId = -1
|
||||||
@ -210,24 +175,20 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc
|
|||||||
var already_imported bool
|
var already_imported bool
|
||||||
for _, split := range transaction.Splits {
|
for _, split := range transaction.Splits {
|
||||||
if split.SecurityId != -1 || split.AccountId == -1 {
|
if split.SecurityId != -1 || split.AccountId == -1 {
|
||||||
imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, split.SecurityId)
|
imbalanced_account, err := GetImbalanceAccount(tx, user.UserId, split.SecurityId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
split.AccountId = imbalanced_account.AccountId
|
split.AccountId = imbalanced_account.AccountId
|
||||||
split.SecurityId = -1
|
split.SecurityId = -1
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err := split.AlreadyImportedTx(sqltransaction)
|
exists, err := split.AlreadyImported(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print("Error checking if split was already imported:", err)
|
log.Print("Error checking if split was already imported:", err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
} else if exists {
|
} else if exists {
|
||||||
already_imported = true
|
already_imported = true
|
||||||
}
|
}
|
||||||
@ -239,55 +200,38 @@ func ofxImportHelper(db *DB, r io.Reader, w http.ResponseWriter, user *User, acc
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, transaction := range transactions {
|
for _, transaction := range transactions {
|
||||||
err := InsertTransactionTx(sqltransaction, &transaction, user)
|
err := InsertTransaction(tx, &transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = sqltransaction.Commit()
|
return SuccessWriter{}
|
||||||
if err != nil {
|
|
||||||
sqltransaction.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
WriteSuccess(w)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
|
func OFXImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter {
|
||||||
download_json := r.PostFormValue("ofxdownload")
|
download_json := r.PostFormValue("ofxdownload")
|
||||||
if download_json == "" {
|
if download_json == "" {
|
||||||
log.Print("download_json")
|
return NewError(3 /*Invalid Request*/)
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var ofxdownload OFXDownload
|
var ofxdownload OFXDownload
|
||||||
err := ofxdownload.Read(download_json)
|
err := ofxdownload.Read(download_json)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print("ofxdownload.Read")
|
return NewError(3 /*Invalid Request*/)
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := GetAccount(db, accountid, user.UserId)
|
account, err := GetAccount(tx, accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print("GetAccount")
|
return NewError(3 /*Invalid Request*/)
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ofxver := ofxgo.OfxVersion203
|
ofxver := ofxgo.OfxVersion203
|
||||||
if len(account.OFXVersion) != 0 {
|
if len(account.OFXVersion) != 0 {
|
||||||
ofxver, err = ofxgo.NewOfxVersion(account.OFXVersion)
|
ofxver, err = ofxgo.NewOfxVersion(account.OFXVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print("NewOfxVersion")
|
return NewError(3 /*Invalid Request*/)
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -308,9 +252,8 @@ func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User
|
|||||||
|
|
||||||
transactionuid, err := ofxgo.RandomUID()
|
transactionuid, err := ofxgo.RandomUID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Println("Error creating uid for transaction:", err)
|
log.Println("Error creating uid for transaction:", err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Type == Investment {
|
if account.Type == Investment {
|
||||||
@ -343,8 +286,7 @@ func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User
|
|||||||
// Import generic bank transactions
|
// Import generic bank transactions
|
||||||
acctTypeEnum, err := ofxgo.NewAcctType(account.OFXAcctType)
|
acctTypeEnum, err := ofxgo.NewAcctType(account.OFXAcctType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
statementRequest := ofxgo.StatementRequest{
|
statementRequest := ofxgo.StatementRequest{
|
||||||
TrnUID: *transactionuid,
|
TrnUID: *transactionuid,
|
||||||
@ -361,49 +303,46 @@ func OFXImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User
|
|||||||
response, err := client.RequestNoParse(&query)
|
response, err := client.RequestNoParse(&query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO this could be an error talking with the OFX server...
|
// TODO this could be an error talking with the OFX server...
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
ofxImportHelper(db, response.Body, w, user, accountid)
|
return ofxImportHelper(tx, response.Body, user, accountid)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OFXFileImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
|
func OFXFileImportHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter {
|
||||||
multipartReader, err := r.MultipartReader()
|
multipartReader, err := r.MultipartReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// assume there is only one 'part'
|
// assume there is only one 'part'
|
||||||
part, err := multipartReader.NextPart()
|
part, err := multipartReader.NextPart()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
log.Print("Encountered unexpected EOF")
|
log.Print("Encountered unexpected EOF")
|
||||||
} else {
|
} else {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ofxImportHelper(db, part, w, user, accountid)
|
return ofxImportHelper(tx, part, user, accountid)
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated
|
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated
|
||||||
*/
|
*/
|
||||||
func AccountImportHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) {
|
func AccountImportHandler(tx *Tx, r *http.Request, user *User, accountid int64, importtype string) ResponseWriterWriter {
|
||||||
|
|
||||||
switch importtype {
|
switch importtype {
|
||||||
case "ofx":
|
case "ofx":
|
||||||
OFXImportHandler(db, w, r, user, accountid)
|
return OFXImportHandler(tx, r, user, accountid)
|
||||||
case "ofxfile":
|
case "ofxfile":
|
||||||
OFXFileImportHandler(db, w, r, user, accountid)
|
return OFXFileImportHandler(tx, r, user, accountid)
|
||||||
default:
|
default:
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"gopkg.in/gorp.v1"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -14,18 +13,18 @@ type Price struct {
|
|||||||
RemoteId string // unique ID from source, for detecting duplicates
|
RemoteId string // unique ID from source, for detecting duplicates
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertPriceTx(transaction *gorp.Transaction, p *Price) error {
|
func InsertPrice(tx *Tx, p *Price) error {
|
||||||
err := transaction.Insert(p)
|
err := tx.Insert(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error {
|
func CreatePriceIfNotExist(tx *Tx, price *Price) error {
|
||||||
if len(price.RemoteId) == 0 {
|
if len(price.RemoteId) == 0 {
|
||||||
// Always create a new price if we can't match on the RemoteId
|
// Always create a new price if we can't match on the RemoteId
|
||||||
err := InsertPriceTx(transaction, price)
|
err := InsertPrice(tx, price)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -34,7 +33,7 @@ func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error {
|
|||||||
|
|
||||||
var prices []*Price
|
var prices []*Price
|
||||||
|
|
||||||
_, err := transaction.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value)
|
_, err := tx.Select(&prices, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date=? AND Value=?", price.SecurityId, price.CurrencyId, price.Date, price.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -43,7 +42,7 @@ func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error {
|
|||||||
return nil // price already exists
|
return nil // price already exists
|
||||||
}
|
}
|
||||||
|
|
||||||
err = InsertPriceTx(transaction, price)
|
err = InsertPrice(tx, price)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -51,9 +50,9 @@ func CreatePriceIfNotExist(transaction *gorp.Transaction, price *Price) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the latest price for security in currency units before date
|
// Return the latest price for security in currency units before date
|
||||||
func GetLatestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) {
|
func GetLatestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) {
|
||||||
var p Price
|
var p Price
|
||||||
err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date <= ? ORDER BY Date DESC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -61,9 +60,9 @@ func GetLatestPrice(transaction *gorp.Transaction, security, currency *Security,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the earliest price for security in currency units after date
|
// Return the earliest price for security in currency units after date
|
||||||
func GetEarliestPrice(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) {
|
func GetEarliestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) {
|
||||||
var p Price
|
var p Price
|
||||||
err := transaction.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
err := tx.SelectOne(&p, "SELECT * from prices where SecurityId=? AND CurrencyId=? AND Date >= ? ORDER BY Date ASC LIMIT 1", security.SecurityId, currency.SecurityId, date)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -71,9 +70,9 @@ func GetEarliestPrice(transaction *gorp.Transaction, security, currency *Securit
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the price for security in currency closest to date
|
// Return the price for security in currency closest to date
|
||||||
func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Security, date *time.Time) (*Price, error) {
|
func GetClosestPrice(tx *Tx, security, currency *Security, date *time.Time) (*Price, error) {
|
||||||
earliest, _ := GetEarliestPrice(transaction, security, currency, date)
|
earliest, _ := GetEarliestPrice(tx, security, currency, date)
|
||||||
latest, err := GetLatestPrice(transaction, security, currency, date)
|
latest, err := GetLatestPrice(tx, security, currency, date)
|
||||||
|
|
||||||
// Return early if either earliest or latest are invalid
|
// Return early if either earliest or latest are invalid
|
||||||
if earliest == nil {
|
if earliest == nil {
|
||||||
@ -90,24 +89,3 @@ func GetClosestPriceTx(transaction *gorp.Transaction, security, currency *Securi
|
|||||||
return earliest, nil
|
return earliest, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetClosestPrice(db *DB, security, currency *Security, date *time.Time) (*Price, error) {
|
|
||||||
transaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
price, err := GetClosestPriceTx(transaction, security, currency, date)
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return price, nil
|
|
||||||
}
|
|
||||||
|
@ -77,36 +77,36 @@ func (r *Tabulation) Write(w http.ResponseWriter) error {
|
|||||||
return enc.Encode(r)
|
return enc.Encode(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetReport(db *DB, reportid int64, userid int64) (*Report, error) {
|
func GetReport(tx *Tx, reportid int64, userid int64) (*Report, error) {
|
||||||
var r Report
|
var r Report
|
||||||
|
|
||||||
err := db.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
err := tx.SelectOne(&r, "SELECT * from reports where UserId=? AND ReportId=?", userid, reportid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &r, nil
|
return &r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetReports(db *DB, userid int64) (*[]Report, error) {
|
func GetReports(tx *Tx, userid int64) (*[]Report, error) {
|
||||||
var reports []Report
|
var reports []Report
|
||||||
|
|
||||||
_, err := db.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
_, err := tx.Select(&reports, "SELECT * from reports where UserId=?", userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &reports, nil
|
return &reports, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertReport(db *DB, r *Report) error {
|
func InsertReport(tx *Tx, r *Report) error {
|
||||||
err := db.Insert(r)
|
err := tx.Insert(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateReport(db *DB, r *Report) error {
|
func UpdateReport(tx *Tx, r *Report) error {
|
||||||
count, err := db.Update(r)
|
count, err := tx.Update(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -116,8 +116,8 @@ func UpdateReport(db *DB, r *Report) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteReport(db *DB, r *Report) error {
|
func DeleteReport(tx *Tx, r *Report) error {
|
||||||
count, err := db.Delete(r)
|
count, err := tx.Delete(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -127,14 +127,14 @@ func DeleteReport(db *DB, r *Report) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func runReport(db *DB, user *User, report *Report) (*Tabulation, error) {
|
func runReport(tx *Tx, user *User, report *Report) (*Tabulation, error) {
|
||||||
// Create a new LState without opening the default libs for security
|
// Create a new LState without opening the default libs for security
|
||||||
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
L := lua.NewState(lua.Options{SkipOpenLibs: true})
|
||||||
defer L.Close()
|
defer L.Close()
|
||||||
|
|
||||||
// Create a new context holding the current user with a timeout
|
// Create a new context holding the current user with a timeout
|
||||||
ctx := context.WithValue(context.Background(), userContextKey, user)
|
ctx := context.WithValue(context.Background(), userContextKey, user)
|
||||||
ctx = context.WithValue(ctx, dbContextKey, db)
|
ctx = context.WithValue(ctx, dbContextKey, tx)
|
||||||
ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, luaTimeoutSeconds*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
L.SetContext(ctx)
|
L.SetContext(ctx)
|
||||||
@ -191,79 +191,60 @@ func runReport(db *DB, user *User, report *Report) (*Tabulation, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReportTabulationHandler(db *DB, w http.ResponseWriter, r *http.Request, user *User, reportid int64) {
|
func ReportTabulationHandler(tx *Tx, r *http.Request, user *User, reportid int64) ResponseWriterWriter {
|
||||||
report, err := GetReport(db, reportid, user.UserId)
|
report, err := GetReport(tx, reportid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tabulation, err := runReport(db, user, report)
|
tabulation, err := runReport(tx, user, report)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO handle different failure cases differently
|
// TODO handle different failure cases differently
|
||||||
log.Print("runReport returned:", err)
|
log.Print("runReport returned:", err)
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tabulation.ReportId = reportid
|
tabulation.ReportId = reportid
|
||||||
|
|
||||||
err = tabulation.Write(w)
|
return tabulation
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
func ReportHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
|
||||||
user, err := GetUserFromSession(db, r)
|
user, err := GetUserFromSession(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
return NewError(1 /*Not Signed In*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method == "POST" {
|
if r.Method == "POST" {
|
||||||
report_json := r.PostFormValue("report")
|
report_json := r.PostFormValue("report")
|
||||||
if report_json == "" {
|
if report_json == "" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var report Report
|
var report Report
|
||||||
err := report.Read(report_json)
|
err := report.Read(report_json)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
report.ReportId = -1
|
report.ReportId = -1
|
||||||
report.UserId = user.UserId
|
report.UserId = user.UserId
|
||||||
|
|
||||||
err = InsertReport(db, &report)
|
err = InsertReport(tx, &report)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(201 /*Created*/)
|
return ResponseWrapper{201, &report}
|
||||||
err = report.Write(w)
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if r.Method == "GET" {
|
} else if r.Method == "GET" {
|
||||||
if reportTabulationRE.MatchString(r.URL.Path) {
|
if reportTabulationRE.MatchString(r.URL.Path) {
|
||||||
var reportid int64
|
var reportid int64
|
||||||
n, err := GetURLPieces(r.URL.Path, "/report/%d/tabulation", &reportid)
|
n, err := GetURLPieces(r.URL.Path, "/report/%d/tabulation", &reportid)
|
||||||
if err != nil || n != 1 {
|
if err != nil || n != 1 {
|
||||||
WriteError(w, 999 /*InternalError*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*InternalError*/)
|
||||||
}
|
}
|
||||||
ReportTabulationHandler(db, w, r, user, reportid)
|
return ReportTabulationHandler(tx, r, user, reportid)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var reportid int64
|
var reportid int64
|
||||||
@ -271,84 +252,62 @@ func ReportHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
if err != nil || n != 1 {
|
if err != nil || n != 1 {
|
||||||
//Return all Reports
|
//Return all Reports
|
||||||
var rl ReportList
|
var rl ReportList
|
||||||
reports, err := GetReports(db, user.UserId)
|
reports, err := GetReports(tx, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
rl.Reports = reports
|
rl.Reports = reports
|
||||||
err = (&rl).Write(w)
|
return &rl
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// Return Report with this Id
|
// Return Report with this Id
|
||||||
report, err := GetReport(db, reportid, user.UserId)
|
report, err := GetReport(tx, reportid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = report.Write(w)
|
return report
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
reportid, err := GetURLID(r.URL.Path)
|
reportid, err := GetURLID(r.URL.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method == "PUT" {
|
if r.Method == "PUT" {
|
||||||
report_json := r.PostFormValue("report")
|
report_json := r.PostFormValue("report")
|
||||||
if report_json == "" {
|
if report_json == "" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var report Report
|
var report Report
|
||||||
err := report.Read(report_json)
|
err := report.Read(report_json)
|
||||||
if err != nil || report.ReportId != reportid {
|
if err != nil || report.ReportId != reportid {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
report.UserId = user.UserId
|
report.UserId = user.UserId
|
||||||
|
|
||||||
err = UpdateReport(db, &report)
|
err = UpdateReport(tx, &report)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = report.Write(w)
|
return &report
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
report, err := GetReport(db, reportid, user.UserId)
|
report, err := GetReport(tx, reportid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteReport(db, report)
|
err = DeleteReport(tx, report)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
WriteSuccess(w)
|
return SuccessWriter{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,6 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"gopkg.in/gorp.v1"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -103,83 +102,50 @@ func FindCurrencyTemplate(iso4217 int64) *Security {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSecurity(db *DB, securityid int64, userid int64) (*Security, error) {
|
func GetSecurity(tx *Tx, securityid int64, userid int64) (*Security, error) {
|
||||||
var s Security
|
var s Security
|
||||||
|
|
||||||
err := db.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
err := tx.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSecurityTx(transaction *gorp.Transaction, securityid int64, userid int64) (*Security, error) {
|
func GetSecurities(tx *Tx, userid int64) (*[]*Security, error) {
|
||||||
var s Security
|
|
||||||
|
|
||||||
err := transaction.SelectOne(&s, "SELECT * from securities where UserId=? AND SecurityId=?", userid, securityid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetSecurities(db *DB, userid int64) (*[]*Security, error) {
|
|
||||||
var securities []*Security
|
var securities []*Security
|
||||||
|
|
||||||
_, err := db.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
_, err := tx.Select(&securities, "SELECT * from securities where UserId=?", userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &securities, nil
|
return &securities, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertSecurity(db *DB, s *Security) error {
|
func InsertSecurity(tx *Tx, s *Security) error {
|
||||||
err := db.Insert(s)
|
err := tx.Insert(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertSecurityTx(transaction *gorp.Transaction, s *Security) error {
|
func UpdateSecurity(tx *Tx, s *Security) (err error) {
|
||||||
err := transaction.Insert(s)
|
user, err := GetUser(tx, s.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func UpdateSecurity(db *DB, s *Security) error {
|
|
||||||
transaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := GetUserTx(transaction, s.UserId)
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
} else if user.DefaultCurrency == s.SecurityId && s.Type != Currency {
|
} else if user.DefaultCurrency == s.SecurityId && s.Type != Currency {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("Cannot change security which is user's default currency to be non-currency")
|
return errors.New("Cannot change security which is user's default currency to be non-currency")
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := transaction.Update(s)
|
count, err := tx.Update(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
return
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
if count != 1 {
|
if count != 1 {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("Updated more than one security")
|
return errors.New("Updated more than one security")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,61 +157,44 @@ func (e SecurityInUseError) Error() string {
|
|||||||
return e.message
|
return e.message
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteSecurity(db *DB, s *Security) error {
|
func DeleteSecurity(tx *Tx, s *Security) error {
|
||||||
transaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// First, ensure no accounts are using this security
|
// First, ensure no accounts are using this security
|
||||||
accounts, err := transaction.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
|
accounts, err := tx.SelectInt("SELECT count(*) from accounts where UserId=? and SecurityId=?", s.UserId, s.SecurityId)
|
||||||
|
|
||||||
if accounts != 0 {
|
if accounts != 0 {
|
||||||
transaction.Rollback()
|
|
||||||
return SecurityInUseError{"One or more accounts still use this security"}
|
return SecurityInUseError{"One or more accounts still use this security"}
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := GetUserTx(transaction, s.UserId)
|
user, err := GetUser(tx, s.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
} else if user.DefaultCurrency == s.SecurityId {
|
} else if user.DefaultCurrency == s.SecurityId {
|
||||||
transaction.Rollback()
|
|
||||||
return SecurityInUseError{"Cannot delete security which is user's default currency"}
|
return SecurityInUseError{"Cannot delete security which is user's default currency"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove all prices involving this security (either of this security, or
|
// Remove all prices involving this security (either of this security, or
|
||||||
// using it as a currency)
|
// using it as a currency)
|
||||||
_, err = transaction.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
|
_, err = tx.Exec("DELETE FROM prices WHERE SecurityId=? OR CurrencyId=?", s.SecurityId, s.SecurityId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := transaction.Delete(s)
|
count, err := tx.Delete(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if count != 1 {
|
if count != 1 {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("Deleted more than one security")
|
return errors.New("Deleted more than one security")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, security *Security) (*Security, error) {
|
func ImportGetCreateSecurity(tx *Tx, userid int64, security *Security) (*Security, error) {
|
||||||
security.UserId = userid
|
security.UserId = userid
|
||||||
if len(security.AlternateId) == 0 {
|
if len(security.AlternateId) == 0 {
|
||||||
// Always create a new local security if we can't match on the AlternateId
|
// Always create a new local security if we can't match on the AlternateId
|
||||||
err := InsertSecurityTx(transaction, security)
|
err := InsertSecurity(tx, security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -254,7 +203,7 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi
|
|||||||
|
|
||||||
var securities []*Security
|
var securities []*Security
|
||||||
|
|
||||||
_, err := transaction.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Precision=?", userid, security.Type, security.AlternateId, security.Precision)
|
_, err := tx.Select(&securities, "SELECT * from securities where UserId=? AND Type=? AND AlternateId=? AND Precision=?", userid, security.Type, security.AlternateId, security.Precision)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -286,7 +235,7 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If there wasn't even one security in the list, make a new one
|
// If there wasn't even one security in the list, make a new one
|
||||||
err = InsertSecurityTx(transaction, security)
|
err = InsertSecurity(tx, security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -294,43 +243,33 @@ func ImportGetCreateSecurity(transaction *gorp.Transaction, userid int64, securi
|
|||||||
return security, nil
|
return security, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
func SecurityHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
|
||||||
user, err := GetUserFromSession(db, r)
|
user, err := GetUserFromSession(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
return NewError(1 /*Not Signed In*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method == "POST" {
|
if r.Method == "POST" {
|
||||||
security_json := r.PostFormValue("security")
|
security_json := r.PostFormValue("security")
|
||||||
if security_json == "" {
|
if security_json == "" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var security Security
|
var security Security
|
||||||
err := security.Read(security_json)
|
err := security.Read(security_json)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
security.SecurityId = -1
|
security.SecurityId = -1
|
||||||
security.UserId = user.UserId
|
security.UserId = user.UserId
|
||||||
|
|
||||||
err = InsertSecurity(db, &security)
|
err = InsertSecurity(tx, &security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(201 /*Created*/)
|
return ResponseWrapper{201, &security}
|
||||||
err = security.Write(w)
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if r.Method == "GET" {
|
} else if r.Method == "GET" {
|
||||||
var securityid int64
|
var securityid int64
|
||||||
n, err := GetURLPieces(r.URL.Path, "/security/%d", &securityid)
|
n, err := GetURLPieces(r.URL.Path, "/security/%d", &securityid)
|
||||||
@ -339,87 +278,65 @@ func SecurityHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
//Return all securities
|
//Return all securities
|
||||||
var sl SecurityList
|
var sl SecurityList
|
||||||
|
|
||||||
securities, err := GetSecurities(db, user.UserId)
|
securities, err := GetSecurities(tx, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
sl.Securities = securities
|
sl.Securities = securities
|
||||||
err = (&sl).Write(w)
|
return &sl
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
security, err := GetSecurity(db, securityid, user.UserId)
|
security, err := GetSecurity(tx, securityid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = security.Write(w)
|
return security
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
securityid, err := GetURLID(r.URL.Path)
|
securityid, err := GetURLID(r.URL.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if r.Method == "PUT" {
|
if r.Method == "PUT" {
|
||||||
security_json := r.PostFormValue("security")
|
security_json := r.PostFormValue("security")
|
||||||
if security_json == "" {
|
if security_json == "" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var security Security
|
var security Security
|
||||||
err := security.Read(security_json)
|
err := security.Read(security_json)
|
||||||
if err != nil || security.SecurityId != securityid {
|
if err != nil || security.SecurityId != securityid {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
security.UserId = user.UserId
|
security.UserId = user.UserId
|
||||||
|
|
||||||
err = UpdateSecurity(db, &security)
|
err = UpdateSecurity(tx, &security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = security.Write(w)
|
return &security
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
security, err := GetSecurity(db, securityid, user.UserId)
|
security, err := GetSecurity(tx, securityid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteSecurity(db, security)
|
err = DeleteSecurity(tx, security)
|
||||||
if _, ok := err.(SecurityInUseError); ok {
|
if _, ok := err.(SecurityInUseError); ok {
|
||||||
WriteError(w, 7 /*In Use Error*/)
|
return NewError(7 /*In Use Error*/)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
WriteSuccess(w)
|
return SuccessWriter{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) {
|
func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -13,9 +13,9 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) {
|
|||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
|
|
||||||
db, ok := ctx.Value(dbContextKey).(*DB)
|
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Couldn't find DB in lua's Context")
|
return nil, errors.New("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
|
|
||||||
security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security)
|
security_map, ok = ctx.Value(securitiesContextKey).(map[int64]*Security)
|
||||||
@ -25,7 +25,7 @@ func luaContextGetSecurities(L *lua.LState) (map[int64]*Security, error) {
|
|||||||
return nil, errors.New("Couldn't find User in lua's Context")
|
return nil, errors.New("Couldn't find User in lua's Context")
|
||||||
}
|
}
|
||||||
|
|
||||||
securities, err := GetSecurities(db, user.UserId)
|
securities, err := GetSecurities(tx, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -155,12 +155,12 @@ func luaClosestPrice(L *lua.LState) int {
|
|||||||
date := luaCheckTime(L, 3)
|
date := luaCheckTime(L, 3)
|
||||||
|
|
||||||
ctx := L.Context()
|
ctx := L.Context()
|
||||||
db, ok := ctx.Value(dbContextKey).(*DB)
|
tx, ok := ctx.Value(dbContextKey).(*Tx)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic("Couldn't find DB in lua's Context")
|
panic("Couldn't find tx in lua's Context")
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := GetClosestPrice(db, s, c, date)
|
p, err := GetClosestPrice(tx, s, c, date)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
L.Push(lua.LNil)
|
L.Push(lua.LNil)
|
||||||
} else {
|
} else {
|
||||||
|
@ -28,7 +28,7 @@ func (s *Session) Read(json_str string) error {
|
|||||||
return dec.Decode(s)
|
return dec.Decode(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSession(db *DB, r *http.Request) (*Session, error) {
|
func GetSession(tx *Tx, r *http.Request) (*Session, error) {
|
||||||
var s Session
|
var s Session
|
||||||
|
|
||||||
cookie, err := r.Cookie("moneygo-session")
|
cookie, err := r.Cookie("moneygo-session")
|
||||||
@ -37,18 +37,17 @@ func GetSession(db *DB, r *http.Request) (*Session, error) {
|
|||||||
}
|
}
|
||||||
s.SessionSecret = cookie.Value
|
s.SessionSecret = cookie.Value
|
||||||
|
|
||||||
err = db.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
|
err = tx.SelectOne(&s, "SELECT * from sessions where SessionSecret=?", s.SessionSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteSessionIfExists(db *DB, r *http.Request) error {
|
func DeleteSessionIfExists(tx *Tx, r *http.Request) error {
|
||||||
// TODO do this in one transaction
|
session, err := GetSession(tx, r)
|
||||||
session, err := GetSession(db, r)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_, err := db.Delete(session)
|
_, err := tx.Delete(session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -64,7 +63,17 @@ func NewSessionCookie() (string, error) {
|
|||||||
return base64.StdEncoding.EncodeToString(bits), nil
|
return base64.StdEncoding.EncodeToString(bits), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (*Session, error) {
|
type NewSessionWriter struct {
|
||||||
|
session *Session
|
||||||
|
cookie *http.Cookie
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NewSessionWriter) Write(w http.ResponseWriter) error {
|
||||||
|
http.SetCookie(w, n.cookie)
|
||||||
|
return n.session.Write(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSession(tx *Tx, r *http.Request, userid int64) (*NewSessionWriter, error) {
|
||||||
s := Session{}
|
s := Session{}
|
||||||
|
|
||||||
session_secret, err := NewSessionCookie()
|
session_secret, err := NewSessionCookie()
|
||||||
@ -81,79 +90,66 @@ func NewSession(db *DB, w http.ResponseWriter, r *http.Request, userid int64) (*
|
|||||||
Secure: true,
|
Secure: true,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
}
|
}
|
||||||
http.SetCookie(w, &cookie)
|
|
||||||
|
|
||||||
s.SessionSecret = session_secret
|
s.SessionSecret = session_secret
|
||||||
s.UserId = userid
|
s.UserId = userid
|
||||||
|
|
||||||
err = db.Insert(&s)
|
err = tx.Insert(&s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &s, nil
|
return &NewSessionWriter{&s, &cookie}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SessionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
func SessionHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
|
||||||
if r.Method == "POST" || r.Method == "PUT" {
|
if r.Method == "POST" || r.Method == "PUT" {
|
||||||
user_json := r.PostFormValue("user")
|
user_json := r.PostFormValue("user")
|
||||||
if user_json == "" {
|
if user_json == "" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
user := User{}
|
user := User{}
|
||||||
err := user.Read(user_json)
|
err := user.Read(user_json)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dbuser, err := GetUserByUsername(db, user.Username)
|
dbuser, err := GetUserByUsername(tx, user.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 2 /*Unauthorized Access*/)
|
return NewError(2 /*Unauthorized Access*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
user.HashPassword()
|
user.HashPassword()
|
||||||
if user.PasswordHash != dbuser.PasswordHash {
|
if user.PasswordHash != dbuser.PasswordHash {
|
||||||
WriteError(w, 2 /*Unauthorized Access*/)
|
return NewError(2 /*Unauthorized Access*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteSessionIfExists(db, r)
|
err = DeleteSessionIfExists(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := NewSession(db, w, r, dbuser.UserId)
|
sessionwriter, err := NewSession(tx, r, dbuser.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = session.Write(w)
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
return sessionwriter
|
||||||
} else if r.Method == "GET" {
|
} else if r.Method == "GET" {
|
||||||
s, err := GetSession(db, r)
|
s, err := GetSession(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
return NewError(1 /*Not Signed In*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Write(w)
|
return s
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
err := DeleteSessionIfExists(db, r)
|
err := DeleteSessionIfExists(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
WriteSuccess(w)
|
return SuccessWriter{}
|
||||||
}
|
}
|
||||||
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gopkg.in/gorp.v1"
|
|
||||||
"log"
|
"log"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -78,8 +77,8 @@ func (s *Split) Valid() bool {
|
|||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Split) AlreadyImportedTx(transaction *gorp.Transaction) (bool, error) {
|
func (s *Split) AlreadyImported(tx *Tx) (bool, error) {
|
||||||
count, err := transaction.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
|
count, err := tx.SelectInt("SELECT COUNT(*) from splits where RemoteId=? and AccountId=?", s.RemoteId, s.AccountId)
|
||||||
return count == 1, err
|
return count == 1, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,7 +133,7 @@ func (t *Transaction) Valid() bool {
|
|||||||
|
|
||||||
// Return a map of security ID's to big.Rat's containing the amount that
|
// Return a map of security ID's to big.Rat's containing the amount that
|
||||||
// security is imbalanced by
|
// security is imbalanced by
|
||||||
func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]big.Rat, error) {
|
func (t *Transaction) GetImbalances(tx *Tx) (map[int64]big.Rat, error) {
|
||||||
sums := make(map[int64]big.Rat)
|
sums := make(map[int64]big.Rat)
|
||||||
|
|
||||||
if !t.Valid() {
|
if !t.Valid() {
|
||||||
@ -146,7 +145,7 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
|
|||||||
if t.Splits[i].AccountId != -1 {
|
if t.Splits[i].AccountId != -1 {
|
||||||
var err error
|
var err error
|
||||||
var account *Account
|
var account *Account
|
||||||
account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
|
account, err = GetAccount(tx, t.Splits[i].AccountId, t.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -162,10 +161,10 @@ func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]
|
|||||||
|
|
||||||
// Returns true if all securities contained in this transaction are balanced,
|
// Returns true if all securities contained in this transaction are balanced,
|
||||||
// false otherwise
|
// false otherwise
|
||||||
func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) {
|
func (t *Transaction) Balanced(tx *Tx) (bool, error) {
|
||||||
var zero big.Rat
|
var zero big.Rat
|
||||||
|
|
||||||
sums, err := t.GetImbalancesTx(transaction)
|
sums, err := t.GetImbalances(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -178,72 +177,48 @@ func (t *Transaction) Balanced(transaction *gorp.Transaction) (bool, error) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTransaction(db *DB, transactionid int64, userid int64) (*Transaction, error) {
|
func GetTransaction(tx *Tx, transactionid int64, userid int64) (*Transaction, error) {
|
||||||
var t Transaction
|
var t Transaction
|
||||||
|
|
||||||
transaction, err := db.Begin()
|
err := tx.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.SelectOne(&t, "SELECT * from transactions where UserId=? AND TransactionId=?", userid, transactionid)
|
_, err = tx.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = transaction.Select(&t.Splits, "SELECT * from splits where TransactionId=?", transactionid)
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &t, nil
|
return &t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTransactions(db *DB, userid int64) (*[]Transaction, error) {
|
func GetTransactions(tx *Tx, userid int64) (*[]Transaction, error) {
|
||||||
var transactions []Transaction
|
var transactions []Transaction
|
||||||
|
|
||||||
transaction, err := db.Begin()
|
_, err := tx.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = transaction.Select(&transactions, "SELECT * from transactions where UserId=?", userid)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range transactions {
|
for i := range transactions {
|
||||||
_, err := transaction.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId)
|
_, err := tx.Select(&transactions[i].Splits, "SELECT * from splits where TransactionId=?", transactions[i].TransactionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &transactions, nil
|
return &transactions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func incrementAccountVersions(transaction *gorp.Transaction, user *User, accountids []int64) error {
|
func incrementAccountVersions(tx *Tx, user *User, accountids []int64) error {
|
||||||
for i := range accountids {
|
for i := range accountids {
|
||||||
account, err := GetAccountTx(transaction, accountids[i], user.UserId)
|
account, err := GetAccount(tx, accountids[i], user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
account.AccountVersion++
|
account.AccountVersion++
|
||||||
count, err := transaction.Update(account)
|
count, err := tx.Update(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -260,12 +235,12 @@ func (ame AccountMissingError) Error() string {
|
|||||||
return "Account missing"
|
return "Account missing"
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error {
|
func InsertTransaction(tx *Tx, t *Transaction, user *User) error {
|
||||||
// Map of any accounts with transaction splits being added
|
// Map of any accounts with transaction splits being added
|
||||||
a_map := make(map[int64]bool)
|
a_map := make(map[int64]bool)
|
||||||
for i := range t.Splits {
|
for i := range t.Splits {
|
||||||
if t.Splits[i].AccountId != -1 {
|
if t.Splits[i].AccountId != -1 {
|
||||||
existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId)
|
existing, err := tx.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -287,13 +262,13 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
|
|||||||
if len(a_ids) < 1 {
|
if len(a_ids) < 1 {
|
||||||
return AccountMissingError{}
|
return AccountMissingError{}
|
||||||
}
|
}
|
||||||
err := incrementAccountVersions(transaction, user, a_ids)
|
err := incrementAccountVersions(tx, user, a_ids)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
t.UserId = user.UserId
|
t.UserId = user.UserId
|
||||||
err = transaction.Insert(t)
|
err = tx.Insert(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -301,7 +276,7 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
|
|||||||
for i := range t.Splits {
|
for i := range t.Splits {
|
||||||
t.Splits[i].TransactionId = t.TransactionId
|
t.Splits[i].TransactionId = t.TransactionId
|
||||||
t.Splits[i].SplitId = -1
|
t.Splits[i].SplitId = -1
|
||||||
err = transaction.Insert(t.Splits[i])
|
err = tx.Insert(t.Splits[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -310,31 +285,10 @@ func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InsertTransaction(db *DB, t *Transaction, user *User) error {
|
func UpdateTransaction(tx *Tx, t *Transaction, user *User) error {
|
||||||
transaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = InsertTransactionTx(transaction, t, user)
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error {
|
|
||||||
var existing_splits []*Split
|
var existing_splits []*Split
|
||||||
|
|
||||||
_, err := transaction.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
_, err := tx.Select(&existing_splits, "SELECT * from splits where TransactionId=?", t.TransactionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -353,7 +307,7 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
|
|||||||
t.Splits[i].TransactionId = t.TransactionId
|
t.Splits[i].TransactionId = t.TransactionId
|
||||||
_, ok := s_map[t.Splits[i].SplitId]
|
_, ok := s_map[t.Splits[i].SplitId]
|
||||||
if ok {
|
if ok {
|
||||||
count, err := transaction.Update(t.Splits[i])
|
count, err := tx.Update(t.Splits[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -363,7 +317,7 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
|
|||||||
delete(s_map, t.Splits[i].SplitId)
|
delete(s_map, t.Splits[i].SplitId)
|
||||||
} else {
|
} else {
|
||||||
t.Splits[i].SplitId = -1
|
t.Splits[i].SplitId = -1
|
||||||
err := transaction.Insert(t.Splits[i])
|
err := tx.Insert(t.Splits[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -380,7 +334,7 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
|
|||||||
a_map[existing_splits[i].AccountId] = true
|
a_map[existing_splits[i].AccountId] = true
|
||||||
}
|
}
|
||||||
if ok {
|
if ok {
|
||||||
_, err := transaction.Delete(existing_splits[i])
|
_, err := tx.Delete(existing_splits[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -392,12 +346,12 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
|
|||||||
for id := range a_map {
|
for id := range a_map {
|
||||||
a_ids = append(a_ids, id)
|
a_ids = append(a_ids, id)
|
||||||
}
|
}
|
||||||
err = incrementAccountVersions(transaction, user, a_ids)
|
err = incrementAccountVersions(tx, user, a_ids)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := transaction.Update(t)
|
count, err := tx.Update(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -408,263 +362,171 @@ func UpdateTransactionTx(transaction *gorp.Transaction, t *Transaction, user *Us
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteTransaction(db *DB, t *Transaction, user *User) error {
|
func DeleteTransaction(tx *Tx, t *Transaction, user *User) error {
|
||||||
transaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var accountids []int64
|
var accountids []int64
|
||||||
_, err = transaction.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
|
_, err := tx.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = transaction.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId)
|
_, err = tx.Exec("DELETE FROM splits WHERE TransactionId=?", t.TransactionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := transaction.Delete(t)
|
count, err := tx.Delete(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if count != 1 {
|
if count != 1 {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("Deleted more than one transaction")
|
return errors.New("Deleted more than one transaction")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = incrementAccountVersions(transaction, user, accountids)
|
err = incrementAccountVersions(tx, user, accountids)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TransactionHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
func TransactionHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
|
||||||
user, err := GetUserFromSession(db, r)
|
user, err := GetUserFromSession(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
return NewError(1 /*Not Signed In*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method == "POST" {
|
if r.Method == "POST" {
|
||||||
transaction_json := r.PostFormValue("transaction")
|
transaction_json := r.PostFormValue("transaction")
|
||||||
if transaction_json == "" {
|
if transaction_json == "" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var transaction Transaction
|
var transaction Transaction
|
||||||
err := transaction.Read(transaction_json)
|
err := transaction.Read(transaction_json)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
transaction.TransactionId = -1
|
transaction.TransactionId = -1
|
||||||
transaction.UserId = user.UserId
|
transaction.UserId = user.UserId
|
||||||
|
|
||||||
sqltx, err := db.Begin()
|
balanced, err := transaction.Balanced(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
return NewError(999 /*Internal Error*/)
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
balanced, err := transaction.Balanced(sqltx)
|
|
||||||
if err != nil {
|
|
||||||
sqltx.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if !transaction.Valid() || !balanced {
|
if !transaction.Valid() || !balanced {
|
||||||
sqltx.Rollback()
|
return NewError(3 /*Invalid Request*/)
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range transaction.Splits {
|
for i := range transaction.Splits {
|
||||||
transaction.Splits[i].SplitId = -1
|
transaction.Splits[i].SplitId = -1
|
||||||
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
|
_, err := GetAccount(tx, transaction.Splits[i].AccountId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltx.Rollback()
|
return NewError(3 /*Invalid Request*/)
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = InsertTransactionTx(sqltx, &transaction, user)
|
err = InsertTransaction(tx, &transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(AccountMissingError); ok {
|
if _, ok := err.(AccountMissingError); ok {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
} else {
|
} else {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
sqltx.Rollback()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = sqltx.Commit()
|
return &transaction
|
||||||
if err != nil {
|
|
||||||
sqltx.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = transaction.Write(w)
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if r.Method == "GET" {
|
} else if r.Method == "GET" {
|
||||||
transactionid, err := GetURLID(r.URL.Path)
|
transactionid, err := GetURLID(r.URL.Path)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//Return all Transactions
|
//Return all Transactions
|
||||||
var al TransactionList
|
var al TransactionList
|
||||||
transactions, err := GetTransactions(db, user.UserId)
|
transactions, err := GetTransactions(tx, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
al.Transactions = transactions
|
al.Transactions = transactions
|
||||||
err = (&al).Write(w)
|
return &al
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
//Return Transaction with this Id
|
//Return Transaction with this Id
|
||||||
transaction, err := GetTransaction(db, transactionid, user.UserId)
|
transaction, err := GetTransaction(tx, transactionid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
|
||||||
err = transaction.Write(w)
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return transaction
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
transactionid, err := GetURLID(r.URL.Path)
|
transactionid, err := GetURLID(r.URL.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if r.Method == "PUT" {
|
if r.Method == "PUT" {
|
||||||
transaction_json := r.PostFormValue("transaction")
|
transaction_json := r.PostFormValue("transaction")
|
||||||
if transaction_json == "" {
|
if transaction_json == "" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var transaction Transaction
|
var transaction Transaction
|
||||||
err := transaction.Read(transaction_json)
|
err := transaction.Read(transaction_json)
|
||||||
if err != nil || transaction.TransactionId != transactionid {
|
if err != nil || transaction.TransactionId != transactionid {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
transaction.UserId = user.UserId
|
transaction.UserId = user.UserId
|
||||||
|
|
||||||
sqltx, err := db.Begin()
|
balanced, err := transaction.Balanced(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
|
||||||
|
|
||||||
balanced, err := transaction.Balanced(sqltx)
|
|
||||||
if err != nil {
|
|
||||||
sqltx.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if !transaction.Valid() || !balanced {
|
if !transaction.Valid() || !balanced {
|
||||||
sqltx.Rollback()
|
return NewError(3 /*Invalid Request*/)
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range transaction.Splits {
|
for i := range transaction.Splits {
|
||||||
_, err := GetAccountTx(sqltx, transaction.Splits[i].AccountId, user.UserId)
|
_, err := GetAccount(tx, transaction.Splits[i].AccountId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltx.Rollback()
|
return NewError(3 /*Invalid Request*/)
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = UpdateTransactionTx(sqltx, &transaction, user)
|
err = UpdateTransaction(tx, &transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sqltx.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = sqltx.Commit()
|
return &transaction
|
||||||
if err != nil {
|
|
||||||
sqltx.Rollback()
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = transaction.Write(w)
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
transactionid, err := GetURLID(r.URL.Path)
|
transactionid, err := GetURLID(r.URL.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
transaction, err := GetTransaction(db, transactionid, user.UserId)
|
transaction, err := GetTransaction(tx, transactionid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DeleteTransaction(db, transaction, user)
|
err = DeleteTransaction(tx, transaction, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
WriteSuccess(w)
|
return SuccessWriter{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int64, transactions []Transaction) (*big.Rat, error) {
|
func TransactionsBalanceDifference(tx *Tx, accountid int64, transactions []Transaction) (*big.Rat, error) {
|
||||||
var pageDifference, tmp big.Rat
|
var pageDifference, tmp big.Rat
|
||||||
for i := range transactions {
|
for i := range transactions {
|
||||||
_, err := transaction.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
|
_, err := tx.Select(&transactions[i].Splits, "SELECT * FROM splits where TransactionId=?", transactions[i].TransactionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -685,17 +547,12 @@ func TransactionsBalanceDifference(transaction *gorp.Transaction, accountid int6
|
|||||||
return &pageDifference, nil
|
return &pageDifference, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) {
|
func GetAccountBalance(tx *Tx, user *User, accountid int64) (*big.Rat, error) {
|
||||||
var splits []Split
|
var splits []Split
|
||||||
transaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
|
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=?"
|
||||||
_, err = transaction.Select(&splits, sql, accountid, user.UserId)
|
_, err := tx.Select(&splits, sql, accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -703,34 +560,22 @@ func GetAccountBalance(db *DB, user *User, accountid int64) (*big.Rat, error) {
|
|||||||
for _, s := range splits {
|
for _, s := range splits {
|
||||||
rat_amount, err := GetBigAmount(s.Amount)
|
rat_amount, err := GetBigAmount(s.Amount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tmp.Add(&balance, rat_amount)
|
tmp.Add(&balance, rat_amount)
|
||||||
balance.Set(&tmp)
|
balance.Set(&tmp)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &balance, nil
|
return &balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assumes accountid is valid and is owned by the current user
|
// Assumes accountid is valid and is owned by the current user
|
||||||
func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time) (*big.Rat, error) {
|
func GetAccountBalanceDate(tx *Tx, user *User, accountid int64, date *time.Time) (*big.Rat, error) {
|
||||||
var splits []Split
|
var splits []Split
|
||||||
transaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
|
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date < ?"
|
||||||
_, err = transaction.Select(&splits, sql, accountid, user.UserId, date)
|
_, err := tx.Select(&splits, sql, accountid, user.UserId, date)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -738,33 +583,21 @@ func GetAccountBalanceDate(db *DB, user *User, accountid int64, date *time.Time)
|
|||||||
for _, s := range splits {
|
for _, s := range splits {
|
||||||
rat_amount, err := GetBigAmount(s.Amount)
|
rat_amount, err := GetBigAmount(s.Amount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tmp.Add(&balance, rat_amount)
|
tmp.Add(&balance, rat_amount)
|
||||||
balance.Set(&tmp)
|
balance.Set(&tmp)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &balance, nil
|
return &balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
func GetAccountBalanceDateRange(tx *Tx, user *User, accountid int64, begin, end *time.Time) (*big.Rat, error) {
|
||||||
var splits []Split
|
var splits []Split
|
||||||
transaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
|
sql := "SELECT DISTINCT splits.* FROM splits INNER JOIN transactions ON transactions.TransactionId = splits.TransactionId WHERE splits.AccountId=? AND transactions.UserId=? AND transactions.Date >= ? AND transactions.Date < ?"
|
||||||
_, err = transaction.Select(&splits, sql, accountid, user.UserId, begin, end)
|
_, err := tx.Select(&splits, sql, accountid, user.UserId, begin, end)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -772,31 +605,19 @@ func GetAccountBalanceDateRange(db *DB, user *User, accountid int64, begin, end
|
|||||||
for _, s := range splits {
|
for _, s := range splits {
|
||||||
rat_amount, err := GetBigAmount(s.Amount)
|
rat_amount, err := GetBigAmount(s.Amount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tmp.Add(&balance, rat_amount)
|
tmp.Add(&balance, rat_amount)
|
||||||
balance.Set(&tmp)
|
balance.Set(&tmp)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &balance, nil
|
return &balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
|
func GetAccountTransactions(tx *Tx, user *User, accountid int64, sort string, page uint64, limit uint64) (*AccountTransactionsList, error) {
|
||||||
var transactions []Transaction
|
var transactions []Transaction
|
||||||
var atl AccountTransactionsList
|
var atl AccountTransactionsList
|
||||||
|
|
||||||
transaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var sqlsort, balanceLimitOffset string
|
var sqlsort, balanceLimitOffset string
|
||||||
var balanceLimitOffsetArg uint64
|
var balanceLimitOffsetArg uint64
|
||||||
if sort == "date-asc" {
|
if sort == "date-asc" {
|
||||||
@ -804,9 +625,8 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa
|
|||||||
balanceLimitOffset = " LIMIT ?"
|
balanceLimitOffset = " LIMIT ?"
|
||||||
balanceLimitOffsetArg = page * limit
|
balanceLimitOffsetArg = page * limit
|
||||||
} else if sort == "date-desc" {
|
} else if sort == "date-desc" {
|
||||||
numSplits, err := transaction.SelectInt("SELECT count(*) FROM splits")
|
numSplits, err := tx.SelectInt("SELECT count(*) FROM splits")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
sqlsort = " ORDER BY transactions.Date DESC"
|
sqlsort = " ORDER BY transactions.Date DESC"
|
||||||
@ -819,41 +639,35 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa
|
|||||||
sqloffset = fmt.Sprintf(" OFFSET %d", page*limit)
|
sqloffset = fmt.Sprintf(" OFFSET %d", page*limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := GetAccountTx(transaction, accountid, user.UserId)
|
account, err := GetAccount(tx, accountid, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
atl.Account = account
|
atl.Account = account
|
||||||
|
|
||||||
sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset
|
sql := "SELECT DISTINCT transactions.* FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + " LIMIT ?" + sqloffset
|
||||||
_, err = transaction.Select(&transactions, sql, user.UserId, accountid, limit)
|
_, err = tx.Select(&transactions, sql, user.UserId, accountid, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
atl.Transactions = &transactions
|
atl.Transactions = &transactions
|
||||||
|
|
||||||
pageDifference, err := TransactionsBalanceDifference(transaction, accountid, transactions)
|
pageDifference, err := TransactionsBalanceDifference(tx, accountid, transactions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := transaction.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid)
|
count, err := tx.SelectInt("SELECT count(DISTINCT transactions.TransactionId) FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?", user.UserId, accountid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
atl.TotalTransactions = count
|
atl.TotalTransactions = count
|
||||||
|
|
||||||
security, err := GetSecurityTx(transaction, atl.Account.SecurityId, user.UserId)
|
security, err := GetSecurity(tx, atl.Account.SecurityId, user.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if security == nil {
|
if security == nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, errors.New("Security not found")
|
return nil, errors.New("Security not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -861,9 +675,8 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa
|
|||||||
// occurred before the page we're returning
|
// occurred before the page we're returning
|
||||||
var amounts []string
|
var amounts []string
|
||||||
sql = "SELECT splits.Amount FROM splits WHERE splits.AccountId=? AND splits.TransactionId IN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ")"
|
sql = "SELECT splits.Amount FROM splits WHERE splits.AccountId=? AND splits.TransactionId IN (SELECT DISTINCT transactions.TransactionId FROM transactions INNER JOIN splits ON transactions.TransactionId = splits.TransactionId WHERE transactions.UserId=? AND splits.AccountId=?" + sqlsort + balanceLimitOffset + ")"
|
||||||
_, err = transaction.Select(&amounts, sql, accountid, user.UserId, accountid, balanceLimitOffsetArg)
|
_, err = tx.Select(&amounts, sql, accountid, user.UserId, accountid, balanceLimitOffsetArg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -871,7 +684,6 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa
|
|||||||
for _, amount := range amounts {
|
for _, amount := range amounts {
|
||||||
rat_amount, err := GetBigAmount(amount)
|
rat_amount, err := GetBigAmount(amount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tmp.Add(&balance, rat_amount)
|
tmp.Add(&balance, rat_amount)
|
||||||
@ -880,20 +692,12 @@ func GetAccountTransactions(db *DB, user *User, accountid int64, sort string, pa
|
|||||||
atl.BeginningBalance = balance.FloatString(security.Precision)
|
atl.BeginningBalance = balance.FloatString(security.Precision)
|
||||||
atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision)
|
atl.EndingBalance = tmp.Add(&balance, pageDifference).FloatString(security.Precision)
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &atl, nil
|
return &atl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return only those transactions which have at least one split pertaining to
|
// Return only those transactions which have at least one split pertaining to
|
||||||
// an account
|
// an account
|
||||||
func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request,
|
func AccountTransactionsHandler(tx *Tx, r *http.Request, user *User, accountid int64) ResponseWriterWriter {
|
||||||
user *User, accountid int64) {
|
|
||||||
|
|
||||||
var page uint64 = 0
|
var page uint64 = 0
|
||||||
var limit uint64 = 50
|
var limit uint64 = 50
|
||||||
var sort string = "date-desc"
|
var sort string = "date-desc"
|
||||||
@ -904,8 +708,7 @@ func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request,
|
|||||||
if pagestring != "" {
|
if pagestring != "" {
|
||||||
p, err := strconv.ParseUint(pagestring, 10, 0)
|
p, err := strconv.ParseUint(pagestring, 10, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
page = p
|
page = p
|
||||||
}
|
}
|
||||||
@ -914,8 +717,7 @@ func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request,
|
|||||||
if limitstring != "" {
|
if limitstring != "" {
|
||||||
l, err := strconv.ParseUint(limitstring, 10, 0)
|
l, err := strconv.ParseUint(limitstring, 10, 0)
|
||||||
if err != nil || l > 100 {
|
if err != nil || l > 100 {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
limit = l
|
limit = l
|
||||||
}
|
}
|
||||||
@ -923,23 +725,16 @@ func AccountTransactionsHandler(db *DB, w http.ResponseWriter, r *http.Request,
|
|||||||
sortstring := query.Get("sort")
|
sortstring := query.Get("sort")
|
||||||
if sortstring != "" {
|
if sortstring != "" {
|
||||||
if sortstring != "date-asc" && sortstring != "date-desc" {
|
if sortstring != "date-asc" && sortstring != "date-desc" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
sort = sortstring
|
sort = sortstring
|
||||||
}
|
}
|
||||||
|
|
||||||
accountTransactions, err := GetAccountTransactions(db, user, accountid, sort, page, limit)
|
accountTransactions, err := GetAccountTransactions(tx, user, accountid, sort, page, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = accountTransactions.Write(w)
|
return accountTransactions
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"gopkg.in/gorp.v1"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -47,61 +46,42 @@ func (u *User) HashPassword() {
|
|||||||
u.Password = ""
|
u.Password = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(db *DB, userid int64) (*User, error) {
|
func GetUser(tx *Tx, userid int64) (*User, error) {
|
||||||
var u User
|
var u User
|
||||||
|
|
||||||
err := db.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
err := tx.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &u, nil
|
return &u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserTx(transaction *gorp.Transaction, userid int64) (*User, error) {
|
func GetUserByUsername(tx *Tx, username string) (*User, error) {
|
||||||
var u User
|
var u User
|
||||||
|
|
||||||
err := transaction.SelectOne(&u, "SELECT * from users where UserId=?", userid)
|
err := tx.SelectOne(&u, "SELECT * from users where Username=?", username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &u, nil
|
return &u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByUsername(db *DB, username string) (*User, error) {
|
func InsertUser(tx *Tx, u *User) error {
|
||||||
var u User
|
|
||||||
|
|
||||||
err := db.SelectOne(&u, "SELECT * from users where Username=?", username)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InsertUser(db *DB, u *User) error {
|
|
||||||
transaction, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
security_template := FindCurrencyTemplate(u.DefaultCurrency)
|
security_template := FindCurrencyTemplate(u.DefaultCurrency)
|
||||||
if security_template == nil {
|
if security_template == nil {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("Invalid ISO4217 Default Currency")
|
return errors.New("Invalid ISO4217 Default Currency")
|
||||||
}
|
}
|
||||||
|
|
||||||
existing, err := transaction.SelectInt("SELECT count(*) from users where Username=?", u.Username)
|
existing, err := tx.SelectInt("SELECT count(*) from users where Username=?", u.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if existing > 0 {
|
if existing > 0 {
|
||||||
transaction.Rollback()
|
|
||||||
return UserExistsError{}
|
return UserExistsError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.Insert(u)
|
err = tx.Insert(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,201 +90,138 @@ func InsertUser(db *DB, u *User) error {
|
|||||||
security = *security_template
|
security = *security_template
|
||||||
security.UserId = u.UserId
|
security.UserId = u.UserId
|
||||||
|
|
||||||
err = InsertSecurityTx(transaction, &security)
|
err = InsertSecurity(tx, &security)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the user's DefaultCurrency to our new SecurityId
|
// Update the user's DefaultCurrency to our new SecurityId
|
||||||
u.DefaultCurrency = security.SecurityId
|
u.DefaultCurrency = security.SecurityId
|
||||||
count, err := transaction.Update(u)
|
count, err := tx.Update(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
} else if count != 1 {
|
} else if count != 1 {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("Would have updated more than one user")
|
return errors.New("Would have updated more than one user")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserFromSession(db *DB, r *http.Request) (*User, error) {
|
func GetUserFromSession(tx *Tx, r *http.Request) (*User, error) {
|
||||||
s, err := GetSession(db, r)
|
s, err := GetSession(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return GetUser(db, s.UserId)
|
return GetUser(tx, s.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUser(db *DB, u *User) error {
|
func UpdateUser(tx *Tx, u *User) error {
|
||||||
transaction, err := db.Begin()
|
security, err := GetSecurity(tx, u.DefaultCurrency, u.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
security, err := GetSecurityTx(transaction, u.DefaultCurrency, u.UserId)
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {
|
} else if security.UserId != u.UserId || security.SecurityId != u.DefaultCurrency {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("UserId and DefaultCurrency don't match the fetched security")
|
return errors.New("UserId and DefaultCurrency don't match the fetched security")
|
||||||
} else if security.Type != Currency {
|
} else if security.Type != Currency {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("New DefaultCurrency security is not a currency")
|
return errors.New("New DefaultCurrency security is not a currency")
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := transaction.Update(u)
|
count, err := tx.Update(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
} else if count != 1 {
|
} else if count != 1 {
|
||||||
transaction.Rollback()
|
|
||||||
return errors.New("Would have updated more than one user")
|
return errors.New("Would have updated more than one user")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteUser(db *DB, u *User) error {
|
func DeleteUser(tx *Tx, u *User) error {
|
||||||
transaction, err := db.Begin()
|
count, err := tx.Delete(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := transaction.Delete(u)
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if count != 1 {
|
if count != 1 {
|
||||||
transaction.Rollback()
|
|
||||||
return fmt.Errorf("No user to delete")
|
return fmt.Errorf("No user to delete")
|
||||||
}
|
}
|
||||||
_, err = transaction.Exec("DELETE FROM prices WHERE prices.PriceId IN (SELECT prices.PriceId FROM prices INNER JOIN securities ON prices.SecurityId=securities.SecurityId WHERE securities.UserId=?)", u.UserId)
|
_, err = tx.Exec("DELETE FROM prices WHERE prices.PriceId IN (SELECT prices.PriceId FROM prices INNER JOIN securities ON prices.SecurityId=securities.SecurityId WHERE securities.UserId=?)", u.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = transaction.Exec("DELETE FROM splits WHERE splits.SplitId IN (SELECT splits.SplitId FROM splits INNER JOIN transactions ON splits.TransactionId=transactions.TransactionId WHERE transactions.UserId=?)", u.UserId)
|
_, err = tx.Exec("DELETE FROM splits WHERE splits.SplitId IN (SELECT splits.SplitId FROM splits INNER JOIN transactions ON splits.TransactionId=transactions.TransactionId WHERE transactions.UserId=?)", u.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = transaction.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId)
|
_, err = tx.Exec("DELETE FROM transactions WHERE transactions.UserId=?", u.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = transaction.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId)
|
_, err = tx.Exec("DELETE FROM securities WHERE securities.UserId=?", u.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = transaction.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId)
|
_, err = tx.Exec("DELETE FROM accounts WHERE accounts.UserId=?", u.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = transaction.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId)
|
_, err = tx.Exec("DELETE FROM reports WHERE reports.UserId=?", u.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = transaction.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId)
|
_, err = tx.Exec("DELETE FROM sessions WHERE sessions.UserId=?", u.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = transaction.Commit()
|
|
||||||
if err != nil {
|
|
||||||
transaction.Rollback()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
func UserHandler(r *http.Request, tx *Tx) ResponseWriterWriter {
|
||||||
if r.Method == "POST" {
|
if r.Method == "POST" {
|
||||||
user_json := r.PostFormValue("user")
|
user_json := r.PostFormValue("user")
|
||||||
if user_json == "" {
|
if user_json == "" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var user User
|
var user User
|
||||||
err := user.Read(user_json)
|
err := user.Read(user_json)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
user.UserId = -1
|
user.UserId = -1
|
||||||
user.HashPassword()
|
user.HashPassword()
|
||||||
|
|
||||||
err = InsertUser(db, &user)
|
err = InsertUser(tx, &user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(UserExistsError); ok {
|
if _, ok := err.(UserExistsError); ok {
|
||||||
WriteError(w, 4 /*User Exists*/)
|
return NewError(4 /*User Exists*/)
|
||||||
} else {
|
} else {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(201 /*Created*/)
|
return ResponseWrapper{201, &user}
|
||||||
err = user.Write(w)
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
user, err := GetUserFromSession(db, r)
|
user, err := GetUserFromSession(tx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 1 /*Not Signed In*/)
|
return NewError(1 /*Not Signed In*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
userid, err := GetURLID(r.URL.Path)
|
userid, err := GetURLID(r.URL.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if userid != user.UserId {
|
if userid != user.UserId {
|
||||||
WriteError(w, 2 /*Unauthorized Access*/)
|
return NewError(2 /*Unauthorized Access*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method == "GET" {
|
if r.Method == "GET" {
|
||||||
err = user.Write(w)
|
return user
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if r.Method == "PUT" {
|
} else if r.Method == "PUT" {
|
||||||
user_json := r.PostFormValue("user")
|
user_json := r.PostFormValue("user")
|
||||||
if user_json == "" {
|
if user_json == "" {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save old PWHash in case the new password is bogus
|
// Save old PWHash in case the new password is bogus
|
||||||
@ -312,8 +229,7 @@ func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
|
|
||||||
err = user.Read(user_json)
|
err = user.Read(user_json)
|
||||||
if err != nil || user.UserId != userid {
|
if err != nil || user.UserId != userid {
|
||||||
WriteError(w, 3 /*Invalid Request*/)
|
return NewError(3 /*Invalid Request*/)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the user didn't create a new password, keep their old one
|
// If the user didn't create a new password, keep their old one
|
||||||
@ -324,27 +240,21 @@ func UserHandler(w http.ResponseWriter, r *http.Request, db *DB) {
|
|||||||
user.PasswordHash = old_pwhash
|
user.PasswordHash = old_pwhash
|
||||||
}
|
}
|
||||||
|
|
||||||
err = UpdateUser(db, user)
|
err = UpdateUser(tx, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = user.Write(w)
|
return user
|
||||||
if err != nil {
|
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if r.Method == "DELETE" {
|
} else if r.Method == "DELETE" {
|
||||||
err := DeleteUser(db, user)
|
err := DeleteUser(tx, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, 999 /*Internal Error*/)
|
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return
|
return NewError(999 /*Internal Error*/)
|
||||||
}
|
}
|
||||||
WriteSuccess(w)
|
return SuccessWriter{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return NewError(3 /*Invalid Request*/)
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,23 @@ func GetURLPieces(url string, format string, a ...interface{}) (int, error) {
|
|||||||
return fmt.Sscanf(url, format, a...)
|
return fmt.Sscanf(url, format, a...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ResponseWrapper struct {
|
||||||
|
Code int
|
||||||
|
Writer ResponseWriterWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ResponseWrapper) Write(w http.ResponseWriter) error {
|
||||||
|
w.WriteHeader(r.Code)
|
||||||
|
return r.Writer.Write(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
type SuccessWriter struct{}
|
||||||
|
|
||||||
|
func (s SuccessWriter) Write(w http.ResponseWriter) error {
|
||||||
|
fmt.Fprint(w, "{}")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func WriteSuccess(w http.ResponseWriter) {
|
func WriteSuccess(w http.ResponseWriter) {
|
||||||
fmt.Fprint(w, "{}")
|
fmt.Fprint(w, "{}")
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user