diff --git a/response.go b/response.go index 75907c0..f88ee58 100644 --- a/response.go +++ b/response.go @@ -187,6 +187,7 @@ const guessVersionCheckBytes = 1024 // Defaults to XML if it can't determine the version or if there is any // ambiguity +// Returns false for SGML, true (for XML) otherwise. func guessVersion(r *bufio.Reader) (bool, error) { b, _ := r.Peek(guessVersionCheckBytes) if b == nil { @@ -236,14 +237,13 @@ func decodeMessageSet(d *xml.Decoder, start xml.StartElement, msgs *[]Message, v if !ok { return errors.New("Invalid message set: " + start.Name.Local) } - var errs ErrInvalid for { tok, err := nextNonWhitespaceToken(d) if err != nil { return err } else if end, ok := tok.(xml.EndElement); ok && end.Name.Local == start.Name.Local { // If we found the end of our starting element, we're done parsing - return errs.ErrOrNil() + return nil } else if startElement, ok := tok.(xml.StartElement); ok { responseType, ok := setTypes[startElement.Name.Local] if !ok { @@ -258,9 +258,6 @@ func decodeMessageSet(d *xml.Decoder, start xml.StartElement, msgs *[]Message, v if err := d.DecodeElement(responseMessage, &startElement); err != nil { return err } - if ok, err := responseMessage.Valid(version); !ok { - errs.AddErr(err) - } *msgs = append(*msgs, responseMessage) } else { return errors.New("Didn't find an opening element") @@ -268,14 +265,25 @@ func decodeMessageSet(d *xml.Decoder, start xml.StartElement, msgs *[]Message, v } } -// ParseResponse parses an OFX response in SGML or XML into a Response object -// from the given io.Reader +// ParseResponse parses and validates an OFX response in SGML or XML into a +// Response object from the given io.Reader // // It is commonly used as part of Client.Request(), but may be used on its own // to parse already-downloaded OFX files (such as those from 'Web Connect'). It // performs version autodetection if it can and attempts to be as forgiving as // possible about the input format. func ParseResponse(reader io.Reader) (*Response, error) { + resp, err := DecodeResponse(reader) + if err != nil { + return nil, err + } + _, err = resp.Valid() + return resp, err +} + +// DecodeResponse parses an OFX response in SGML or XML into a Response object +// from the given io.Reader +func DecodeResponse(reader io.Reader) (*Response, error) { var or Response r := bufio.NewReaderSize(reader, guessVersionCheckBytes) @@ -326,17 +334,12 @@ func ParseResponse(reader io.Reader) (*Response, error) { return nil, errors.New("Missing opening SIGNONMSGSRSV1 xml element") } - var errs ErrInvalid - tok, err = nextNonWhitespaceToken(decoder) if err != nil { return nil, err } else if signonEnd, ok := tok.(xml.EndElement); !ok || signonEnd.Name.Local != SignonRs.String() { return nil, errors.New("Missing closing SIGNONMSGSRSV1 xml element") } - if ok, err := or.Signon.Valid(or.Version); !ok { - errs.AddErr(err) - } var messageSlices = map[string]*[]Message{ SignupRs.String(): &or.Signup, @@ -360,17 +363,14 @@ func ParseResponse(reader io.Reader) (*Response, error) { if err != nil { return nil, err } else if ofxEnd, ok := tok.(xml.EndElement); ok && ofxEnd.Name.Local == "OFX" { - return &or, errs.ErrOrNil() // found closing XML element, so we're done + return &or, nil // found closing XML element, so we're done } else if start, ok := tok.(xml.StartElement); ok { slice, ok := messageSlices[start.Name.Local] if !ok { return nil, errors.New("Invalid message set: " + start.Name.Local) } if err := decodeMessageSet(decoder, start, slice, or.Version); err != nil { - if _, ok := err.(ErrInvalid); !ok { - return nil, err - } - errs.AddErr(err) + return nil, err } } else { return nil, errors.New("Found unexpected token") @@ -378,6 +378,38 @@ func ParseResponse(reader io.Reader) (*Response, error) { } } +// Valid returns whether the Response is valid according to the OFX spec +func (or *Response) Valid() (bool, error) { + var errs ErrInvalid + if ok, err := or.Signon.Valid(or.Version); !ok { + errs.AddErr(err) + } + for _, messageSet := range [][]Message{ + or.Signup, + or.Bank, + or.CreditCard, + or.Loan, + or.InvStmt, + or.InterXfer, + or.WireXfer, + or.Billpay, + or.Email, + or.SecList, + or.PresDir, + or.PresDlv, + or.Prof, + or.Image, + } { + for _, message := range messageSet { + if ok, err := message.Valid(or.Version); !ok { + errs.AddErr(err) + } + } + } + err := errs.ErrOrNil() + return err == nil, err +} + // Marshal this Response into its SGML/XML representation held in a bytes.Buffer // // If error is non-nil, this bytes.Buffer is ready to be sent to an OFX client diff --git a/response_test.go b/response_test.go index ee24498..54f6895 100644 --- a/response_test.go +++ b/response_test.go @@ -177,7 +177,7 @@ func TestValidSamples(t *testing.T) { func TestInvalidResponse(t *testing.T) { // in this example, the severity is invalid due to mixed upper and lower case letters - resp, err := ofxgo.ParseResponse(bytes.NewReader([]byte(` + const invalidResponse = ` OFXHEADER:100 DATA:OFXSGML VERSION:102 @@ -208,20 +208,58 @@ NEWFILEUID:NONE -`))) - expectedErr := "Validation failed: Invalid STATUS>SEVERITY; Invalid STATUS>SEVERITY" - if err == nil { - t.Fatalf("ParseResponse should fail with %q, found nil", expectedErr) - } - if _, ok := err.(ofxgo.ErrInvalid); !ok { - t.Errorf("ParseResponse should return an error with type ErrInvalid, found %T", err) - } - if err.Error() != expectedErr { - t.Errorf("ParseResponse should fail with %q, found %v", expectedErr, err) - } - if resp == nil { - t.Errorf("Response must not be nil if only validation errors are present") - } +` + const expectedErr = "Validation failed: Invalid STATUS>SEVERITY; Invalid STATUS>SEVERITY" + + t.Run("parse response", func(t *testing.T) { + resp, err := ofxgo.ParseResponse(bytes.NewReader([]byte(invalidResponse))) + expectedErr := "Validation failed: Invalid STATUS>SEVERITY; Invalid STATUS>SEVERITY" + if err == nil { + t.Fatalf("ParseResponse should fail with %q, found nil", expectedErr) + } + if _, ok := err.(ofxgo.ErrInvalid); !ok { + t.Errorf("ParseResponse should return an error with type ErrInvalid, found %T", err) + } + if err.Error() != expectedErr { + t.Errorf("ParseResponse should fail with %q, found %v", expectedErr, err) + } + if resp == nil { + t.Errorf("Response must not be nil if only validation errors are present") + } + }) + + t.Run("parse failed", func(t *testing.T) { + resp, err := ofxgo.ParseResponse(bytes.NewReader(nil)) + if err == nil { + t.Error("ParseResponse should fail to decode") + } + if resp != nil { + t.Errorf("ParseResponse should return a nil response, found: %v", resp) + } + }) + + t.Run("decode, then validate response", func(t *testing.T) { + resp, err := ofxgo.DecodeResponse(bytes.NewReader([]byte(invalidResponse))) + if err != nil { + t.Errorf("Unexpected error: %s", err.Error()) + } + if resp == nil { + t.Fatal("Response should not be nil from successful decode") + } + valid, err := resp.Valid() + if valid { + t.Error("Response should not be valid") + } + if err == nil { + t.Fatalf("response.Valid() should fail with %q, found nil", expectedErr) + } + if _, ok := err.(ofxgo.ErrInvalid); !ok { + t.Errorf("response.Valid() should return an error of type ErrInvalid, found: %T", err) + } + if err.Error() != expectedErr { + t.Errorf("response.Valid() should return an error with message %q, but found %q", expectedErr, err.Error()) + } + }) } func TestErrInvalidError(t *testing.T) {