diff --git a/accounts.go b/accounts.go
index 2c947e9..faa1db4 100644
--- a/accounts.go
+++ b/accounts.go
@@ -48,7 +48,7 @@ var accountImportRE *regexp.Regexp
func init() {
accountTransactionsRE = regexp.MustCompile(`^/account/[0-9]+/transactions/?$`)
- accountImportRE = regexp.MustCompile(`^/account/[0-9]+/import/?$`)
+ accountImportRE = regexp.MustCompile(`^/account/[0-9]+/import/[a-z]+/?$`)
}
func (a *Account) Write(w http.ResponseWriter) error {
@@ -97,138 +97,98 @@ func GetAccounts(userid int64) (*[]Account, error) {
return &accounts, nil
}
-// Get (and attempt to create if it doesn't exist) the security/currency
-// trading account for the supplied security/currency
-func GetTradingAccount(userid int64, securityid int64) (*Account, error) {
- var tradingAccounts []Account //top-level 'Trading' account(s)
- var tradingAccount Account
- var accounts []Account //second-level security-specific trading account(s)
+// Get (and attempt to create if it doesn't exist). Matches on UserId,
+// SecurityId, Type, Name, and ParentAccountId
+func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, error) {
+ var accounts []Account
var account Account
- transaction, err := DB.Begin()
- if err != nil {
- return nil, err
- }
-
// Try to find the top-level trading account
- _, err = transaction.Select(&tradingAccounts, "SELECT * from accounts where UserId=? AND Name='Trading' AND ParentAccountId=-1 AND Type=? ORDER BY AccountId ASC LIMIT 1", userid, Trading)
+ _, err := transaction.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId)
if err != nil {
- transaction.Rollback()
- return nil, err
- }
- if len(tradingAccounts) == 1 {
- tradingAccount = tradingAccounts[0]
- } else {
- tradingAccount.UserId = userid
- tradingAccount.Name = "Trading"
- tradingAccount.ParentAccountId = -1
- tradingAccount.SecurityId = 840 /*USD*/ //FIXME SecurityId shouldn't matter for top-level trading account, but maybe we should grab the user's default
- tradingAccount.Type = Trading
-
- err = transaction.Insert(&tradingAccount)
- if err != nil {
- transaction.Rollback()
- return nil, err
- }
- }
-
- // Now, try to find the security-specific trading account
- _, err = transaction.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", userid, securityid, tradingAccount.AccountId)
- if err != nil {
- transaction.Rollback()
return nil, err
}
if len(accounts) == 1 {
account = accounts[0]
} else {
- security := GetSecurity(securityid)
- account.UserId = userid
- account.Name = security.Name
- account.ParentAccountId = tradingAccount.AccountId
- account.SecurityId = securityid
- account.Type = Trading
+ account.UserId = a.UserId
+ account.SecurityId = a.SecurityId
+ account.Type = a.Type
+ account.Name = a.Name
+ account.ParentAccountId = a.ParentAccountId
err = transaction.Insert(&account)
if err != nil {
- transaction.Rollback()
return nil, err
}
}
-
- err = transaction.Commit()
- if err != nil {
- transaction.Rollback()
- return nil, err
- }
-
return &account, nil
}
// Get (and attempt to create if it doesn't exist) the security/currency
-// imbalance account for the supplied security/currency
-func GetImbalanceAccount(userid int64, securityid int64) (*Account, error) {
- var imbalanceAccounts []Account //top-level imbalance account(s)
- var imbalanceAccount Account
- var accounts []Account //second-level security-specific imbalance account(s)
+// trading account for the supplied security/currency
+func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) {
+ var tradingAccount Account
var account Account
- transaction, err := DB.Begin()
+ tradingAccount.UserId = userid
+ tradingAccount.Type = Trading
+ tradingAccount.Name = "Trading"
+ tradingAccount.SecurityId = 840 /*USD*/ //FIXME SecurityId shouldn't matter for top-level trading account, but maybe we should grab the user's default
+ tradingAccount.ParentAccountId = -1
+
+ // Find/create the top-level trading account
+ ta, err := GetCreateAccountTx(transaction, tradingAccount)
if err != nil {
return nil, err
}
- // Try to find the top-level imbalance account
- _, err = transaction.Select(&imbalanceAccounts, "SELECT * from accounts where UserId=? AND Name='Imbalances' AND ParentAccountId=-1 AND Type=? ORDER BY AccountId ASC LIMIT 1", userid, Bank)
+ security := GetSecurity(securityid)
+ account.UserId = userid
+ account.Name = security.Name
+ account.ParentAccountId = ta.AccountId
+ account.SecurityId = securityid
+ account.Type = Trading
+
+ a, err := GetCreateAccountTx(transaction, account)
if err != nil {
- transaction.Rollback()
- return nil, err
- }
- if len(imbalanceAccounts) == 1 {
- imbalanceAccount = imbalanceAccounts[0]
- } else {
- imbalanceAccount.UserId = userid
- imbalanceAccount.Name = "Imbalances"
- imbalanceAccount.ParentAccountId = -1
- imbalanceAccount.SecurityId = 840 /*USD*/ //FIXME SecurityId shouldn't matter for top-level imbalance account, but maybe we should grab the user's default
- imbalanceAccount.Type = Bank
-
- err = transaction.Insert(&imbalanceAccount)
- if err != nil {
- transaction.Rollback()
- return nil, err
- }
- }
-
- // Now, try to find the security-specific imbalances account
- _, err = transaction.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", userid, securityid, imbalanceAccount.AccountId)
- if err != nil {
- transaction.Rollback()
- return nil, err
- }
- if len(accounts) == 1 {
- account = accounts[0]
- } else {
- security := GetSecurity(securityid)
- account.UserId = userid
- account.Name = security.Name
- account.ParentAccountId = imbalanceAccount.AccountId
- account.SecurityId = securityid
- account.Type = Bank
-
- err = transaction.Insert(&account)
- if err != nil {
- transaction.Rollback()
- return nil, err
- }
- }
-
- err = transaction.Commit()
- if err != nil {
- transaction.Rollback()
return nil, err
}
- return &account, nil
+ return a, nil
+}
+
+// Get (and attempt to create if it doesn't exist) the security/currency
+// imbalance account for the supplied security/currency
+func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) {
+ var imbalanceAccount Account
+ var account Account
+
+ imbalanceAccount.UserId = userid
+ imbalanceAccount.Name = "Imbalances"
+ imbalanceAccount.ParentAccountId = -1
+ imbalanceAccount.SecurityId = 840 /*USD*/ //FIXME SecurityId shouldn't matter for top-level imbalance account, but maybe we should grab the user's default
+ imbalanceAccount.Type = Bank
+
+ // Find/create the top-level trading account
+ ia, err := GetCreateAccountTx(transaction, imbalanceAccount)
+ if err != nil {
+ return nil, err
+ }
+
+ security := GetSecurity(securityid)
+ account.UserId = userid
+ account.Name = security.Name
+ account.ParentAccountId = ia.AccountId
+ account.SecurityId = securityid
+ account.Type = Bank
+
+ a, err := GetCreateAccountTx(transaction, account)
+ if err != nil {
+ return nil, err
+ }
+
+ return a, nil
}
type ParentAccountMissingError struct{}
@@ -358,14 +318,15 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) {
// import handler
if accountImportRE.MatchString(r.URL.Path) {
var accountid int64
- n, err := GetURLPieces(r.URL.Path, "/account/%d", &accountid)
+ var importtype string
+ n, err := GetURLPieces(r.URL.Path, "/account/%d/import/%s", &accountid, &importtype)
- if err != nil || n != 1 {
+ if err != nil || n != 2 {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
}
- AccountImportHandler(w, r, user, accountid)
+ AccountImportHandler(w, r, user, accountid, importtype)
return
}
diff --git a/gnucash.go b/gnucash.go
new file mode 100644
index 0000000..7a4d96f
--- /dev/null
+++ b/gnucash.go
@@ -0,0 +1,372 @@
+package main
+
+import (
+ "encoding/xml"
+ "fmt"
+ "io"
+ "log"
+ "math"
+ "math/big"
+ "net/http"
+ "time"
+)
+
+type GnucashXMLCommodity struct {
+ Name string `xml:"http://www.gnucash.org/XML/cmdty id"`
+ Description string `xml:"http://www.gnucash.org/XML/cmdty name"`
+ Type string `xml:"http://www.gnucash.org/XML/cmdty space"`
+ Fraction int `xml:"http://www.gnucash.org/XML/cmdty fraction"`
+ XCode string `xml:"http://www.gnucash.org/XML/cmdty xcode"`
+}
+
+type GnucashCommodity struct{ Security }
+
+func (gc *GnucashCommodity) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
+ var gxc GnucashXMLCommodity
+ if err := d.DecodeElement(&gxc, &start); err != nil {
+ return err
+ }
+
+ gc.Security.Type = Stock // assumed default
+ if gxc.Type == "ISO4217" {
+ gc.Security.Type = Currency
+ }
+ gc.Name = gxc.Name
+ gc.Symbol = gxc.Name
+ gc.Description = gxc.Description
+ gc.AlternateId = gxc.XCode
+ if gxc.Fraction > 0 {
+ gc.Precision = int(math.Ceil(math.Log10(float64(gxc.Fraction))))
+ } else {
+ gc.Precision = 0
+ }
+ return nil
+}
+
+type GnucashTime struct{ time.Time }
+
+func (g *GnucashTime) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
+ var s string
+ if err := d.DecodeElement(&s, &start); err != nil {
+ return fmt.Errorf("date should be a string")
+ }
+ t, err := time.Parse("2006-01-02 15:04:05 -0700", s)
+ g.Time = t
+ return err
+}
+
+type GnucashDate struct {
+ Date GnucashTime `xml:"http://www.gnucash.org/XML/ts date"`
+}
+
+type GnucashAccount struct {
+ Version string `xml:"version,attr"`
+ accountid int64 // Used to map Gnucash guid's to integer ones
+ AccountId string `xml:"http://www.gnucash.org/XML/act id"`
+ ParentAccountId string `xml:"http://www.gnucash.org/XML/act parent"`
+ Name string `xml:"http://www.gnucash.org/XML/act name"`
+ Description string `xml:"http://www.gnucash.org/XML/act description"`
+ Type string `xml:"http://www.gnucash.org/XML/act type"`
+ Commodity GnucashXMLCommodity `xml:"http://www.gnucash.org/XML/act commodity"`
+}
+
+type GnucashTransaction struct {
+ TransactionId string `xml:"http://www.gnucash.org/XML/trn id"`
+ Description string `xml:"http://www.gnucash.org/XML/trn description"`
+ DatePosted GnucashDate `xml:"http://www.gnucash.org/XML/trn date-posted"`
+ DateEntered GnucashDate `xml:"http://www.gnucash.org/XML/trn date-entered"`
+ Commodity GnucashXMLCommodity `xml:"http://www.gnucash.org/XML/trn currency"`
+ Splits []GnucashSplit `xml:"http://www.gnucash.org/XML/trn splits>split"`
+}
+
+type GnucashSplit struct {
+ SplitId string `xml:"http://www.gnucash.org/XML/split id"`
+ AccountId string `xml:"http://www.gnucash.org/XML/split account"`
+ Memo string `xml:"http://www.gnucash.org/XML/split memo"`
+ Amount string `xml:"http://www.gnucash.org/XML/split quantity"`
+ Value string `xml:"http://www.gnucash.org/XML/split value"`
+}
+
+type GnucashXMLImport struct {
+ XMLName xml.Name `xml:"gnc-v2"`
+ Commodities []GnucashCommodity `xml:"http://www.gnucash.org/XML/gnc book>commodity"`
+ Accounts []GnucashAccount `xml:"http://www.gnucash.org/XML/gnc book>account"`
+ Transactions []GnucashTransaction `xml:"http://www.gnucash.org/XML/gnc book>transaction"`
+}
+
+type GnucashImport struct {
+ Securities []Security
+ Accounts []Account
+ Transactions []Transaction
+}
+
+func ImportGnucash(r io.Reader) (*GnucashImport, error) {
+ var gncxml GnucashXMLImport
+ var gncimport GnucashImport
+
+ // Perform initial parsing of xml into structs
+ decoder := xml.NewDecoder(r)
+ err := decoder.Decode(&gncxml)
+ if err != nil {
+ return nil, err
+ }
+
+ // Fixup securities, making a map of them as we go
+ securityMap := make(map[string]Security)
+ for i := range gncxml.Commodities {
+ s := gncxml.Commodities[i].Security
+ s.SecurityId = int64(i + 1)
+ securityMap[s.Name] = s
+
+ // Ignore gnucash's "template" commodity
+ if s.Name != "template" ||
+ s.Description != "template" ||
+ s.AlternateId != "template" {
+ gncimport.Securities = append(gncimport.Securities, s)
+ }
+ }
+
+ //find root account, while simultaneously creating map of GUID's to
+ //accounts
+ var rootAccount GnucashAccount
+ accountMap := make(map[string]GnucashAccount)
+ for i := range gncxml.Accounts {
+ gncxml.Accounts[i].accountid = int64(i + 1)
+ if gncxml.Accounts[i].Type == "ROOT" {
+ rootAccount = gncxml.Accounts[i]
+ } else {
+ accountMap[gncxml.Accounts[i].AccountId] = gncxml.Accounts[i]
+ }
+ }
+
+ //Translate to our account format, figuring out parent relationships
+ for guid := range accountMap {
+ ga := accountMap[guid]
+ var a Account
+
+ a.AccountId = ga.accountid
+ if ga.ParentAccountId == rootAccount.AccountId {
+ a.ParentAccountId = -1
+ } else {
+ parent, ok := accountMap[ga.ParentAccountId]
+ if ok {
+ a.ParentAccountId = parent.accountid
+ } else {
+ a.ParentAccountId = -1 // Ugly, but assign to top-level if we can't find its parent
+ }
+ }
+ a.Name = ga.Name
+ security, ok := securityMap[ga.Commodity.Name]
+ if ok {
+ } else {
+ return nil, fmt.Errorf("Unable to find security: %s", ga.Commodity.Name)
+ }
+ a.SecurityId = security.SecurityId
+
+ //TODO find account types
+ switch ga.Type {
+ default:
+ a.Type = Bank
+ case "ASSET":
+ a.Type = Asset
+ case "BANK":
+ a.Type = Bank
+ case "CASH":
+ a.Type = Cash
+ case "CREDIT", "LIABILITY":
+ a.Type = Liability
+ case "EQUITY":
+ a.Type = Equity
+ case "EXPENSE":
+ a.Type = Expense
+ case "INCOME":
+ a.Type = Income
+ case "PAYABLE":
+ a.Type = Payable
+ case "RECEIVABLE":
+ a.Type = Receivable
+ case "MUTUAL", "STOCK":
+ a.Type = Investment
+ case "TRADING":
+ a.Type = Trading
+ }
+
+ gncimport.Accounts = append(gncimport.Accounts, a)
+ }
+
+ //Translate transactions to our format
+ for i := range gncxml.Transactions {
+ gt := gncxml.Transactions[i]
+
+ t := new(Transaction)
+ t.Description = gt.Description
+ t.Date = gt.DatePosted.Date.Time
+ t.Status = Imported
+ for j := range gt.Splits {
+ gs := gt.Splits[j]
+ s := new(Split)
+ s.Memo = gs.Memo
+ account, ok := accountMap[gs.AccountId]
+ if !ok {
+ return nil, fmt.Errorf("Unable to find account: %s", gs.AccountId)
+ }
+ s.AccountId = account.accountid
+
+ security, ok := securityMap[account.Commodity.Name]
+ if !ok {
+ return nil, fmt.Errorf("Unable to find security: %s", account.Commodity.Name)
+ }
+ s.SecurityId = -1
+
+ var r big.Rat
+ _, ok = r.SetString(gs.Amount)
+ if ok {
+ s.Amount = r.FloatString(security.Precision)
+ } else {
+ return nil, fmt.Errorf("Can't set split Amount: %s", gs.Amount)
+ }
+
+ t.Splits = append(t.Splits, s)
+ }
+ gncimport.Transactions = append(gncimport.Transactions, *t)
+ }
+
+ return &gncimport, nil
+}
+
+func GnucashImportHandler(w http.ResponseWriter, r *http.Request) {
+ user, err := GetUserFromSession(r)
+ if err != nil {
+ WriteError(w, 1 /*Not Signed In*/)
+ return
+ }
+
+ if r.Method != "POST" {
+ WriteError(w, 3 /*Invalid Request*/)
+ return
+ }
+
+ multipartReader, err := r.MultipartReader()
+ if err != nil {
+ WriteError(w, 3 /*Invalid Request*/)
+ return
+ }
+
+ // assume there is only one 'part'
+ part, err := multipartReader.NextPart()
+ if err != nil {
+ if err == io.EOF {
+ WriteError(w, 3 /*Invalid Request*/)
+ } else {
+ WriteError(w, 999 /*Internal Error*/)
+ log.Print(err)
+ }
+ return
+ }
+
+ gnucashImport, err := ImportGnucash(part)
+ if err != nil {
+ WriteError(w, 3 /*Invalid Request*/)
+ return
+ }
+
+ sqltransaction, err := DB.Begin()
+ if err != nil {
+ WriteError(w, 999 /*Internal Error*/)
+ log.Print(err)
+ return
+ }
+
+ // Import securities, building map from Gnucash security IDs to our
+ // internal IDs
+ securityMap := make(map[int64]int64)
+ for _, security := range gnucashImport.Securities {
+ //TODO FIXME check on AlternateID also, and convert to the case
+ //where users have their own internal securities
+ s, err := GetSecurityByNameAndType(security.Name, security.Type)
+ if err != nil {
+ //TODO attempt to create security if it doesn't exist
+ sqltransaction.Rollback()
+ WriteError(w, 999 /*Internal Error*/)
+ log.Print(err)
+ return
+ }
+ securityMap[security.SecurityId] = s.SecurityId
+ }
+
+ // Get/create accounts in the database, building a map from Gnucash account
+ // IDs to our internal IDs as we go
+ accountMap := make(map[int64]int64)
+ accountsRemaining := len(gnucashImport.Accounts)
+ accountsRemainingLast := accountsRemaining
+ for accountsRemaining > 0 {
+ for _, account := range gnucashImport.Accounts {
+
+ // If the account has already been added to the map, skip it
+ _, ok := accountMap[account.AccountId]
+ if ok {
+ continue
+ }
+
+ // If it hasn't been added, but its parent has, add it to the map
+ _, ok = accountMap[account.ParentAccountId]
+ if ok || account.ParentAccountId == -1 {
+ account.UserId = user.UserId
+ if account.ParentAccountId != -1 {
+ account.ParentAccountId = accountMap[account.ParentAccountId]
+ }
+ account.SecurityId = securityMap[account.SecurityId]
+ a, err := GetCreateAccountTx(sqltransaction, account)
+ if err != nil {
+ sqltransaction.Rollback()
+ WriteError(w, 999 /*Internal Error*/)
+ log.Print(err)
+ return
+ }
+ accountMap[account.AccountId] = a.AccountId
+ accountsRemaining--
+ }
+ }
+ if accountsRemaining == accountsRemainingLast {
+ //We didn't make any progress in importing the next level of accounts, so there must be a circular parent-child relationship, so give up and tell the user they're wrong
+ sqltransaction.Rollback()
+ WriteError(w, 999 /*Internal Error*/)
+ log.Print(fmt.Errorf("Circular account parent-child relationship when importing %s", part.FileName()))
+ return
+ }
+ accountsRemainingLast = accountsRemaining
+ }
+
+ // Insert transactions, fixing up account IDs to match internal ones from
+ // above
+ for _, transaction := range gnucashImport.Transactions {
+ for _, split := range transaction.Splits {
+ acctId, ok := accountMap[split.AccountId]
+ if !ok {
+ sqltransaction.Rollback()
+ WriteError(w, 999 /*Internal Error*/)
+ log.Print(fmt.Errorf("Error: Split's AccountID Doesn't exist: %d\n", split.AccountId))
+ return
+ }
+ split.AccountId = acctId
+ fmt.Printf("Setting split AccountId to %d\n", acctId)
+ }
+ err := InsertTransactionTx(sqltransaction, &transaction, user)
+ if err != nil {
+ sqltransaction.Rollback()
+ WriteError(w, 999 /*Internal Error*/)
+ log.Print(err)
+ return
+ }
+ }
+
+ err = sqltransaction.Commit()
+ if err != nil {
+ sqltransaction.Rollback()
+ WriteError(w, 999 /*Internal Error*/)
+ log.Print(err)
+ return
+ }
+
+ WriteSuccess(w)
+}
diff --git a/imports.go b/imports.go
index d1fdb38..02aa22c 100644
--- a/imports.go
+++ b/imports.go
@@ -12,7 +12,9 @@ import (
/*
* Assumes the User is a valid, signed-in user, but accountid has not yet been validated
*/
-func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) {
+func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) {
+ //TODO branch off for different importtype's
+
// Return Account with this Id
account, err := GetAccount(accountid, user.UserId)
if err != nil {
@@ -58,23 +60,32 @@ func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac
itl, err := ImportOFX(tmpFilename, account)
if err != nil {
- //TODO is this necessarily an invalid request?
+ //TODO is this necessarily an invalid request (what if it was an error on our end)?
WriteError(w, 3 /*Invalid Request*/)
return
}
+ sqltransaction, err := DB.Begin()
+ if err != nil {
+ WriteError(w, 999 /*Internal Error*/)
+ log.Print(err)
+ return
+ }
+
var transactions []Transaction
for _, transaction := range *itl.Transactions {
transaction.UserId = user.UserId
transaction.Status = Imported
if !transaction.Valid() {
+ sqltransaction.Rollback()
WriteError(w, 3 /*Invalid Request*/)
return
}
- imbalances, err := transaction.GetImbalances()
+ imbalances, err := transaction.GetImbalancesTx(sqltransaction)
if err != nil {
+ sqltransaction.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
@@ -95,11 +106,12 @@ func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac
// If we're dealing with exactly two securities, assume any imbalances
// from imports are from trading currencies/securities
if num_imbalances == 2 {
- imbalanced_account, err = GetTradingAccount(user.UserId, imbalanced_security)
+ imbalanced_account, err = GetTradingAccount(sqltransaction, user.UserId, imbalanced_security)
} else {
- imbalanced_account, err = GetImbalanceAccount(user.UserId, imbalanced_security)
+ imbalanced_account, err = GetImbalanceAccount(sqltransaction, user.UserId, imbalanced_security)
}
if err != nil {
+ sqltransaction.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
@@ -121,8 +133,9 @@ func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac
// accounts
for _, split := range transaction.Splits {
if split.SecurityId != -1 || split.AccountId == -1 {
- imbalanced_account, err := GetImbalanceAccount(user.UserId, split.SecurityId)
+ imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, split.SecurityId)
if err != nil {
+ sqltransaction.Rollback()
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
return
@@ -133,23 +146,24 @@ func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac
}
}
- balanced, err := transaction.Balanced()
- if !balanced || err != nil {
- WriteError(w, 999 /*Internal Error*/)
- log.Print(err)
- return
- }
-
transactions = append(transactions, transaction)
}
for _, transaction := range transactions {
- err := InsertTransaction(&transaction, user)
+ err := InsertTransactionTx(sqltransaction, &transaction, user)
if err != nil {
WriteError(w, 999 /*Internal Error*/)
log.Print(err)
}
}
+ err = sqltransaction.Commit()
+ if err != nil {
+ sqltransaction.Rollback()
+ WriteError(w, 999 /*Internal Error*/)
+ log.Print(err)
+ return
+ }
+
WriteSuccess(w)
}
diff --git a/js/AccountCombobox.js b/js/AccountCombobox.js
index 67492a2..94ce5de 100644
--- a/js/AccountCombobox.js
+++ b/js/AccountCombobox.js
@@ -9,6 +9,7 @@ module.exports = React.createClass({
getDefaultProps: function() {
return {
includeRoot: true,
+ disabled: false,
rootName: "New Top-level Account"
};
},
@@ -33,6 +34,7 @@ module.exports = React.createClass({
defaultValue={this.props.value}
onChange={this.handleAccountChange}
ref="account"
+ disabled={this.props.disabled}
className={className} />
);
}
diff --git a/js/AccountRegister.js b/js/AccountRegister.js
index 1496ade..073183d 100644
--- a/js/AccountRegister.js
+++ b/js/AccountRegister.js
@@ -20,8 +20,10 @@ var ButtonToolbar = ReactBootstrap.ButtonToolbar;
var ProgressBar = ReactBootstrap.ProgressBar;
var Glyphicon = ReactBootstrap.Glyphicon;
-var DateTimePicker = require('react-widgets').DateTimePicker;
-var Combobox = require('react-widgets').Combobox;
+var ReactWidgets = require('react-widgets')
+var DateTimePicker = ReactWidgets.DateTimePicker;
+var Combobox = ReactWidgets.Combobox;
+var DropdownList = ReactWidgets.DropdownList;
var Big = require('big.js');
@@ -455,29 +457,39 @@ const AddEditTransactionModal = React.createClass({
}
});
+const ImportType = {
+ OFX: 1,
+ Gnucash: 2
+};
+var ImportTypeList = [];
+for (var type in ImportType) {
+ if (ImportType.hasOwnProperty(type)) {
+ var name = ImportType[type] == ImportType.OFX ? "OFX/QFX" : type; //QFX is a special snowflake
+ ImportTypeList.push({'TypeId': ImportType[type], 'Name': name});
+ }
+}
+
const ImportTransactionsModal = React.createClass({
getInitialState: function() {
return {
importing: false,
imported: false,
importFile: "",
+ importType: ImportType.Gnucash,
uploadProgress: -1,
error: null};
},
handleCancel: function() {
- this.setState({
- importing: false,
- imported: false,
- importFile: "",
- uploadProgress: -1,
- error: null
- });
+ this.setState(this.getInitialState());
if (this.props.onCancel != null)
this.props.onCancel();
},
- onImportChanged: function() {
+ handleImportChange: function() {
this.setState({importFile: this.refs.importfile.getValue()});
},
+ handleTypeChange: function(type) {
+ this.setState({importType: type.TypeId});
+ },
handleSubmit: function() {
if (this.props.onSubmit != null)
this.props.onSubmit(this.props.account);
@@ -493,11 +505,18 @@ const ImportTransactionsModal = React.createClass({
handleImportTransactions: function() {
var file = this.refs.importfile.getInputDOMNode().files[0];
var formData = new FormData();
- this.setState({importing: true});
formData.append('importfile', file, this.state.importFile);
+ var url = ""
+ if (this.state.importType == ImportType.OFX)
+ url = "account/"+this.props.account.AccountId+"/import/ofx";
+ else if (this.state.importType == ImportType.Gnucash)
+ url = "import/gnucash";
+
+ this.setState({importing: true});
+
$.ajax({
type: "POST",
- url: "account/"+this.props.account.AccountId+"/import",
+ url: url,
data: formData,
xhr: function() {
var xhrObject = $.ajaxSettings.xhr();
@@ -514,7 +533,7 @@ const ImportTransactionsModal = React.createClass({
if (e.isError()) {
var errString = e.ErrorString;
if (e.ErrorId == 3 /* Invalid Request */) {
- errString = "Please check that the file you uploaded is a valid OFX file for this account and try again.";
+ errString = "Please check that the file you uploaded is valid and try again.";
}
this.setState({
importing: false,
@@ -540,9 +559,11 @@ const ImportTransactionsModal = React.createClass({
});
},
render: function() {
- var accountNameLabel = ""
- if (this.props.account != null )
+ var accountNameLabel = "Performing global import:"
+ if (this.props.account != null && this.state.importType != ImportType.Gnucash)
accountNameLabel = "Importing to '" + getAccountDisplayName(this.props.account, this.props.account_map) + "' account:";
+
+ // Display the progress bar if an upload/import is in progress
var progressBar = [];
if (this.state.importing && this.state.uploadProgress == 100) {
progressBar = ();
@@ -550,6 +571,7 @@ const ImportTransactionsModal = React.createClass({
progressBar = ();
}
+ // Create panel, possibly displaying error or success messages
var panel = [];
if (this.state.error != null) {
panel = ({this.state.error});
@@ -557,16 +579,22 @@ const ImportTransactionsModal = React.createClass({
panel = (Your import is now complete.);
}
- var buttonsDisabled = (this.state.importing) ? true : false;
+ // Display proper buttons, possibly disabling them if an import is in progress
var button1 = [];
var button2 = [];
if (!this.state.imported && this.state.error == null) {
- button1 = ();
- button2 = ();
+ button1 = ();
+ button2 = ();
} else {
- button1 = ();
+ button1 = ();
}
var inputDisabled = (this.state.importing || this.state.error != null || this.state.imported) ? true : false;
+
+ // Disable OFX/QFX imports if no account is selected
+ var disabledTypes = false;
+ if (this.props.account == null)
+ disabledTypes = [ImportTypeList[ImportType.OFX - 1]];
+
return (
@@ -576,13 +604,21 @@ const ImportTransactionsModal = React.createClass({
{progressBar}
{panel}
@@ -897,8 +933,7 @@ module.exports = React.createClass({
diff --git a/libofx.go b/libofx.go
index c9c8830..a45404a 100644
--- a/libofx.go
+++ b/libofx.go
@@ -25,11 +25,11 @@ import (
)
type ImportObject struct {
- TransactionList ImportTransactionsList
+ TransactionList OFXImport
Error error
}
-type ImportTransactionsList struct {
+type OFXImport struct {
Account *Account
Transactions *[]Transaction
TotalTransactions int64
@@ -249,7 +249,7 @@ func OFXTransactionCallback(transaction_data C.struct_OfxTransactionData, data u
return 0
}
-func ImportOFX(filename string, account *Account) (*ImportTransactionsList, error) {
+func ImportOFX(filename string, account *Account) (*OFXImport, error) {
var a Account
var t []Transaction
var iobj ImportObject
diff --git a/main.go b/main.go
index 5344cf5..30b1cc6 100644
--- a/main.go
+++ b/main.go
@@ -69,6 +69,7 @@ func main() {
servemux.HandleFunc("/security/", SecurityHandler)
servemux.HandleFunc("/account/", AccountHandler)
servemux.HandleFunc("/transaction/", TransactionHandler)
+ servemux.HandleFunc("/import/gnucash", GnucashImportHandler)
listener, err := net.Listen("tcp", ":"+strconv.Itoa(port))
if err != nil {
diff --git a/securities.go b/securities.go
index 2f2db08..5da4236 100644
--- a/securities.go
+++ b/securities.go
@@ -2,7 +2,7 @@ package main
import (
"encoding/json"
- "errors"
+ "fmt"
"log"
"net/http"
)
@@ -55703,7 +55703,16 @@ func GetSecurityByName(name string) (*Security, error) {
return value, nil
}
}
- return nil, errors.New("Invalid Security Name")
+ return nil, fmt.Errorf("Invalid Security Name: \"%s\"", name)
+}
+
+func GetSecurityByNameAndType(name string, _type int64) (*Security, error) {
+ for _, value := range security_map {
+ if value.Name == name && value.Type == _type {
+ return value, nil
+ }
+ }
+ return nil, fmt.Errorf("Invalid Security Name (%s) or Type (%d)", name, _type)
}
func GetSecurities() []*Security {
diff --git a/transactions.go b/transactions.go
index 5d4de65..b772910 100644
--- a/transactions.go
+++ b/transactions.go
@@ -113,7 +113,7 @@ func (t *Transaction) Valid() bool {
// Return a map of security ID's to big.Rat's containing the amount that
// security is imbalanced by
-func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) {
+func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]big.Rat, error) {
sums := make(map[int64]big.Rat)
if !t.Valid() {
@@ -123,7 +123,13 @@ func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) {
for i := range t.Splits {
securityid := t.Splits[i].SecurityId
if t.Splits[i].AccountId != -1 {
- account, err := GetAccount(t.Splits[i].AccountId, t.UserId)
+ var err error
+ var account *Account
+ if transaction != nil {
+ account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId)
+ } else {
+ account, err = GetAccount(t.Splits[i].AccountId, t.UserId)
+ }
if err != nil {
return nil, err
}
@@ -137,6 +143,10 @@ func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) {
return sums, nil
}
+func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) {
+ return t.GetImbalancesTx(nil)
+}
+
// Returns true if all securities contained in this transaction are balanced,
// false otherwise
func (t *Transaction) Balanced() (bool, error) {
@@ -235,23 +245,16 @@ func (ame AccountMissingError) Error() string {
return "Account missing"
}
-func InsertTransaction(t *Transaction, user *User) error {
- transaction, err := DB.Begin()
- if err != nil {
- return err
- }
-
+func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error {
// Map of any accounts with transaction splits being added
a_map := make(map[int64]bool)
for i := range t.Splits {
if t.Splits[i].AccountId != -1 {
existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId)
if err != nil {
- transaction.Rollback()
return err
}
if existing != 1 {
- transaction.Rollback()
return AccountMissingError{}
}
a_map[t.Splits[i].AccountId] = true
@@ -269,15 +272,14 @@ func InsertTransaction(t *Transaction, user *User) error {
if len(a_ids) < 1 {
return AccountMissingError{}
}
- err = incrementAccountVersions(transaction, user, a_ids)
+ err := incrementAccountVersions(transaction, user, a_ids)
if err != nil {
- transaction.Rollback()
return err
}
+ t.UserId = user.UserId
err = transaction.Insert(t)
if err != nil {
- transaction.Rollback()
return err
}
@@ -286,11 +288,24 @@ func InsertTransaction(t *Transaction, user *User) error {
t.Splits[i].SplitId = -1
err = transaction.Insert(t.Splits[i])
if err != nil {
- transaction.Rollback()
return err
}
}
+ return nil
+}
+func InsertTransaction(t *Transaction, user *User) error {
+ transaction, err := DB.Begin()
+ if err != nil {
+ return err
+ }
+
+ err = InsertTransactionTx(transaction, t, user)
+ if err != nil {
+ transaction.Rollback()
+ return err
+ }
+
err = transaction.Commit()
if err != nil {
transaction.Rollback()