diff --git a/internal/db/db.go b/internal/db/db.go index 988fd82..f1d976c 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -11,19 +11,19 @@ import ( "gopkg.in/gorp.v1" ) -func GetDbMap(db *sql.DB, cfg *config.Config) (*gorp.DbMap, error) { +func GetDbMap(db *sql.DB, dbtype config.DbType) (*gorp.DbMap, error) { var dialect gorp.Dialect - if cfg.MoneyGo.DBType == config.SQLite { + if dbtype == config.SQLite { dialect = gorp.SqliteDialect{} - } else if cfg.MoneyGo.DBType == config.MySQL { + } else if dbtype == config.MySQL { dialect = gorp.MySQLDialect{ Engine: "InnoDB", Encoding: "UTF8", } - } else if cfg.MoneyGo.DBType == config.Postgres { + } else if dbtype == config.Postgres { dialect = gorp.PostgresDialect{} } else { - return nil, fmt.Errorf("Don't know gorp dialect to go with '%s' DB type", cfg.MoneyGo.DBType.String()) + return nil, fmt.Errorf("Don't know gorp dialect to go with '%s' DB type", dbtype.String()) } dbmap := &gorp.DbMap{Db: db, Dialect: dialect} diff --git a/internal/handlers/securities.go b/internal/handlers/securities.go index 73ee50e..acbc18c 100644 --- a/internal/handlers/securities.go +++ b/internal/handlers/securities.go @@ -56,6 +56,11 @@ func (s *Security) Write(w http.ResponseWriter) error { return enc.Encode(s) } +func (sl *SecurityList) Read(json_str string) error { + dec := json.NewDecoder(strings.NewReader(json_str)) + return dec.Decode(sl) +} + func (sl *SecurityList) Write(w http.ResponseWriter) error { enc := json.NewEncoder(w) return enc.Encode(sl) @@ -415,7 +420,16 @@ func SecurityTemplateHandler(w http.ResponseWriter, r *http.Request) { var limit int64 = -1 search := query.Get("search") - _type := GetSecurityType(query.Get("type")) + + var _type int64 = 0 + typestring := query.Get("type") + if len(typestring) > 0 { + _type = GetSecurityType(typestring) + if _type == 0 { + WriteError(w, 3 /*Invalid Request*/) + return + } + } limitstring := query.Get("limit") if limitstring != "" { diff --git a/internal/handlers/security_template_test.go b/internal/handlers/security_template_test.go new file mode 100644 index 0000000..a8fb354 --- /dev/null +++ b/internal/handlers/security_template_test.go @@ -0,0 +1,81 @@ +package handlers_test + +import ( + "database/sql" + "github.com/aclindsa/moneygo/internal/config" + "github.com/aclindsa/moneygo/internal/db" + "github.com/aclindsa/moneygo/internal/handlers" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "os" + "path" + "testing" +) + +var server *httptest.Server + +func RunTests(m *testing.M) int { + tmpdir, err := ioutil.TempDir("./", "handlertest") + if err != nil { + log.Fatal(err) + } + defer os.RemoveAll(tmpdir) + + dbpath := path.Join(tmpdir, "moneygo.sqlite") + database, err := sql.Open("sqlite3", "file:"+dbpath+"?cache=shared&mode=rwc") + if err != nil { + log.Fatal(err) + } + defer database.Close() + + dbmap, err := db.GetDbMap(database, config.SQLite) + if err != nil { + log.Fatal(err) + } + + servemux := handlers.GetHandler(dbmap) + server = httptest.NewServer(servemux) + defer server.Close() + + return m.Run() +} + +func TestMain(m *testing.M) { + os.Exit(RunTests(m)) +} + +func TestSecurityTemplates(t *testing.T) { + var sl handlers.SecurityList + response, err := http.Get(server.URL + "/securitytemplate/?search=USD&type=currency") + if err != nil { + t.Error(err) + } + + body, err := ioutil.ReadAll(response.Body) + response.Body.Close() + if err != nil { + t.Error(err) + } + + err = (&sl).Read(string(body)) + if err != nil { + t.Error(err) + } + + num_usd := 0 + for _, s := range *sl.Securities { + if s.Type != handlers.Currency { + t.Fatalf("Requested Currency-only security templates, received a non-Currency template for %s", s.Name) + } + + if s.Name == "USD" && s.AlternateId == "840" { + num_usd++ + } + } + + if num_usd != 1 { + t.Fatalf("Expected one USD security template, found %d\n", num_usd) + } +} diff --git a/main.go b/main.go index 65583c6..b426656 100644 --- a/main.go +++ b/main.go @@ -72,7 +72,7 @@ func main() { } defer database.Close() - dbmap, err := db.GetDbMap(database, cfg) + dbmap, err := db.GetDbMap(database, cfg.MoneyGo.DBType) if err != nil { log.Fatal(err) }