From 58c7c17727b857697b7c2c4666e8e685d062f16c Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Tue, 2 Feb 2016 21:46:27 -0500 Subject: [PATCH] Initial pass at OFX imports Still needs some fixups: * UI is incomplete * Investment transactions are unbalanced initially * OFX imports don't detect if one of the description fields for a transaction is empty (to fall back on another) * I'm sure plenty of other issues I haven't discovered yet --- accounts.go | 30 +++- imports.go | 91 ++++++++++++ libofx.c | 14 ++ libofx.go | 283 +++++++++++++++++++++++++++++++++++++ main.go | 2 + securities.go | 18 ++- static/account_register.js | 141 +++++++++++++++++- static/models.js | 30 ++-- transactions.go | 77 ++++++---- 9 files changed, 638 insertions(+), 48 deletions(-) create mode 100644 imports.go create mode 100644 libofx.c create mode 100644 libofx.go diff --git a/accounts.go b/accounts.go index ab319ca..b4d113c 100644 --- a/accounts.go +++ b/accounts.go @@ -21,12 +21,13 @@ const ( ) type Account struct { - AccountId int64 - UserId int64 - SecurityId int64 - ParentAccountId int64 // -1 if this account is at the root - Type int64 - Name string + AccountId int64 + ExternalAccountId string + UserId int64 + SecurityId int64 + ParentAccountId int64 // -1 if this account is at the root + Type int64 + Name string // monotonically-increasing account transaction version number. Used for // allowing a client to ensure they have a consistent version when paging @@ -39,9 +40,11 @@ type AccountList struct { } var accountTransactionsRE *regexp.Regexp +var accountImportRE *regexp.Regexp func init() { accountTransactionsRE = regexp.MustCompile(`^/account/[0-9]+/transactions/?$`) + accountImportRE = regexp.MustCompile(`^/account/[0-9]+/import/?$`) } func (a *Account) Write(w http.ResponseWriter) error { @@ -213,6 +216,21 @@ func AccountHandler(w http.ResponseWriter, r *http.Request) { } if r.Method == "POST" { + // if URL looks like /account/[0-9]+/import, use the account + // import handler + if accountImportRE.MatchString(r.URL.Path) { + var accountid int64 + n, err := GetURLPieces(r.URL.Path, "/account/%d", &accountid) + + if err != nil || n != 1 { + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + AccountImportHandler(w, r, user, accountid) + return + } + account_json := r.PostFormValue("account") if account_json == "" { WriteError(w, 3 /*Invalid Request*/) diff --git a/imports.go b/imports.go new file mode 100644 index 0000000..7d25c44 --- /dev/null +++ b/imports.go @@ -0,0 +1,91 @@ +package main + +import ( + "io" + "io/ioutil" + "log" + "net/http" + "os" +) + +/* + * Assumes the User is a valid, signed-in user, but accountid has not yet been validated + */ +func AccountImportHandler(w http.ResponseWriter, r *http.Request, user *User, accountid int64) { + // Return Account with this Id + account, err := GetAccount(accountid, user.UserId) + if err != nil { + WriteError(w, 3 /*Invalid Request*/) + return + } + + multipartReader, err := r.MultipartReader() + if err != nil { + WriteError(w, 3 /*Invalid Request*/) + return + } + + // assume there is only one 'part' + part, err := multipartReader.NextPart() + if err != nil { + if err == io.EOF { + WriteError(w, 3 /*Invalid Request*/) + } else { + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + } + return + } + + f, err := ioutil.TempFile(tmpDir, user.Username+"_"+account.Name) + if err != nil { + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + tmpFilename := f.Name() + defer os.Remove(tmpFilename) + + _, err = io.Copy(f, part) + f.Close() + if err != nil { + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + return + } + + itl, err := ImportOFX(tmpFilename, account) + + if err != nil { + //TODO is this necessarily an invalid request? + WriteError(w, 3 /*Invalid Request*/) + return + } + + for _, transaction := range *itl.Transactions { + if !transaction.Valid() { + WriteError(w, 3 /*Invalid Request*/) + return + } + + // TODO check if transactions are balanced too + // balanced, err := transaction.Balanced() + // if !balanced || err != nil { + // WriteError(w, 3 /*Invalid Request*/) + // return + // } + } + + /////////////////////// TODO //////////////////////// + for _, transaction := range *itl.Transactions { + transaction.UserId = user.UserId + transaction.Status = Imported + err := InsertTransaction(&transaction, user) + if err != nil { + WriteError(w, 999 /*Internal Error*/) + log.Print(err) + } + } + + WriteSuccess(w) +} diff --git a/libofx.c b/libofx.c new file mode 100644 index 0000000..3fb2397 --- /dev/null +++ b/libofx.c @@ -0,0 +1,14 @@ +#include +#include "_cgo_export.h" + +int ofx_statement_callback(const struct OfxStatementData statement_data, void *data) { + return OFXStatementCallback(statement_data, data); +} + +int ofx_account_callback(const struct OfxAccountData account_data, void *data) { + return OFXAccountCallback(account_data, data); +} + +int ofx_transaction_callback(const struct OfxTransactionData transaction_data, void *data) { + return OFXTransactionCallback(transaction_data, data); +} diff --git a/libofx.go b/libofx.go new file mode 100644 index 0000000..c9c8830 --- /dev/null +++ b/libofx.go @@ -0,0 +1,283 @@ +package main + +//#cgo LDFLAGS: -lofx +// +//#include +// +// //The next line disables the definition of static variables to allow for it to +// //be included here (see libofx commit bd24df15531e52a2858f70487443af8b9fa407f4) +//#define OFX_AQUAMANIAC_UGLY_HACK1 +//#include +// +// typedef int (*ofx_statement_cb_fn) (const struct OfxStatementData, void *); +// extern int ofx_statement_callback(const struct OfxStatementData, void *); +// typedef int (*ofx_account_cb_fn) (const struct OfxAccountData, void *); +// extern int ofx_account_callback(const struct OfxAccountData, void *); +// typedef int (*ofx_transaction_cb_fn) (const struct OfxTransactionData, void *); +// extern int ofx_transaction_callback(const struct OfxTransactionData, void *); +import "C" + +import ( + "errors" + "math/big" + "time" + "unsafe" +) + +type ImportObject struct { + TransactionList ImportTransactionsList + Error error +} + +type ImportTransactionsList struct { + Account *Account + Transactions *[]Transaction + TotalTransactions int64 + BeginningBalance string + EndingBalance string +} + +func init() { + // Turn off all libofx info/debug messages + C.ofx_PARSER_msg = 0 + C.ofx_DEBUG_msg = 0 + C.ofx_DEBUG1_msg = 0 + C.ofx_DEBUG2_msg = 0 + C.ofx_DEBUG3_msg = 0 + C.ofx_DEBUG4_msg = 0 + C.ofx_DEBUG5_msg = 0 + C.ofx_STATUS_msg = 0 + C.ofx_INFO_msg = 0 + C.ofx_WARNING_msg = 0 + C.ofx_ERROR_msg = 0 +} + +//export OFXStatementCallback +func OFXStatementCallback(statement_data C.struct_OfxStatementData, data unsafe.Pointer) C.int { + // import := (*ImportObject)(data) + return 0 +} + +//export OFXAccountCallback +func OFXAccountCallback(account_data C.struct_OfxAccountData, data unsafe.Pointer) C.int { + iobj := (*ImportObject)(data) + itl := iobj.TransactionList + if account_data.account_id_valid != 0 { + account_name := C.GoString(&account_data.account_name[0]) + account_id := C.GoString(&account_data.account_id[0]) + itl.Account.Name = account_name + itl.Account.ExternalAccountId = account_id + } else { + if iobj.Error == nil { + iobj.Error = errors.New("OFX account ID invalid") + } + return 1 + } + if account_data.account_type_valid != 0 { + switch account_data.account_type { + case C.OFX_CHECKING, C.OFX_SAVINGS, C.OFX_MONEYMRKT, C.OFX_CMA: + itl.Account.Type = Bank + case C.OFX_CREDITLINE, C.OFX_CREDITCARD: + itl.Account.Type = Liability + case C.OFX_INVESTMENT: + itl.Account.Type = Investment + } + } else { + if iobj.Error == nil { + iobj.Error = errors.New("OFX account type invalid") + } + return 1 + } + if account_data.currency_valid != 0 { + currency_name := C.GoString(&account_data.currency[0]) + currency, err := GetSecurityByName(currency_name) + if err != nil { + if iobj.Error == nil { + iobj.Error = err + } + return 1 + } + itl.Account.SecurityId = currency.SecurityId + } else { + if iobj.Error == nil { + iobj.Error = errors.New("OFX account currency invalid") + } + return 1 + } + return 0 +} + +//export OFXTransactionCallback +func OFXTransactionCallback(transaction_data C.struct_OfxTransactionData, data unsafe.Pointer) C.int { + iobj := (*ImportObject)(data) + itl := iobj.TransactionList + transaction := new(Transaction) + + if transaction_data.name_valid != 0 { + transaction.Description = C.GoString(&transaction_data.name[0]) + } + // if transaction_data.reference_number_valid != 0 { + // fmt.Println("reference_number: ", C.GoString(&transaction_data.reference_number[0])) + // } + if transaction_data.date_posted_valid != 0 { + transaction.Date = time.Unix(int64(transaction_data.date_posted), 0) + } else if transaction_data.date_initiated_valid != 0 { + transaction.Date = time.Unix(int64(transaction_data.date_initiated), 0) + } + if transaction_data.fi_id_valid != 0 { + transaction.RemoteId = C.GoString(&transaction_data.fi_id[0]) + } + + if transaction_data.amount_valid != 0 { + split := new(Split) + r := new(big.Rat) + r.SetFloat64(float64(transaction_data.amount)) + security := GetSecurity(itl.Account.SecurityId) + split.Amount = r.FloatString(security.Precision) + if transaction_data.memo_valid != 0 { + split.Memo = C.GoString(&transaction_data.memo[0]) + } + if transaction_data.check_number_valid != 0 { + split.Number = C.GoString(&transaction_data.check_number[0]) + } + split.SecurityId = -1 + split.AccountId = itl.Account.AccountId + transaction.Splits = append(transaction.Splits, split) + } else { + if iobj.Error == nil { + iobj.Error = errors.New("OFX transaction amount invalid") + } + return 1 + } + + var security *Security + split := new(Split) + units := new(big.Rat) + + if transaction_data.units_valid != 0 { + units.SetFloat64(float64(transaction_data.units)) + if transaction_data.security_data_valid != 0 { + security_data := transaction_data.security_data_ptr + if security_data.ticker_valid != 0 { + s, err := GetSecurityByName(C.GoString(&security_data.ticker[0])) + if err != nil { + if iobj.Error == nil { + iobj.Error = errors.New("Failed to find OFX transaction security: " + C.GoString(&security_data.ticker[0])) + } + return 1 + } + security = s + } else { + if iobj.Error == nil { + iobj.Error = errors.New("OFX security ticker invalid") + } + return 1 + } + if security.Type == Stock && security_data.unique_id_valid != 0 && security_data.unique_id_type_valid != 0 && C.GoString(&security_data.unique_id_type[0]) == "CUSIP" { + // Validate the security CUSIP, if possible + if security.AlternateId != C.GoString(&security_data.unique_id[0]) { + if iobj.Error == nil { + iobj.Error = errors.New("OFX transaction security CUSIP failed to validate") + } + return 1 + } + } + } else { + security = GetSecurity(itl.Account.SecurityId) + } + } else { + // Calculate units from other available fields if its not present + // units = - (amount + various fees) / unitprice + units.SetFloat64(float64(transaction_data.amount)) + fees := new(big.Rat) + if transaction_data.fees_valid != 0 { + fees.SetFloat64(float64(-transaction_data.fees)) + } + if transaction_data.commission_valid != 0 { + commission := new(big.Rat) + commission.SetFloat64(float64(-transaction_data.commission)) + fees.Add(fees, commission) + } + units.Add(units, fees) + units.Neg(units) + if transaction_data.unitprice_valid != 0 && transaction_data.unitprice != 0 { + unitprice := new(big.Rat) + unitprice.SetFloat64(float64(transaction_data.unitprice)) + units.Quo(units, unitprice) + } + + // If 'units' wasn't present, assume we're using the account's security + security = GetSecurity(itl.Account.SecurityId) + } + + split.Amount = units.FloatString(security.Precision) + split.SecurityId = security.SecurityId + split.AccountId = -1 + transaction.Splits = append(transaction.Splits, split) + + if transaction_data.fees_valid != 0 { + split := new(Split) + r := new(big.Rat) + r.SetFloat64(float64(-transaction_data.fees)) + security := GetSecurity(itl.Account.SecurityId) + split.Amount = r.FloatString(security.Precision) + split.Memo = "fees" + split.SecurityId = itl.Account.SecurityId + split.AccountId = -1 + transaction.Splits = append(transaction.Splits, split) + } + + if transaction_data.commission_valid != 0 { + split := new(Split) + r := new(big.Rat) + r.SetFloat64(float64(-transaction_data.commission)) + security := GetSecurity(itl.Account.SecurityId) + split.Amount = r.FloatString(security.Precision) + split.Memo = "commission" + split.SecurityId = itl.Account.SecurityId + split.AccountId = -1 + transaction.Splits = append(transaction.Splits, split) + } + + // if transaction_data.payee_id_valid != 0 { + // fmt.Println("payee_id: ", C.GoString(&transaction_data.payee_id[0])) + // } + + transaction_list := append(*itl.Transactions, *transaction) + iobj.TransactionList.Transactions = &transaction_list + + return 0 +} + +func ImportOFX(filename string, account *Account) (*ImportTransactionsList, error) { + var a Account + var t []Transaction + var iobj ImportObject + iobj.TransactionList.Account = &a + iobj.TransactionList.Transactions = &t + + a.AccountId = account.AccountId + + context := C.libofx_get_new_context() + defer C.libofx_free_context(context) + + C.ofx_set_statement_cb(context, C.ofx_statement_cb_fn(C.ofx_statement_callback), unsafe.Pointer(&iobj)) + C.ofx_set_account_cb(context, C.ofx_account_cb_fn(C.ofx_account_callback), unsafe.Pointer(&iobj)) + C.ofx_set_transaction_cb(context, C.ofx_transaction_cb_fn(C.ofx_transaction_callback), unsafe.Pointer(&iobj)) + + filename_cstring := C.CString(filename) + defer C.free(unsafe.Pointer(filename_cstring)) + C.libofx_proc_file(context, filename_cstring, C.OFX) // unconditionally returns 0. + + iobj.TransactionList.TotalTransactions = int64(len(*iobj.TransactionList.Transactions)) + + if iobj.TransactionList.TotalTransactions == 0 { + return nil, errors.New("No OFX transactions found") + } + + if iobj.Error != nil { + return nil, iobj.Error + } else { + return &iobj.TransactionList, nil + } +} diff --git a/main.go b/main.go index 28f8db2..f279cf9 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( var serveFcgi bool var baseDir string +var tmpDir string var port int var smtpServer string var smtpPort int @@ -23,6 +24,7 @@ var reminderEmail string func init() { flag.StringVar(&baseDir, "base", "./", "Base directory for server") + flag.StringVar(&tmpDir, "tmp", "/tmp", "Directory to create temporary files in") flag.IntVar(&port, "port", 80, "Port to serve API/files on") flag.StringVar(&smtpServer, "smtp.server", "smtp.example.com", "SMTP server to send reminder emails from.") flag.IntVar(&smtpPort, "smtp.port", 587, "SMTP server port to connect to") diff --git a/securities.go b/securities.go index bf0b3ac..8998e4b 100644 --- a/securities.go +++ b/securities.go @@ -2,15 +2,14 @@ package main import ( "encoding/json" + "errors" "log" "net/http" ) const ( - Banknote int64 = 1 - Bond = 2 - Stock = 3 - MutualFund = 4 + Currency int64 = 1 + Stock = 2 ) type Security struct { @@ -22,6 +21,8 @@ type Security struct { // security is precise to Precision int Type int64 + // AlternateId is CUSIP for Type=Stock + AlternateId string } type SecurityList struct { @@ -1303,6 +1304,15 @@ func GetSecurity(securityid int64) *Security { return nil } +func GetSecurityByName(name string) (*Security, error) { + for _, value := range security_map { + if value.Name == name { + return value, nil + } + } + return nil, errors.New("Invalid Security Name") +} + func GetSecurities() []*Security { return security_list } diff --git a/static/account_register.js b/static/account_register.js index 8284135..d52d6a1 100644 --- a/static/account_register.js +++ b/static/account_register.js @@ -13,6 +13,8 @@ var Col = ReactBootstrap.Col; var Button = ReactBootstrap.Button; var ButtonToolbar = ReactBootstrap.ButtonToolbar; +var ProgressBar = ReactBootstrap.ProgressBar; + var DateTimePicker = ReactWidgets.DateTimePicker; const TransactionRow = React.createClass({ @@ -45,7 +47,11 @@ const TransactionRow = React.createClass({ var otherSplit = this.props.transaction.Splits[0]; if (otherSplit.AccountId == this.props.account.AccountId) var otherSplit = this.props.transaction.Splits[1]; - var accountName = getAccountDisplayName(this.props.account_map[otherSplit.AccountId], this.props.account_map); + + if (otherSplit.AccountId == -1) + var accountName = "Unbalanced " + this.props.security_map[otherSplit.SecurityId].Symbol + " transaction"; + else + var accountName = getAccountDisplayName(this.props.account_map[otherSplit.AccountId], this.props.account_map); } else { accountName = "--Split Transaction--"; } @@ -224,6 +230,7 @@ const AddEditTransactionModal = React.createClass({ handleUpdateAccount: function(account, split) { var transaction = this.state.transaction; transaction.Splits[split] = React.addons.update(transaction.Splits[split], { + SecurityId: {$set: -1}, AccountId: {$set: account.AccountId} }); this.setState({ @@ -290,11 +297,14 @@ const AddEditTransactionModal = React.createClass({ var accountValidation = ""; if (s.AccountId in this.props.account_map) { security = this.props.security_map[this.props.account_map[s.AccountId].SecurityId]; - if (security.SecurityId in imbalancedSecurityMap) - amountValidation = "error"; } else { + if (s.SecurityId in this.props.security_map) { + security = this.props.security_map[s.SecurityId]; + } accountValidation = "has-error"; } + if (security != null && security.SecurityId in imbalancedSecurityMap) + amountValidation = "error"; // Define all closures for calling split-updating functions var deleteSplitFn = (function() { @@ -423,9 +433,108 @@ const AddEditTransactionModal = React.createClass({ } }); +const ImportTransactionsModal = React.createClass({ + getInitialState: function() { + return { + importFile: "", + uploadProgress: -1}; + }, + handleCancel: function() { + this.setState({ + importFile: "", + uploadProgress: -1 + }); + if (this.props.onCancel != null) + this.props.onCancel(); + }, + onImportChanged: function() { + this.setState({importFile: this.refs.importfile.getValue()}); + }, + handleSubmit: function() { + if (this.props.onSubmit != null) + this.props.onSubmit(this.props.account); + }, + handleSetProgress: function(e) { + if (e.lengthComputable) { + var pct = Math.round(e.loaded/e.total*100); + this.setState({uploadProgress: pct}); + } else { + this.setState({uploadProgress: 50}); + } + }, + handleImportTransactions: function() { + var file = this.refs.importfile.getInputDOMNode().files[0]; + var formData = new FormData(); + formData.append('importfile', file, this.state.importFile); + $.ajax({ + type: "POST", + url: "account/"+this.props.account.AccountId+"/import", + data: formData, + xhr: function() { + var xhrObject = $.ajaxSettings.xhr(); + if (xhrObject.upload) { + xhrObject.upload.addEventListener('progress', this.handleSetProgress, false); + } else { + console.log("File upload failed because !xhr.upload") + } + return xhrObject; + }.bind(this), + beforeSend: function() { + console.log("before send"); + }, + success: function() { + this.setState({uploadProgress: 100}); + console.log("success"); + }.bind(this), + error: function(e) { + console.log("error handler", e); + }, + // So jQuery doesn't try to process teh data or content-type + cache: false, + contentType: false, + processData: false + }); + }, + render: function() { + var accountNameLabel = "" + if (this.props.account != null ) + accountNameLabel = "Import File to '" + getAccountDisplayName(this.props.account, this.props.account_map) + "':"; + var progressBar = []; + if (this.state.uploadProgress != -1) + progressBar = (); + return ( + + + Import Transactions + + +
+ +
+ {progressBar} +
+ + + + + + +
+ ); + } +}); + const AccountRegister = React.createClass({ getInitialState: function() { return { + importingTransactions: false, editingTransaction: false, selectedTransaction: new Transaction(), transactions: [], @@ -468,6 +577,16 @@ const AccountRegister = React.createClass({ selectedTransaction: newTransaction }); }, + handleImportClicked: function() { + this.setState({ + importingTransactions: true + }); + }, + handleImportingCancel: function() { + this.setState({ + importingTransactions: false + }); + }, ajaxError: function(jqXHR, status, error) { var e = new Error(); e.ErrorId = 5; @@ -593,6 +712,9 @@ const AccountRegister = React.createClass({ error: this.ajaxError }); }, + handleImportComplete: function() { + this.setState({importingTransactions: false}); + }, handleDeleteTransaction: function(transaction) { this.setState({ editingTransaction: false @@ -676,6 +798,13 @@ const AccountRegister = React.createClass({ onDelete={this.handleDeleteTransaction} securities={this.props.securities} security_map={this.props.security_map}/> +
Transactions for '{name}' @@ -695,6 +824,12 @@ const AccountRegister = React.createClass({ disabled={disabled}> New Transaction +
diff --git a/static/models.js b/static/models.js index 38dfa88..3a9060a 100644 --- a/static/models.js +++ b/static/models.js @@ -77,10 +77,8 @@ Session.prototype.isSession = function() { } const SecurityType = { - Banknote: 1, - Bond: 2, - Stock: 3, - MutualFund: 4 + Currency: 1, + Stock: 2 } var SecurityTypeList = []; for (var type in SecurityType) { @@ -197,6 +195,7 @@ function Split() { this.SplitId = -1; this.TransactionId = -1; this.AccountId = -1; + this.SecurityId = -1; this.Number = ""; this.Memo = ""; this.Amount = new Big(0.0); @@ -208,6 +207,7 @@ Split.prototype.toJSONobj = function() { json_obj.SplitId = this.SplitId; json_obj.TransactionId = this.TransactionId; json_obj.AccountId = this.AccountId; + json_obj.SecurityId = this.SecurityId; json_obj.Number = this.Number; json_obj.Memo = this.Memo; json_obj.Amount = this.Amount.toFixed(); @@ -222,6 +222,8 @@ Split.prototype.fromJSONobj = function(json_obj) { this.TransactionId = json_obj.TransactionId; if (json_obj.hasOwnProperty("AccountId")) this.AccountId = json_obj.AccountId; + if (json_obj.hasOwnProperty("SecurityId")) + this.SecurityId = json_obj.SecurityId; if (json_obj.hasOwnProperty("Number")) this.Number = json_obj.Number; if (json_obj.hasOwnProperty("Memo")) @@ -236,14 +238,16 @@ Split.prototype.isSplit = function() { var empty_split = new Split(); return this.SplitId != empty_split.SplitId || this.TransactionId != empty_split.TransactionId || - this.AccountId != empty_split.AccountId; + this.AccountId != empty_split.AccountId || + this.SecurityId != empty_split.SecurityId; } const TransactionStatus = { - Entered: 1, - Cleared: 2, - Reconciled: 3, - Voided: 4 + Imported: 1, + Entered: 2, + Cleared: 3, + Reconciled: 4, + Voided: 5 } var TransactionStatusList = []; for (var type in TransactionStatus) { @@ -331,10 +335,14 @@ Transaction.prototype.imbalancedSplitSecurities = function(account_map) { const emptySplit = new Split(); for (var i = 0; i < this.Splits.length; i++) { split = this.Splits[i]; - if (split.AccountId == emptySplit.AccountId) { + var securityId = -1; + if (split.AccountId != emptySplit.AccountId) { + securityId = account_map[split.AccountId].SecurityId; + } else if (split.SecurityId != emptySplit.SecurityId) { + securityId = split.SecurityId; + } else { continue; } - var securityId = account_map[split.AccountId].SecurityId; if (securityId in splitBalances) { splitBalances[securityId] = split.Amount.plus(splitBalances[securityId]); } else { diff --git a/transactions.go b/transactions.go index f9427e5..57722b6 100644 --- a/transactions.go +++ b/transactions.go @@ -17,10 +17,17 @@ import ( type Split struct { SplitId int64 TransactionId int64 - AccountId int64 - Number string // Check or reference number - Memo string - Amount string // String representation of decimal, suitable for passing to big.Rat.SetString() + + // One of AccountId and SecurityId must be -1 + // In normal splits, AccountId will be valid and SecurityId will be -1. The + // only case where this is reversed is for transactions that have been + // imported and not yet associated with an account. + AccountId int64 + SecurityId int64 + + Number string // Check or reference number + Memo string + Amount string // String representation of decimal, suitable for passing to big.Rat.SetString() } func GetBigAmount(amt string) (*big.Rat, error) { @@ -37,20 +44,26 @@ func (s *Split) GetAmount() (*big.Rat, error) { } func (s *Split) Valid() bool { + if (s.AccountId == -1 && s.SecurityId == -1) || + (s.AccountId != -1 && s.SecurityId != -1) { + return false + } _, err := s.GetAmount() return err == nil } const ( - Entered int64 = 1 - Cleared = 2 - Reconciled = 3 - Voided = 4 + Imported int64 = 1 + Entered = 2 + Cleared = 3 + Reconciled = 4 + Voided = 5 ) type Transaction struct { TransactionId int64 UserId int64 + RemoteId string // unique ID from server, for detecting duplicates Description string Status int64 Date time.Time @@ -106,14 +119,18 @@ func (t *Transaction) Balanced() (bool, error) { return false, errors.New("Transaction invalid") } for i := range t.Splits { - account, err := GetAccount(t.Splits[i].AccountId, t.UserId) - if err != nil { - return false, err + securityid := t.Splits[i].SecurityId + if t.Splits[i].AccountId != -1 { + account, err := GetAccount(t.Splits[i].AccountId, t.UserId) + if err != nil { + return false, err + } + securityid = account.SecurityId } amount, _ := t.Splits[i].GetAmount() - sum := sums[account.SecurityId] + sum := sums[securityid] (&sum).Add(&sum, amount) - sums[account.SecurityId] = sum + sums[securityid] = sum } for _, security_sum := range sums { if security_sum.Cmp(&zero) != 0 { @@ -212,16 +229,20 @@ func InsertTransaction(t *Transaction, user *User) error { // Map of any accounts with transaction splits being added a_map := make(map[int64]bool) for i := range t.Splits { - existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) - if err != nil { - transaction.Rollback() - return err - } - if existing != 1 { - transaction.Rollback() + if t.Splits[i].AccountId != -1 { + existing, err := transaction.SelectInt("SELECT count(*) from accounts where AccountId=?", t.Splits[i].AccountId) + if err != nil { + transaction.Rollback() + return err + } + if existing != 1 { + transaction.Rollback() + return AccountMissingError{} + } + a_map[t.Splits[i].AccountId] = true + } else if t.Splits[i].SecurityId == -1 { return AccountMissingError{} } - a_map[t.Splits[i].AccountId] = true } //increment versions for all accounts @@ -229,6 +250,10 @@ func InsertTransaction(t *Transaction, user *User) error { for id := range a_map { a_ids = append(a_ids, id) } + // ensure at least one of the splits is associated with an actual account + if len(a_ids) < 1 { + return AccountMissingError{} + } err = incrementAccountVersions(transaction, user, a_ids) if err != nil { transaction.Rollback() @@ -305,13 +330,17 @@ func UpdateTransaction(t *Transaction, user *User) error { return err } } - a_map[t.Splits[i].AccountId] = true + if t.Splits[i].AccountId != -1 { + a_map[t.Splits[i].AccountId] = true + } } // Delete any remaining pre-existing splits for i := range existing_splits { _, ok := s_map[existing_splits[i].SplitId] - a_map[existing_splits[i].AccountId] = true + if existing_splits[i].AccountId != -1 { + a_map[existing_splits[i].AccountId] = true + } if ok { _, err := transaction.Delete(existing_splits[i]) if err != nil { @@ -358,7 +387,7 @@ func DeleteTransaction(t *Transaction, user *User) error { } var accountids []int64 - _, err = transaction.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=?", t.TransactionId) + _, err = transaction.Select(&accountids, "SELECT DISTINCT AccountId FROM splits WHERE TransactionId=? AND AccountId != -1", t.TransactionId) if err != nil { transaction.Rollback() return err