diff --git a/accounts.go b/accounts.go index 2c947e9..faa1db4 100644 --- a/accounts.go +++ b/accounts.go @@ -48,7 +48,7 @@ var accountImportRE *regexp.Regexp func init() { accountTransactionsRE = regexp.MustCompile(`^/account/[0-9]+/transactions/?$`) - accountImportRE = regexp.MustCompile(`^/account/[0-9]+/import/?$`) + accountImportRE = regexp.MustCompile(`^/account/[0-9]+/import/[a-z]+/?$`) } func (a *Account) Write(w http.ResponseWriter) error { @@ -97,138 +97,98 @@ func GetAccounts(userid int64) (*[]Account, error) { return &accounts, nil } -// Get (and attempt to create if it doesn't exist) the security/currency -// trading account for the supplied security/currency -func GetTradingAccount(userid int64, securityid int64) (*Account, error) { - var tradingAccounts []Account //top-level 'Trading' account(s) - var tradingAccount Account - var accounts []Account //second-level security-specific trading account(s) +// Get (and attempt to create if it doesn't exist). Matches on UserId, +// SecurityId, Type, Name, and ParentAccountId +func GetCreateAccountTx(transaction *gorp.Transaction, a Account) (*Account, error) { + var accounts []Account var account Account - transaction, err := DB.Begin() - if err != nil { - return nil, err - } - // Try to find the top-level trading account - _, err = transaction.Select(&tradingAccounts, "SELECT * from accounts where UserId=? AND Name='Trading' AND ParentAccountId=-1 AND Type=? ORDER BY AccountId ASC LIMIT 1", userid, Trading) + _, err := transaction.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND Type=? AND Name=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", a.UserId, a.SecurityId, a.Type, a.Name, a.ParentAccountId) if err != nil { - transaction.Rollback() - return nil, err - } - if len(tradingAccounts) == 1 { - tradingAccount = tradingAccounts[0] - } else { - tradingAccount.UserId = userid - tradingAccount.Name = "Trading" - tradingAccount.ParentAccountId = -1 - tradingAccount.SecurityId = 840 /*USD*/ //FIXME SecurityId shouldn't matter for top-level trading account, but maybe we should grab the user's default - tradingAccount.Type = Trading - - err = transaction.Insert(&tradingAccount) - if err != nil { - transaction.Rollback() - return nil, err - } - } - - // Now, try to find the security-specific trading account - _, err = transaction.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", userid, securityid, tradingAccount.AccountId) - if err != nil { - transaction.Rollback() return nil, err } if len(accounts) == 1 { account = accounts[0] } else { - security := GetSecurity(securityid) - account.UserId = userid - account.Name = security.Name - account.ParentAccountId = tradingAccount.AccountId - account.SecurityId = securityid - account.Type = Trading + account.UserId = a.UserId + account.SecurityId = a.SecurityId + account.Type = a.Type + account.Name = a.Name + account.ParentAccountId = a.ParentAccountId err = transaction.Insert(&account) if err != nil { - transaction.Rollback() return nil, err } } - - err = transaction.Commit() - if err != nil { - transaction.Rollback() - return nil, err - } - return &account, nil } // Get (and attempt to create if it doesn't exist) the security/currency -// imbalance account for the supplied security/currency -func GetImbalanceAccount(userid int64, securityid int64) (*Account, error) { - var imbalanceAccounts []Account //top-level imbalance account(s) - var imbalanceAccount Account - var accounts []Account //second-level security-specific imbalance account(s) +// trading account for the supplied security/currency +func GetTradingAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) { + var tradingAccount Account var account Account - transaction, err := DB.Begin() + tradingAccount.UserId = userid + tradingAccount.Type = Trading + tradingAccount.Name = "Trading" + tradingAccount.SecurityId = 840 /*USD*/ //FIXME SecurityId shouldn't matter for top-level trading account, but maybe we should grab the user's default + tradingAccount.ParentAccountId = -1 + + // Find/create the top-level trading account + ta, err := GetCreateAccountTx(transaction, tradingAccount) if err != nil { return nil, err } - // Try to find the top-level imbalance account - _, err = transaction.Select(&imbalanceAccounts, "SELECT * from accounts where UserId=? AND Name='Imbalances' AND ParentAccountId=-1 AND Type=? ORDER BY AccountId ASC LIMIT 1", userid, Bank) + security := GetSecurity(securityid) + account.UserId = userid + account.Name = security.Name + account.ParentAccountId = ta.AccountId + account.SecurityId = securityid + account.Type = Trading + + a, err := GetCreateAccountTx(transaction, account) if err != nil { - transaction.Rollback() - return nil, err - } - if len(imbalanceAccounts) == 1 { - imbalanceAccount = imbalanceAccounts[0] - } else { - imbalanceAccount.UserId = userid - imbalanceAccount.Name = "Imbalances" - imbalanceAccount.ParentAccountId = -1 - imbalanceAccount.SecurityId = 840 /*USD*/ //FIXME SecurityId shouldn't matter for top-level imbalance account, but maybe we should grab the user's default - imbalanceAccount.Type = Bank - - err = transaction.Insert(&imbalanceAccount) - if err != nil { - transaction.Rollback() - return nil, err - } - } - - // Now, try to find the security-specific imbalances account - _, err = transaction.Select(&accounts, "SELECT * from accounts where UserId=? AND SecurityId=? AND ParentAccountId=? ORDER BY AccountId ASC LIMIT 1", userid, securityid, imbalanceAccount.AccountId) - if err != nil { - transaction.Rollback() - return nil, err - } - if len(accounts) == 1 { - account = accounts[0] - } else { - security := GetSecurity(securityid) - account.UserId = userid - account.Name = security.Name - account.ParentAccountId = imbalanceAccount.AccountId - account.SecurityId = securityid - account.Type = Bank - - err = transaction.Insert(&account) - if err != nil { - transaction.Rollback() - return nil, err - } - } - - err = transaction.Commit() - if err != nil { - transaction.Rollback() return nil, err } - return &account, nil + return a, nil +} + +// Get (and attempt to create if it doesn't exist) the security/currency +// imbalance account for the supplied security/currency +func GetImbalanceAccount(transaction *gorp.Transaction, userid int64, securityid int64) (*Account, error) { + var imbalanceAccount Account + var account Account + + imbalanceAccount.UserId = userid + imbalanceAccount.Name = "Imbalances" + imbalanceAccount.ParentAccountId = -1 + imbalanceAccount.SecurityId = 840 /*USD*/ //FIXME SecurityId shouldn't matter for top-level imbalance account, but maybe we should grab the user's default + imbalanceAccount.Type = Bank + + // Find/create the top-level trading account + ia, err := GetCreateAccountTx(transaction, imbalanceAccount) + if err != nil { + return nil, err + } + + security := GetSecurity(securityid) + account.UserId = userid + account.Name = security.Name + account.ParentAccountId = ia.AccountId + account.SecurityId = securityid + account.Type = Bank + + a, err := GetCreateAccountTx(transaction, account) + if err != nil { + return nil, err + } + + return a, nil } type ParentAccountMissingError struct{} @@ -358,14 +318,15 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { // import handler if accountImportRE.MatchString(r.URL.Path) { var accountid int64 - n, err := GetURLPieces(r.URL.Path, "/account/%d", &accountid) + var importtype string + n, err := GetURLPieces(r.URL.Path, "/account/%d/import/%s", &accountid, &importtype) - if err != nil || n != 1 { + if err != nil || n != 2 { WriteError(w, 999 /*Internal Error*/) log.Print(err) return } - AccountImportHandler(w, r, user, accountid) + AccountImportHandler(w, r, user, accountid, importtype) return } diff --git a/gnucash.go b/gnucash.go new file mode 100644 index 0000000..7a4d96f --- /dev/null +++ b/gnucash.go @@ -0,0 +1,372 @@ +package main + +import ( + "encoding/xml" + "fmt" + "io" + "log" + "math" + "math/big" + "net/http" + "time" +) + +type GnucashXMLCommodity struct { + Name string `xml:"http://www.gnucash.org/XML/cmdty id"` + Description string `xml:"http://www.gnucash.org/XML/cmdty name"` + Type string `xml:"http://www.gnucash.org/XML/cmdty space"` + Fraction int `xml:"http://www.gnucash.org/XML/cmdty fraction"` + XCode string `xml:"http://www.gnucash.org/XML/cmdty xcode"` +} + +type GnucashCommodity struct{ Security } + +func (gc *GnucashCommodity) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + var gxc GnucashXMLCommodity + if err := d.DecodeElement(&gxc, &start); err != nil { + return err + } + + gc.Security.Type = Stock // assumed default + if gxc.Type == "ISO4217" { + gc.Security.Type = Currency + } + gc.Name = gxc.Name + gc.Symbol = gxc.Name + gc.Description = gxc.Description + gc.AlternateId = gxc.XCode + if gxc.Fraction > 0 { + gc.Precision = int(math.Ceil(math.Log10(float64(gxc.Fraction)))) + } else { + gc.Precision = 0 + } + return nil +} + +type GnucashTime struct{ time.Time } + +func (g *GnucashTime) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + var s string + if err := d.DecodeElement(&s, &start); err != nil { + return fmt.Errorf("date should be a string") + } + t, err := time.Parse("2006-01-02 15:04:05 -0700", s) + g.Time = t + return err +} + +type GnucashDate struct { + Date GnucashTime `xml:"http://www.gnucash.org/XML/ts date"` +} + +type GnucashAccount struct { + Version string `xml:"version,attr"` + accountid int64 // Used to map Gnucash guid's to integer ones + AccountId string `xml:"http://www.gnucash.org/XML/act id"` + ParentAccountId string `xml:"http://www.gnucash.org/XML/act parent"` + Name string `xml:"http://www.gnucash.org/XML/act name"` + Description string `xml:"http://www.gnucash.org/XML/act description"` + Type string `xml:"http://www.gnucash.org/XML/act type"` + Commodity GnucashXMLCommodity `xml:"http://www.gnucash.org/XML/act commodity"` +} + +type GnucashTransaction struct { + TransactionId string `xml:"http://www.gnucash.org/XML/trn id"` + Description string `xml:"http://www.gnucash.org/XML/trn description"` + DatePosted GnucashDate `xml:"http://www.gnucash.org/XML/trn date-posted"` + DateEntered GnucashDate `xml:"http://www.gnucash.org/XML/trn date-entered"` + Commodity GnucashXMLCommodity `xml:"http://www.gnucash.org/XML/trn currency"` + Splits []GnucashSplit `xml:"http://www.gnucash.org/XML/trn splits>split"` +} + +type GnucashSplit struct { + SplitId string `xml:"http://www.gnucash.org/XML/split id"` + AccountId string `xml:"http://www.gnucash.org/XML/split account"` + Memo string `xml:"http://www.gnucash.org/XML/split memo"` + Amount string `xml:"http://www.gnucash.org/XML/split quantity"` + Value string `xml:"http://www.gnucash.org/XML/split value"` +} + +type GnucashXMLImport struct { + XMLName xml.Name `xml:"gnc-v2"` + Commodities []GnucashCommodity `xml:"http://www.gnucash.org/XML/gnc book>commodity"` + Accounts []GnucashAccount `xml:"http://www.gnucash.org/XML/gnc book>account"` + Transactions []GnucashTransaction `xml:"http://www.gnucash.org/XML/gnc book>transaction"` +} + +type GnucashImport struct { + Securities []Security + Accounts []Account + Transactions []Transaction +} + +func ImportGnucash(r io.Reader) (*GnucashImport, error) { + var gncxml GnucashXMLImport + var gncimport GnucashImport + + // Perform initial parsing of xml into structs + decoder := xml.NewDecoder(r) + err := decoder.Decode(&gncxml) + if err != nil { + return nil, err + } + + // Fixup securities, making a map of them as we go + securityMap := make(map[string]Security) + for i := range gncxml.Commodities { + s := gncxml.Commodities[i].Security + s.SecurityId = int64(i + 1) + securityMap[s.Name] = s + + // Ignore gnucash's "template" commodity + if s.Name != "template" || + s.Description != "template" || + s.AlternateId != "template" { + gncimport.Securities = append(gncimport.Securities, s) + } + } + + //find root account, while simultaneously creating map of GUID's to + //accounts + var rootAccount GnucashAccount + accountMap := make(map[string]GnucashAccount) + for i := range gncxml.Accounts { + gncxml.Accounts[i].accountid = int64(i + 1) + if gncxml.Accounts[i].Type == "ROOT" { + rootAccount = gncxml.Accounts[i] + } else { + accountMap[gncxml.Accounts[i].AccountId] = gncxml.Accounts[i] + } + } + + //Translate to our account format, figuring out parent relationships + for guid := range accountMap { + ga := accountMap[guid] + var a Account + + a.AccountId = ga.accountid + if ga.ParentAccountId == rootAccount.AccountId { + a.ParentAccountId = -1 + } else { + parent, ok := accountMap[ga.ParentAccountId] + if ok { + a.ParentAccountId = parent.accountid + } else { + a.ParentAccountId = -1 // Ugly, but assign to top-level if we can't find its parent + } + } + a.Name = ga.Name + security, ok := securityMap[ga.Commodity.Name] + if ok { + } else { + return nil, fmt.Errorf("Unable to find security: %s", ga.Commodity.Name) + } + a.SecurityId = security.SecurityId + + //TODO find account types + switch ga.Type { + default: + a.Type = Bank + case "ASSET": + a.Type = Asset + case "BANK": + a.Type = Bank + case "CASH": + a.Type = Cash + case "CREDIT", "LIABILITY": + a.Type = Liability + case "EQUITY": + a.Type = Equity + case "EXPENSE": + a.Type = Expense + case "INCOME": + a.Type = Income + case "PAYABLE": + a.Type = Payable + case "RECEIVABLE": + a.Type = Receivable + case "MUTUAL", "STOCK": + a.Type = Investment + case "TRADING": + a.Type = Trading + } + + gncimport.Accounts = append(gncimport.Accounts, a) + } + + //Translate transactions to our format + for i := range gncxml.Transactions { + gt := gncxml.Transactions[i] + + t := new(Transaction) + t.Description = gt.Description + t.Date = gt.DatePosted.Date.Time + t.Status = Imported + for j := range gt.Splits { + gs := gt.Splits[j] + s := new(Split) + s.Memo = gs.Memo + account, ok := accountMap[gs.AccountId] + if !ok { + return nil, fmt.Errorf("Unable to find account: %s", gs.AccountId) + } + s.AccountId = account.accountid + + security, ok := securityMap[account.Commodity.Name] + if !ok { + return nil, fmt.Errorf("Unable to find security: %s", account.Commodity.Name) + } + s.SecurityId = -1 + + var r big.Rat + _, ok = r.SetString(gs.Amount) + if ok { + s.Amount = r.FloatString(security.Precision) + } else { + return nil, fmt.Errorf("Can't set split Amount: %s", gs.Amount) + } + + t.Splits = append(t.Splits, s) + } + gncimport.Transactions = append(gncimport.Transactions, *t) + } + + return &gncimport, nil +} + +func GnucashImportHandler(w http.ResponseWriter, r *http.Request) { + user, err := GetUserFromSession(r) + if err != nil { + WriteError(w, 1 /*Not Signed In*/) + return + } + + if r.Method != "POST" { + WriteError(w, 3 /*Invalid Request*/) + return + } + + multipartReader, err := r.MultipartReader() + if err != nil { + WriteError(w, 3 /*Invalid Request*/) + return + } + + // assume there is only one 'part' + part, err := multipartReader.NextPart() + if err != nil { + if err == io.EOF { + WriteError(w, 3 /*Invalid Request*/) + } else { + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + } + return + } + + gnucashImport, err := ImportGnucash(part) + if err != nil { + WriteError(w, 3 /*Invalid Request*/) + return + } + + sqltransaction, err := DB.Begin() + if err != nil { + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + + // Import securities, building map from Gnucash security IDs to our + // internal IDs + securityMap := make(map[int64]int64) + for _, security := range gnucashImport.Securities { + //TODO FIXME check on AlternateID also, and convert to the case + //where users have their own internal securities + s, err := GetSecurityByNameAndType(security.Name, security.Type) + if err != nil { + //TODO attempt to create security if it doesn't exist + sqltransaction.Rollback() + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + securityMap[security.SecurityId] = s.SecurityId + } + + // Get/create accounts in the database, building a map from Gnucash account + // IDs to our internal IDs as we go + accountMap := make(map[int64]int64) + accountsRemaining := len(gnucashImport.Accounts) + accountsRemainingLast := accountsRemaining + for accountsRemaining > 0 { + for _, account := range gnucashImport.Accounts { + + // If the account has already been added to the map, skip it + _, ok := accountMap[account.AccountId] + if ok { + continue + } + + // If it hasn't been added, but its parent has, add it to the map + _, ok = accountMap[account.ParentAccountId] + if ok || account.ParentAccountId == -1 { + account.UserId = user.UserId + if account.ParentAccountId != -1 { + account.ParentAccountId = accountMap[account.ParentAccountId] + } + account.SecurityId = securityMap[account.SecurityId] + a, err := GetCreateAccountTx(sqltransaction, account) + if err != nil { + sqltransaction.Rollback() + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + accountMap[account.AccountId] = a.AccountId + accountsRemaining-- + } + } + if accountsRemaining == accountsRemainingLast { + //We didn't make any progress in importing the next level of accounts, so there must be a circular parent-child relationship, so give up and tell the user they're wrong + sqltransaction.Rollback() + WriteError(w, 999 /*Internal Error*/) + log.Print(fmt.Errorf("Circular account parent-child relationship when importing %s", part.FileName())) + return + } + accountsRemainingLast = accountsRemaining + } + + // Insert transactions, fixing up account IDs to match internal ones from + // above + for _, transaction := range gnucashImport.Transactions { + for _, split := range transaction.Splits { + acctId, ok := accountMap[split.AccountId] + if !ok { + sqltransaction.Rollback() + WriteError(w, 999 /*Internal Error*/) + log.Print(fmt.Errorf("Error: Split's AccountID Doesn't exist: %d\n", split.AccountId)) + return + } + split.AccountId = acctId + fmt.Printf("Setting split AccountId to %d\n", acctId) + } + err := InsertTransactionTx(sqltransaction, &transaction, user) + if err != nil { + sqltransaction.Rollback() + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + } + + err = sqltransaction.Commit() + if err != nil { + sqltransaction.Rollback() + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + + WriteSuccess(w) +} diff --git a/imports.go b/imports.go index d1fdb38..02aa22c 100644 --- a/imports.go +++ b/imports.go @@ -12,7 +12,9 @@ import ( /* * Assumes the User is a valid, signed-in user, but accountid has not yet been validated */ -func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) { +func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64, importtype string) { + //TODO branch off for different importtype's + // Return Account with this Id account, err := GetAccount(accountid, user.UserId) if err != nil { @@ -58,23 +60,32 @@ func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac itl, err := ImportOFX(tmpFilename, account) if err != nil { - //TODO is this necessarily an invalid request? + //TODO is this necessarily an invalid request (what if it was an error on our end)? WriteError(w, 3 /*Invalid Request*/) return } + sqltransaction, err := DB.Begin() + if err != nil { + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + var transactions []Transaction for _, transaction := range *itl.Transactions { transaction.UserId = user.UserId transaction.Status = Imported if !transaction.Valid() { + sqltransaction.Rollback() WriteError(w, 3 /*Invalid Request*/) return } - imbalances, err := transaction.GetImbalances() + imbalances, err := transaction.GetImbalancesTx(sqltransaction) if err != nil { + sqltransaction.Rollback() WriteError(w, 999 /*Internal Error*/) log.Print(err) return @@ -95,11 +106,12 @@ func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac // If we're dealing with exactly two securities, assume any imbalances // from imports are from trading currencies/securities if num_imbalances == 2 { - imbalanced_account, err = GetTradingAccount(user.UserId, imbalanced_security) + imbalanced_account, err = GetTradingAccount(sqltransaction, user.UserId, imbalanced_security) } else { - imbalanced_account, err = GetImbalanceAccount(user.UserId, imbalanced_security) + imbalanced_account, err = GetImbalanceAccount(sqltransaction, user.UserId, imbalanced_security) } if err != nil { + sqltransaction.Rollback() WriteError(w, 999 /*Internal Error*/) log.Print(err) return @@ -121,8 +133,9 @@ func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac // accounts for _, split := range transaction.Splits { if split.SecurityId != -1 || split.AccountId == -1 { - imbalanced_account, err := GetImbalanceAccount(user.UserId, split.SecurityId) + imbalanced_account, err := GetImbalanceAccount(sqltransaction, user.UserId, split.SecurityId) if err != nil { + sqltransaction.Rollback() WriteError(w, 999 /*Internal Error*/) log.Print(err) return @@ -133,23 +146,24 @@ func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, ac } } - balanced, err := transaction.Balanced() - if !balanced || err != nil { - WriteError(w, 999 /*Internal Error*/) - log.Print(err) - return - } - transactions = append(transactions, transaction) } for _, transaction := range transactions { - err := InsertTransaction(&transaction, user) + err := InsertTransactionTx(sqltransaction, &transaction, user) if err != nil { WriteError(w, 999 /*Internal Error*/) log.Print(err) } } + err = sqltransaction.Commit() + if err != nil { + sqltransaction.Rollback() + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + WriteSuccess(w) } diff --git a/js/AccountCombobox.js b/js/AccountCombobox.js index 67492a2..94ce5de 100644 --- a/js/AccountCombobox.js +++ b/js/AccountCombobox.js @@ -9,6 +9,7 @@ module.exports = React.createClass({ getDefaultProps: function() { return { includeRoot: true, + disabled: false, rootName: "New Top-level Account" }; }, @@ -33,6 +34,7 @@ module.exports = React.createClass({ defaultValue={this.props.value} onChange={this.handleAccountChange} ref="account" + disabled={this.props.disabled} className={className} /> ); } diff --git a/js/AccountRegister.js b/js/AccountRegister.js index 1496ade..073183d 100644 --- a/js/AccountRegister.js +++ b/js/AccountRegister.js @@ -20,8 +20,10 @@ var ButtonToolbar = ReactBootstrap.ButtonToolbar; var ProgressBar = ReactBootstrap.ProgressBar; var Glyphicon = ReactBootstrap.Glyphicon; -var DateTimePicker = require('react-widgets').DateTimePicker; -var Combobox = require('react-widgets').Combobox; +var ReactWidgets = require('react-widgets') +var DateTimePicker = ReactWidgets.DateTimePicker; +var Combobox = ReactWidgets.Combobox; +var DropdownList = ReactWidgets.DropdownList; var Big = require('big.js'); @@ -455,29 +457,39 @@ const AddEditTransactionModal = React.createClass({ } }); +const ImportType = { + OFX: 1, + Gnucash: 2 +}; +var ImportTypeList = []; +for (var type in ImportType) { + if (ImportType.hasOwnProperty(type)) { + var name = ImportType[type] == ImportType.OFX ? "OFX/QFX" : type; //QFX is a special snowflake + ImportTypeList.push({'TypeId': ImportType[type], 'Name': name}); + } +} + const ImportTransactionsModal = React.createClass({ getInitialState: function() { return { importing: false, imported: false, importFile: "", + importType: ImportType.Gnucash, uploadProgress: -1, error: null}; }, handleCancel: function() { - this.setState({ - importing: false, - imported: false, - importFile: "", - uploadProgress: -1, - error: null - }); + this.setState(this.getInitialState()); if (this.props.onCancel != null) this.props.onCancel(); }, - onImportChanged: function() { + handleImportChange: function() { this.setState({importFile: this.refs.importfile.getValue()}); }, + handleTypeChange: function(type) { + this.setState({importType: type.TypeId}); + }, handleSubmit: function() { if (this.props.onSubmit != null) this.props.onSubmit(this.props.account); @@ -493,11 +505,18 @@ const ImportTransactionsModal = React.createClass({ handleImportTransactions: function() { var file = this.refs.importfile.getInputDOMNode().files[0]; var formData = new FormData(); - this.setState({importing: true}); formData.append('importfile', file, this.state.importFile); + var url = "" + if (this.state.importType == ImportType.OFX) + url = "account/"+this.props.account.AccountId+"/import/ofx"; + else if (this.state.importType == ImportType.Gnucash) + url = "import/gnucash"; + + this.setState({importing: true}); + $.ajax({ type: "POST", - url: "account/"+this.props.account.AccountId+"/import", + url: url, data: formData, xhr: function() { var xhrObject = $.ajaxSettings.xhr(); @@ -514,7 +533,7 @@ const ImportTransactionsModal = React.createClass({ if (e.isError()) { var errString = e.ErrorString; if (e.ErrorId == 3 /* Invalid Request */) { - errString = "Please check that the file you uploaded is a valid OFX file for this account and try again."; + errString = "Please check that the file you uploaded is valid and try again."; } this.setState({ importing: false, @@ -540,9 +559,11 @@ const ImportTransactionsModal = React.createClass({ }); }, render: function() { - var accountNameLabel = "" - if (this.props.account != null ) + var accountNameLabel = "Performing global import:" + if (this.props.account != null && this.state.importType != ImportType.Gnucash) accountNameLabel = "Importing to '" + getAccountDisplayName(this.props.account, this.props.account_map) + "' account:"; + + // Display the progress bar if an upload/import is in progress var progressBar = []; if (this.state.importing && this.state.uploadProgress == 100) { progressBar = (); @@ -550,6 +571,7 @@ const ImportTransactionsModal = React.createClass({ progressBar = (); } + // Create panel, possibly displaying error or success messages var panel = []; if (this.state.error != null) { panel = ({this.state.error}); @@ -557,16 +579,22 @@ const ImportTransactionsModal = React.createClass({ panel = (Your import is now complete.); } - var buttonsDisabled = (this.state.importing) ? true : false; + // Display proper buttons, possibly disabling them if an import is in progress var button1 = []; var button2 = []; if (!this.state.imported && this.state.error == null) { - button1 = (); - button2 = (); + button1 = (); + button2 = (); } else { - button1 = (); + button1 = (); } var inputDisabled = (this.state.importing || this.state.error != null || this.state.imported) ? true : false; + + // Disable OFX/QFX imports if no account is selected + var disabledTypes = false; + if (this.props.account == null) + disabledTypes = [ImportTypeList[ImportType.OFX - 1]]; + return ( @@ -576,13 +604,21 @@ const ImportTransactionsModal = React.createClass({
+ + onChange={this.handleImportChange} /> {progressBar} {panel} @@ -897,8 +933,7 @@ module.exports = React.createClass({ diff --git a/libofx.go b/libofx.go index c9c8830..a45404a 100644 --- a/libofx.go +++ b/libofx.go @@ -25,11 +25,11 @@ import ( ) type ImportObject struct { - TransactionList ImportTransactionsList + TransactionList OFXImport Error error } -type ImportTransactionsList struct { +type OFXImport struct { Account *Account Transactions *[]Transaction TotalTransactions int64 @@ -249,7 +249,7 @@ func OFXTransactionCallback(transaction_data C.struct_OfxTransactionData, data u return 0 } -func ImportOFX(filename string, account *Account) (*ImportTransactionsList, error) { +func ImportOFX(filename string, account *Account) (*OFXImport, error) { var a Account var t []Transaction var iobj ImportObject diff --git a/main.go b/main.go index 5344cf5..30b1cc6 100644 --- a/main.go +++ b/main.go @@ -69,6 +69,7 @@ func main() { servemux.HandleFunc("/security/", SecurityHandler) servemux.HandleFunc("/account/", AccountHandler) servemux.HandleFunc("/transaction/", TransactionHandler) + servemux.HandleFunc("/import/gnucash", GnucashImportHandler) listener, err := net.Listen("tcp", ":"+strconv.Itoa(port)) if err != nil { diff --git a/securities.go b/securities.go index 2f2db08..5da4236 100644 --- a/securities.go +++ b/securities.go @@ -2,7 +2,7 @@ package main import ( "encoding/json" - "errors" + "fmt" "log" "net/http" ) @@ -55703,7 +55703,16 @@ func GetSecurityByName(name string) (*Security, error) { return value, nil } } - return nil, errors.New("Invalid Security Name") + return nil, fmt.Errorf("Invalid Security Name: \"%s\"", name) +} + +func GetSecurityByNameAndType(name string, _type int64) (*Security, error) { + for _, value := range security_map { + if value.Name == name && value.Type == _type { + return value, nil + } + } + return nil, fmt.Errorf("Invalid Security Name (%s) or Type (%d)", name, _type) } func GetSecurities() []*Security { diff --git a/transactions.go b/transactions.go index 5d4de65..b772910 100644 --- a/transactions.go +++ b/transactions.go @@ -113,7 +113,7 @@ func (t *Transaction) Valid() bool { // Return a map of security ID's to big.Rat's containing the amount that // security is imbalanced by -func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) { +func (t *Transaction) GetImbalancesTx(transaction *gorp.Transaction) (map[int64]big.Rat, error) { sums := make(map[int64]big.Rat) if !t.Valid() { @@ -123,7 +123,13 @@ func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) { for i := range t.Splits { securityid := t.Splits[i].SecurityId if t.Splits[i].AccountId != -1 { - account, err := GetAccount(t.Splits[i].AccountId, t.UserId) + var err error + var account *Account + if transaction != nil { + account, err = GetAccountTx(transaction, t.Splits[i].AccountId, t.UserId) + } else { + account, err = GetAccount(t.Splits[i].AccountId, t.UserId) + } if err != nil { return nil, err } @@ -137,6 +143,10 @@ func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) { return sums, nil } +func (t *Transaction) GetImbalances() (map[int64]big.Rat, error) { + return t.GetImbalancesTx(nil) +} + // Returns true if all securities contained in this transaction are balanced, // false otherwise func (t *Transaction) Balanced() (bool, error) { @@ -235,23 +245,16 @@ func (ame AccountMissingError) Error() string { return "Account missing" } -func InsertTransaction(t *Transaction, user *User) error { - transaction, err := DB.Begin() - if err != nil { - return err - } - +func InsertTransactionTx(transaction *gorp.Transaction, t *Transaction, user *User) error { // Map of any accounts with transaction splits being added a_map := make(map[int64]bool) for i := range t.Splits { if t.Splits[i].AccountId != -1 { existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) if err != nil { - transaction.Rollback() return err } if existing != 1 { - transaction.Rollback() return AccountMissingError{} } a_map[t.Splits[i].AccountId] = true @@ -269,15 +272,14 @@ func InsertTransaction(t *Transaction, user *User) error { if len(a_ids) < 1 { return AccountMissingError{} } - err = incrementAccountVersions(transaction, user, a_ids) + err := incrementAccountVersions(transaction, user, a_ids) if err != nil { - transaction.Rollback() return err } + t.UserId = user.UserId err = transaction.Insert(t) if err != nil { - transaction.Rollback() return err } @@ -286,11 +288,24 @@ func InsertTransaction(t *Transaction, user *User) error { t.Splits[i].SplitId = -1 err = transaction.Insert(t.Splits[i]) if err != nil { - transaction.Rollback() return err } } + return nil +} +func InsertTransaction(t *Transaction, user *User) error { + transaction, err := DB.Begin() + if err != nil { + return err + } + + err = InsertTransactionTx(transaction, t, user) + if err != nil { + transaction.Rollback() + return err + } + err = transaction.Commit() if err != nil { transaction.Rollback()