diff --git a/bank_test.go b/bank_test.go index bb9da4f..f9d7b9c 100644 --- a/bank_test.go +++ b/bank_test.go @@ -283,4 +283,5 @@ func TestUnmarshalBankStatementResponse(t *testing.T) { } checkResponsesEqual(t, &expected, response) + checkResponseRoundTrip(t, response) } diff --git a/common.go b/common.go index 786fe2c..62cf58c 100644 --- a/common.go +++ b/common.go @@ -3,10 +3,36 @@ package ofxgo //go:generate ./generate_constants.py import ( + "bytes" "errors" + "fmt" "github.com/aclindsa/xml" ) +func writeHeader(b *bytes.Buffer, v ofxVersion) error { + // Write the header appropriate to our version + switch v { + case OfxVersion102, OfxVersion103, OfxVersion151, OfxVersion160: + b.WriteString(`OFXHEADER:100 +DATA:OFXSGML +VERSION:` + v.String() + ` +SECURITY:NONE +ENCODING:USASCII +CHARSET:1252 +COMPRESSION:NONE +OLDFILEUID:NONE +NEWFILEUID:NONE + +`) + case OfxVersion200, OfxVersion201, OfxVersion202, OfxVersion203, OfxVersion210, OfxVersion211, OfxVersion220: + b.WriteString(`` + "\n") + b.WriteString(`` + "\n") + default: + return fmt.Errorf("%d is not a valid OFX version string", v) + } + return nil +} + // Message represents an OFX message in a message set. it is used to ease // marshalling and unmarshalling. type Message interface { diff --git a/creditcard_test.go b/creditcard_test.go index e384488..d487dd6 100644 --- a/creditcard_test.go +++ b/creditcard_test.go @@ -161,4 +161,5 @@ NEWFILEUID:NONE } checkResponsesEqual(t, &expected, response) + checkResponseRoundTrip(t, response) } diff --git a/go.mod b/go.mod index b0419f5..afc1f02 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,7 @@ module github.com/aclindsa/ofxgo require ( github.com/aclindsa/xml v0.0.0-20171002130543-5d4402bb4a20 + github.com/howeyc/gopass v0.0.0-20170109162249-bf9dde6d0d2c golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4 // indirect golang.org/x/sys v0.0.0-20180928133829-e4b3c5e90611 // indirect golang.org/x/text v0.0.0-20180911161511-905a57155faa diff --git a/go.sum b/go.sum index cf21bc7..51a8234 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/aclindsa/xml v0.0.0-20171002130543-5d4402bb4a20 h1:wN3KlzWq56AIgOqFzYLYVih4zVyPDViCUeG5uZxJHq4= github.com/aclindsa/xml v0.0.0-20171002130543-5d4402bb4a20/go.mod h1:DiEHtTD+e6zS3+R95F05Bfbcsfv13wZTi2M4LfAFLBE= +github.com/howeyc/gopass v0.0.0-20170109162249-bf9dde6d0d2c h1:kQWxfPIHVLbgLzphqk3QUflDy9QdksZR4ygR807bpy0= +github.com/howeyc/gopass v0.0.0-20170109162249-bf9dde6d0d2c/go.mod h1:lADxMC39cJJqL93Duh1xhAs4I2Zs8mKS89XWXFGp9cs= golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4 h1:Vk3wNqEZwyGyei9yq5ekj7frek2u7HUfffJ1/opblzc= golang.org/x/crypto v0.0.0-20181001203147-e3636079e1a4/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/sys v0.0.0-20180928133829-e4b3c5e90611 h1:O33LKL7WyJgjN9CvxfTIomjIClbd/Kq86/iipowHQU0= diff --git a/invstmt.go b/invstmt.go index 3adeee6..9a36e45 100644 --- a/invstmt.go +++ b/invstmt.go @@ -465,6 +465,7 @@ type InvBankTransaction struct { // security-related transactions themselves. It must be unmarshalled manually // due to the structure (don't know what kind of InvTransaction is coming next) type InvTranList struct { + XMLName xml.Name `xml:"INVTRANLIST"` DtStart Date DtEnd Date // This is the value that should be sent as in the next InvStatementRequest to ensure that no transactions are missed InvTransactions []InvTransaction @@ -630,6 +631,119 @@ func (l *InvTranList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error } } +// MarshalXML handles marshalling an InvTranList element to an SGML/XML string +func (l *InvTranList) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + invTranListElement := xml.StartElement{Name: xml.Name{Local: "INVTRANLIST"}} + if err := e.EncodeToken(invTranListElement); err != nil { + return err + } + err := e.EncodeElement(&l.DtStart, xml.StartElement{Name: xml.Name{Local: "DTSTART"}}) + if err != nil { + return err + } + err = e.EncodeElement(&l.DtEnd, xml.StartElement{Name: xml.Name{Local: "DTEND"}}) + if err != nil { + return err + } + for _, t := range l.InvTransactions { + start := xml.StartElement{Name: xml.Name{Local: t.TransactionType()}} + switch tran := t.(type) { + case BuyDebt: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case BuyMF: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case BuyOpt: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case BuyOther: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case BuyStock: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case ClosureOpt: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case Income: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case InvExpense: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case JrnlFund: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case JrnlSec: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case MarginInterest: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case Reinvest: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case RetOfCap: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case SellDebt: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case SellMF: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case SellOpt: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case SellOther: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case SellStock: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case Split: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + case Transfer: + if err := e.EncodeElement(&tran, start); err != nil { + return err + } + default: + return errors.New("Invalid INVTRANLIST child type: " + tran.TransactionType()) + } + } + for _, tran := range l.BankTransactions { + err = e.EncodeElement(&tran, xml.StartElement{Name: xml.Name{Local: "INVBANKTRAN"}}) + if err != nil { + return err + } + } + if err := e.EncodeToken(invTranListElement.End()); err != nil { + return err + } + return nil +} + // InvPosition contains generic position information included in each of the // other *Position types type InvPosition struct { @@ -770,6 +884,45 @@ func (p *PositionList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) erro } } +// MarshalXML handles marshalling a PositionList to an XML string +func (p *PositionList) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + invPosListElement := xml.StartElement{Name: xml.Name{Local: "INVPOSLIST"}} + if err := e.EncodeToken(invPosListElement); err != nil { + return err + } + for _, position := range *p { + start := xml.StartElement{Name: xml.Name{Local: position.PositionType()}} + switch pos := position.(type) { + case DebtPosition: + if err := e.EncodeElement(&pos, start); err != nil { + return err + } + case MFPosition: + if err := e.EncodeElement(&pos, start); err != nil { + return err + } + case OptPosition: + if err := e.EncodeElement(&pos, start); err != nil { + return err + } + case OtherPosition: + if err := e.EncodeElement(&pos, start); err != nil { + return err + } + case StockPosition: + if err := e.EncodeElement(&pos, start); err != nil { + return err + } + default: + return errors.New("Invalid INVPOSLIST child type: " + pos.PositionType()) + } + } + if err := e.EncodeToken(invPosListElement.End()); err != nil { + return err + } + return nil +} + // InvBalance contains three (or optionally four) specified balances as well as // a free-form list of generic balance information which may be provided by an // FI. @@ -1036,6 +1189,69 @@ func (o *OOList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { } } +// MarshalXML handles marshalling an OOList to an XML string +func (o *OOList) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + ooListElement := xml.StartElement{Name: xml.Name{Local: "INVOOLIST"}} + if err := e.EncodeToken(ooListElement); err != nil { + return err + } + for _, openorder := range *o { + start := xml.StartElement{Name: xml.Name{Local: openorder.OrderType()}} + switch oo := openorder.(type) { + case OOBuyDebt: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOBuyMF: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOBuyOpt: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOBuyOther: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOBuyStock: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOSellDebt: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOSellMF: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOSellOpt: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOSellOther: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOSellStock: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + case OOSwitchMF: + if err := e.EncodeElement(&oo, start); err != nil { + return err + } + default: + return errors.New("Invalid OOLIST child type: " + oo.OrderType()) + } + } + if err := e.EncodeToken(ooListElement.End()); err != nil { + return err + } + return nil +} + // ContribSecurity identifies current contribution allocation for a security in // a 401(k) account type ContribSecurity struct { diff --git a/invstmt_test.go b/invstmt_test.go index afe37b9..e1a1e14 100644 --- a/invstmt_test.go +++ b/invstmt_test.go @@ -602,6 +602,7 @@ func TestUnmarshalInvStatementResponse(t *testing.T) { } checkResponsesEqual(t, &expected, response) + checkResponseRoundTrip(t, response) } func TestUnmarshalInvStatementResponse102(t *testing.T) { @@ -957,6 +958,7 @@ NEWFILEUID: NONE } checkResponsesEqual(t, &expected, response) + checkResponseRoundTrip(t, response) } func TestUnmarshalInvTranList(t *testing.T) { diff --git a/profile.go b/profile.go index c28cc8a..8bacc60 100644 --- a/profile.go +++ b/profile.go @@ -3,6 +3,7 @@ package ofxgo import ( "errors" "github.com/aclindsa/xml" + "strings" ) // ProfileRequest represents a request for a server to provide a profile of its @@ -126,6 +127,35 @@ func (msl *MessageSetList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) } } +// MarshalXML handles marshalling a MessageSetList element to an XML string +func (msl *MessageSetList) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + messageSetListElement := xml.StartElement{Name: xml.Name{Local: "MSGSETLIST"}} + if err := e.EncodeToken(messageSetListElement); err != nil { + return err + } + for _, messageset := range *msl { + if !strings.HasSuffix(messageset.Name, "V1") { + return errors.New("Expected MessageSet.Name to end with \"V1\"") + } + messageSetName := strings.TrimSuffix(messageset.Name, "V1") + messageSetElement := xml.StartElement{Name: xml.Name{Local: messageSetName}} + if err := e.EncodeToken(messageSetElement); err != nil { + return err + } + start := xml.StartElement{Name: xml.Name{Local: messageset.Name}} + if err := e.EncodeElement(&messageset, start); err != nil { + return err + } + if err := e.EncodeToken(messageSetElement.End()); err != nil { + return err + } + } + if err := e.EncodeToken(messageSetListElement.End()); err != nil { + return err + } + return nil +} + // ProfileResponse contains a requested profile of the server's capabilities // (which message sets and versions it supports, how to access them, which // languages and which types of synchronization they support, etc.). Note that diff --git a/profile_test.go b/profile_test.go index 71db474..b15d0c3 100644 --- a/profile_test.go +++ b/profile_test.go @@ -325,4 +325,5 @@ NEWFILEUID:NONE } checkResponsesEqual(t, &expected, response) + checkResponseRoundTrip(t, response) } diff --git a/request.go b/request.go index dc63f19..7b6171d 100644 --- a/request.go +++ b/request.go @@ -3,7 +3,6 @@ package ofxgo import ( "bytes" "errors" - "fmt" "github.com/aclindsa/xml" "time" ) @@ -35,7 +34,7 @@ type Request struct { indent bool // Whether to indent the marshaled XML } -func marshalMessageSet(e *xml.Encoder, requests []Message, set messageType, version ofxVersion) error { +func encodeMessageSet(e *xml.Encoder, requests []Message, set messageType, version ofxVersion) error { if len(requests) > 0 { messageSetElement := xml.StartElement{Name: xml.Name{Local: set.String()}} if err := e.EncodeToken(messageSetElement); err != nil { @@ -80,25 +79,7 @@ func (oq *Request) Marshal() (*bytes.Buffer, error) { var b bytes.Buffer // Write the header appropriate to our version - switch oq.Version { - case OfxVersion102, OfxVersion103, OfxVersion151, OfxVersion160: - b.WriteString(`OFXHEADER:100 -DATA:OFXSGML -VERSION:` + oq.Version.String() + ` -SECURITY:NONE -ENCODING:USASCII -CHARSET:1252 -COMPRESSION:NONE -OLDFILEUID:NONE -NEWFILEUID:NONE - -`) - case OfxVersion200, OfxVersion201, OfxVersion202, OfxVersion203, OfxVersion210, OfxVersion211, OfxVersion220: - b.WriteString(`` + "\n") - b.WriteString(`` + "\n") - default: - return nil, fmt.Errorf("%d is not a valid OFX version string", oq.Version) - } + writeHeader(&b, oq.Version) encoder := xml.NewEncoder(&b) if oq.indent { @@ -145,7 +126,7 @@ NEWFILEUID:NONE {oq.Image, ImageRq}, } for _, set := range messageSets { - if err := marshalMessageSet(encoder, set.Messages, set.Type, oq.Version); err != nil { + if err := encodeMessageSet(encoder, set.Messages, set.Type, oq.Version); err != nil { return nil, err } } diff --git a/response.go b/response.go index 1d177be..9ab4db7 100644 --- a/response.go +++ b/response.go @@ -369,3 +369,70 @@ func ParseResponse(reader io.Reader) (*Response, error) { } } } + +// Marshal this Response into its SGML/XML representation held in a bytes.Buffer +// +// If error is non-nil, this bytes.Buffer is ready to be sent to an OFX client +func (or *Response) Marshal() (*bytes.Buffer, error) { + var b bytes.Buffer + + // Write the header appropriate to our version + writeHeader(&b, or.Version) + + encoder := xml.NewEncoder(&b) + encoder.Indent("", " ") + + ofxElement := xml.StartElement{Name: xml.Name{Local: "OFX"}} + + if err := encoder.EncodeToken(ofxElement); err != nil { + return nil, err + } + + if ok, err := or.Signon.Valid(or.Version); !ok { + return nil, err + } + signonMsgSet := xml.StartElement{Name: xml.Name{Local: SignonRs.String()}} + if err := encoder.EncodeToken(signonMsgSet); err != nil { + return nil, err + } + if err := encoder.Encode(&or.Signon); err != nil { + return nil, err + } + if err := encoder.EncodeToken(signonMsgSet.End()); err != nil { + return nil, err + } + + messageSets := []struct { + Messages []Message + Type messageType + }{ + {or.Signup, SignupRs}, + {or.Bank, BankRs}, + {or.CreditCard, CreditCardRs}, + {or.Loan, LoanRs}, + {or.InvStmt, InvStmtRs}, + {or.InterXfer, InterXferRs}, + {or.WireXfer, WireXferRs}, + {or.Billpay, BillpayRs}, + {or.Email, EmailRs}, + {or.SecList, SecListRs}, + {or.PresDir, PresDirRs}, + {or.PresDlv, PresDlvRs}, + {or.Prof, ProfRs}, + {or.Image, ImageRs}, + } + for _, set := range messageSets { + if err := encodeMessageSet(encoder, set.Messages, set.Type, or.Version); err != nil { + return nil, err + } + } + + if err := encoder.EncodeToken(ofxElement.End()); err != nil { + return nil, err + } + + if err := encoder.Flush(); err != nil { + return nil, err + } + return &b, nil +} diff --git a/response_test.go b/response_test.go index 1018002..c1faa3b 100644 --- a/response_test.go +++ b/response_test.go @@ -136,6 +136,20 @@ func checkResponsesEqual(t *testing.T, expected, actual *ofxgo.Response) { checkEqual(t, "", reflect.ValueOf(expected), reflect.ValueOf(actual)) } +func checkResponseRoundTrip(t *testing.T, response *ofxgo.Response) { + b, err := response.Marshal() + if err != nil { + t.Fatalf("Unexpected error re-marshaling OFX response: %s\n", err) + } + roundtripped, err := ofxgo.ParseResponse(b) + if err != nil { + t.Fatalf("Unexpected error re-parsing OFX response: %s\n", err) + } + checkResponsesEqual(t, response, roundtripped) +} + +// Ensure that these samples both parse without errors, and can be converted +// back and forth without changing. func TestValidSamples(t *testing.T) { fn := func(path string, info os.FileInfo, err error) error { if info.IsDir() { @@ -147,10 +161,11 @@ func TestValidSamples(t *testing.T) { if err != nil { t.Fatalf("Unexpected error opening %s: %s\n", path, err) } - _, err = ofxgo.ParseResponse(file) + response, err := ofxgo.ParseResponse(file) if err != nil { t.Fatalf("Unexpected error parsing OFX response in %s: %s\n", path, err) } + checkResponseRoundTrip(t, response) return nil } filepath.Walk("samples/valid_responses", fn) diff --git a/seclist.go b/seclist.go index 8c60582..c8d3334 100644 --- a/seclist.go +++ b/seclist.go @@ -221,6 +221,7 @@ func (i StockInfo) SecurityType() string { // SecurityList is a container for Security objects containaing information // about securities type SecurityList struct { + XMLName xml.Name `xml:"SECLIST"` Securities []Security } @@ -290,3 +291,42 @@ func (r *SecurityList) UnmarshalXML(d *xml.Decoder, start xml.StartElement) erro } } } + +// MarshalXML handles marshalling a SecurityList to an SGML/XML string +func (r *SecurityList) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + secListElement := xml.StartElement{Name: xml.Name{Local: "SECLIST"}} + if err := e.EncodeToken(secListElement); err != nil { + return err + } + for _, s := range r.Securities { + start := xml.StartElement{Name: xml.Name{Local: s.SecurityType()}} + switch sec := s.(type) { + case DebtInfo: + if err := e.EncodeElement(&sec, start); err != nil { + return err + } + case MFInfo: + if err := e.EncodeElement(&sec, start); err != nil { + return err + } + case OptInfo: + if err := e.EncodeElement(&sec, start); err != nil { + return err + } + case OtherInfo: + if err := e.EncodeElement(&sec, start); err != nil { + return err + } + case StockInfo: + if err := e.EncodeElement(&sec, start); err != nil { + return err + } + default: + return errors.New("Invalid SECLIST child type: " + sec.SecurityType()) + } + } + if err := e.EncodeToken(secListElement.End()); err != nil { + return err + } + return nil +} diff --git a/signup_test.go b/signup_test.go index da2ac68..115ec1f 100644 --- a/signup_test.go +++ b/signup_test.go @@ -155,4 +155,5 @@ func TestUnmarshalAcctInfoResponse(t *testing.T) { } checkResponsesEqual(t, &expected, response) + checkResponseRoundTrip(t, response) }