diff --git a/internal/handlers/common_test.go b/internal/handlers/common_test.go new file mode 100644 index 0000000..8966a18 --- /dev/null +++ b/internal/handlers/common_test.go @@ -0,0 +1,66 @@ +package handlers_test + +import ( + "database/sql" + "github.com/aclindsa/moneygo/internal/config" + "github.com/aclindsa/moneygo/internal/db" + "github.com/aclindsa/moneygo/internal/handlers" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path" + "strings" + "testing" +) + +var server *httptest.Server + +func Delete(client *http.Client, url string) (*http.Response, error) { + request, err := http.NewRequest(http.MethodDelete, url, nil) + if err != nil { + return nil, err + } + return client.Do(request) +} + +func PutForm(client *http.Client, url string, data url.Values) (*http.Response, error) { + request, err := http.NewRequest(http.MethodPut, url, strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return client.Do(request) +} + +func RunTests(m *testing.M) int { + tmpdir, err := ioutil.TempDir("./", "handlertest") + if err != nil { + log.Fatal(err) + } + defer os.RemoveAll(tmpdir) + + dbpath := path.Join(tmpdir, "moneygo.sqlite") + database, err := sql.Open("sqlite3", "file:"+dbpath+"?cache=shared&mode=rwc") + if err != nil { + log.Fatal(err) + } + defer database.Close() + + dbmap, err := db.GetDbMap(database, config.SQLite) + if err != nil { + log.Fatal(err) + } + + servemux := handlers.GetHandler(dbmap) + server = httptest.NewTLSServer(servemux) + defer server.Close() + + return m.Run() +} + +func TestMain(m *testing.M) { + os.Exit(RunTests(m)) +} diff --git a/internal/handlers/security_template_test.go b/internal/handlers/security_templates_test.go similarity index 64% rename from internal/handlers/security_template_test.go rename to internal/handlers/security_templates_test.go index 5bad9b4..db32d79 100644 --- a/internal/handlers/security_template_test.go +++ b/internal/handlers/security_templates_test.go @@ -1,54 +1,14 @@ package handlers_test import ( - "database/sql" - "github.com/aclindsa/moneygo/internal/config" - "github.com/aclindsa/moneygo/internal/db" "github.com/aclindsa/moneygo/internal/handlers" "io/ioutil" - "log" - "net/http" - "net/http/httptest" - "os" - "path" "testing" ) -var server *httptest.Server - -func RunTests(m *testing.M) int { - tmpdir, err := ioutil.TempDir("./", "handlertest") - if err != nil { - log.Fatal(err) - } - defer os.RemoveAll(tmpdir) - - dbpath := path.Join(tmpdir, "moneygo.sqlite") - database, err := sql.Open("sqlite3", "file:"+dbpath+"?cache=shared&mode=rwc") - if err != nil { - log.Fatal(err) - } - defer database.Close() - - dbmap, err := db.GetDbMap(database, config.SQLite) - if err != nil { - log.Fatal(err) - } - - servemux := handlers.GetHandler(dbmap) - server = httptest.NewServer(servemux) - defer server.Close() - - return m.Run() -} - -func TestMain(m *testing.M) { - os.Exit(RunTests(m)) -} - func TestSecurityTemplates(t *testing.T) { var sl handlers.SecurityList - response, err := http.Get(server.URL + "/securitytemplate/?search=USD&type=currency") + response, err := server.Client().Get(server.URL + "/securitytemplate/?search=USD&type=currency") if err != nil { t.Fatal(err) } @@ -85,7 +45,7 @@ func TestSecurityTemplates(t *testing.T) { func TestSecurityTemplateLimit(t *testing.T) { var sl handlers.SecurityList - response, err := http.Get(server.URL + "/securitytemplate/?search=e&limit=5") + response, err := server.Client().Get(server.URL + "/securitytemplate/?search=e&limit=5") if err != nil { t.Fatal(err) } @@ -111,7 +71,7 @@ func TestSecurityTemplateLimit(t *testing.T) { func TestSecurityTemplateInvalidType(t *testing.T) { var e handlers.Error - response, err := http.Get(server.URL + "/securitytemplate/?search=e&type=blah") + response, err := server.Client().Get(server.URL + "/securitytemplate/?search=e&type=blah") if err != nil { t.Fatal(err) } @@ -134,7 +94,7 @@ func TestSecurityTemplateInvalidType(t *testing.T) { func TestSecurityTemplateInvalidLimit(t *testing.T) { var e handlers.Error - response, err := http.Get(server.URL + "/securitytemplate/?search=e&type=Currency&limit=foo") + response, err := server.Client().Get(server.URL + "/securitytemplate/?search=e&type=Currency&limit=foo") if err != nil { t.Fatal(err) } diff --git a/internal/handlers/sessions.go b/internal/handlers/sessions.go index 732085d..bbc5ee8 100644 --- a/internal/handlers/sessions.go +++ b/internal/handlers/sessions.go @@ -8,6 +8,7 @@ import ( "io" "log" "net/http" + "strings" "time" ) @@ -22,6 +23,11 @@ func (s *Session) Write(w http.ResponseWriter) error { return enc.Encode(s) } +func (s *Session) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(s) +} + func GetSession(db *DB, r *http.Request) (*Session, error) { var s Session diff --git a/internal/handlers/sessions_test.go b/internal/handlers/sessions_test.go new file mode 100644 index 0000000..5e0702e --- /dev/null +++ b/internal/handlers/sessions_test.go @@ -0,0 +1,144 @@ +package handlers_test + +import ( + "encoding/json" + "fmt" + "github.com/aclindsa/moneygo/internal/handlers" + "io/ioutil" + "net/http" + "net/http/cookiejar" + "net/url" + "testing" +) + +func newSession(user *User) (*http.Client, error) { + var u User + var e handlers.Error + + jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: nil}) + if err != nil { + return nil, err + } + + client := server.Client() + client.Jar = jar + + bytes, err := json.Marshal(user) + if err != nil { + return nil, err + } + response, err := client.PostForm(server.URL+"/session/", url.Values{"user": {string(bytes)}}) + if err != nil { + return nil, err + } + + body, err := ioutil.ReadAll(response.Body) + response.Body.Close() + if err != nil { + return nil, err + } + + err = (&u).Read(string(body)) + if err != nil { + return nil, err + } + + err = (&e).Read(string(body)) + if err != nil { + return nil, err + } + + if e.ErrorId != 0 || len(e.ErrorString) != 0 { + return nil, fmt.Errorf("Unexpected error when creating session %+v", e) + } + + return client, nil +} + +func getSession(client *http.Client) (*handlers.Session, error) { + var s handlers.Session + response, err := client.Get(server.URL + "/session/") + if err != nil { + return nil, err + } + + body, err := ioutil.ReadAll(response.Body) + response.Body.Close() + if err != nil { + return nil, err + } + + err = (&s).Read(string(body)) + if err != nil { + return nil, err + } + + return &s, nil +} + +func sessionExistsOrError(c *http.Client) error { + + url, err := url.Parse(server.URL) + if err != nil { + return err + } + cookies := c.Jar.Cookies(url) + + var found_session bool = false + for _, cookie := range cookies { + if cookie.Name == "moneygo-session" { + found_session = true + } + } + if found_session { + return nil + } + return fmt.Errorf("Didn't find 'moneygo-session' cookie in CookieJar") +} + +func TestCreateSession(t *testing.T) { + u, err := createUser(&users[0]) + if err != nil { + t.Fatal(err) + } + u.Password = users[0].Password + + client, err := newSession(u) + if err != nil { + t.Fatal(err) + } + defer deleteUser(client, u) + if err := sessionExistsOrError(client); err != nil { + t.Fatal(err) + } +} + +func TestGetSession(t *testing.T) { + u, err := createUser(&users[0]) + if err != nil { + t.Fatal(err) + } + u.Password = users[0].Password + + client, err := newSession(u) + if err != nil { + t.Fatal(err) + } + defer deleteUser(client, u) + session, err := getSession(client) + if err != nil { + t.Fatal(err) + } + + if len(session.SessionSecret) != 0 { + t.Error("Session.SessionSecret should not be passed back in JSON") + } + + if session.UserId != u.UserId { + t.Errorf("session's UserId (%d) should equal user's UserID (%d)", session.UserId, u.UserId) + } + + if session.SessionId == 0 { + t.Error("session's SessionId should not be 0") + } +} diff --git a/internal/handlers/testdata_test.go b/internal/handlers/testdata_test.go new file mode 100644 index 0000000..1cacd61 --- /dev/null +++ b/internal/handlers/testdata_test.go @@ -0,0 +1,39 @@ +package handlers_test + +import ( + "encoding/json" + "net/http" + "strings" +) + +// Needed because handlers.User doesn't allow Password to be written to JSON + +type User struct { + UserId int64 + DefaultCurrency int64 // SecurityId of default currency, or ISO4217 code for it if creating new user + Name string + Username string + Password string + PasswordHash string + Email string +} + +func (u *User) Write(w http.ResponseWriter) error { + enc := json.NewEncoder(w) + return enc.Encode(u) +} + +func (u *User) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(u) +} + +var users = []User{ + User{ + DefaultCurrency: 840, // USD + Name: "John Smith", + Username: "jsmith", + Password: "hunter2", + Email: "jsmith@example.com", + }, +} diff --git a/internal/handlers/users_test.go b/internal/handlers/users_test.go new file mode 100644 index 0000000..31ba708 --- /dev/null +++ b/internal/handlers/users_test.go @@ -0,0 +1,218 @@ +package handlers_test + +import ( + "encoding/json" + "fmt" + "github.com/aclindsa/moneygo/internal/handlers" + "io/ioutil" + "net/http" + "net/url" + "strconv" + "testing" +) + +func createUser(user *User) (*User, error) { + bytes, err := json.Marshal(user) + if err != nil { + return nil, err + } + response, err := server.Client().PostForm(server.URL+"/user/", url.Values{"user": {string(bytes)}}) + if err != nil { + return nil, err + } + + body, err := ioutil.ReadAll(response.Body) + response.Body.Close() + if err != nil { + return nil, err + } + + var e handlers.Error + err = (&e).Read(string(body)) + if err != nil { + return nil, err + } + if e.ErrorId != 0 || len(e.ErrorString) != 0 { + return nil, fmt.Errorf("Error when creating user %+v", e) + } + + var u User + err = (&u).Read(string(body)) + if err != nil { + return nil, err + } + + if u.UserId == 0 || len(u.Username) == 0 { + return nil, fmt.Errorf("Unable to create user: %+v", user) + } + + return &u, nil +} + +func updateUser(client *http.Client, user *User) (*User, error) { + bytes, err := json.Marshal(user) + if err != nil { + return nil, err + } + response, err := PutForm(client, server.URL+"/user/"+strconv.FormatInt(user.UserId, 10), url.Values{"user": {string(bytes)}}) + if err != nil { + return nil, err + } + + body, err := ioutil.ReadAll(response.Body) + response.Body.Close() + if err != nil { + return nil, err + } + + var e handlers.Error + err = (&e).Read(string(body)) + if err != nil { + return nil, err + } + if e.ErrorId != 0 || len(e.ErrorString) != 0 { + return nil, fmt.Errorf("Error when updating user %+v", e) + } + + var u User + err = (&u).Read(string(body)) + if err != nil { + return nil, err + } + + if u.UserId == 0 || len(u.Username) == 0 { + return nil, fmt.Errorf("Unable to update user: %+v", user) + } + + return &u, nil +} + +func deleteUser(client *http.Client, u *User) error { + response, err := Delete(client, server.URL+"/user/"+strconv.FormatInt(u.UserId, 10)) + if err != nil { + return err + } + + body, err := ioutil.ReadAll(response.Body) + response.Body.Close() + if err != nil { + return err + } + + var e handlers.Error + err = (&e).Read(string(body)) + if err != nil { + return err + } + if e.ErrorId != 0 || len(e.ErrorString) != 0 { + return fmt.Errorf("Error when deleting user %+v", e) + } + + return nil +} + +func getUser(client *http.Client, userid int64) (*User, error) { + response, err := client.Get(server.URL + "/user/" + strconv.FormatInt(userid, 10)) + if err != nil { + return nil, err + } + + body, err := ioutil.ReadAll(response.Body) + response.Body.Close() + if err != nil { + return nil, err + } + + var e handlers.Error + err = (&e).Read(string(body)) + if err != nil { + return nil, err + } + if e.ErrorId != 0 || len(e.ErrorString) != 0 { + return nil, fmt.Errorf("Error when get user %+v", e) + } + + var u User + err = (&u).Read(string(body)) + if err != nil { + return nil, err + } + + if u.UserId == 0 || len(u.Username) == 0 { + return nil, fmt.Errorf("Unable to get userid: %d", userid) + } + + return &u, nil +} + +func TestCreateUser(t *testing.T) { + u, err := createUser(&users[0]) + if err != nil { + t.Fatal(err) + } + + if len(u.Password) != 0 || len(u.PasswordHash) != 0 { + t.Error("Never send password, only send password hash when necessary") + } + + u.Password = users[0].Password + + client, err := newSession(u) + if err != nil { + t.Fatalf("Error creating new session, user not deleted (may cause errors in other tests): %s", err) + } + defer deleteUser(client, u) +} + +func TestGetUser(t *testing.T) { + origu, err := createUser(&users[0]) + if err != nil { + t.Fatal(err) + } + origu.Password = users[0].Password + + client, err := newSession(origu) + if err != nil { + t.Fatalf("Error creating new session, user not deleted (may cause errors in other tests): %s", err) + } + defer deleteUser(client, origu) + + u, err := getUser(client, origu.UserId) + if err != nil { + t.Fatalf("Error fetching user: %s\n", err) + } + if u.UserId != origu.UserId { + t.Errorf("UserId doesn't match") + } +} + +func TestUpdateUser(t *testing.T) { + origu, err := createUser(&users[0]) + if err != nil { + t.Fatal(err) + } + origu.Password = users[0].Password + + client, err := newSession(origu) + if err != nil { + t.Fatalf("Error creating new session, user not deleted (may cause errors in other tests): %s", err) + } + defer deleteUser(client, origu) + + origu.Name = "Bob" + origu.Email = "bob@example.com" + + u, err := updateUser(client, origu) + if err != nil { + t.Fatalf("Error updating user: %s\n", err) + } + if u.UserId != origu.UserId { + t.Errorf("UserId doesn't match") + } + if u.Name != origu.Name { + t.Errorf("UserId doesn't match") + } + if u.Email != origu.Email { + t.Errorf("UserId doesn't match") + } +}