diff --git a/client/asink.go b/client/asink.go index 0722fdb..31fdc48 100644 --- a/client/asink.go +++ b/client/asink.go @@ -121,7 +121,7 @@ func ProcessLocalEvent(globals AsinkGlobals, event *asink.Event) { if err != nil { //bail out if the file we are trying to upload already got deleted if util.ErrorFileNotFound(err) { - event.Status |= asink.DISCARDED + event.LocalStatus |= asink.DISCARDED return } panic(err) @@ -132,7 +132,7 @@ func ProcessLocalEvent(globals AsinkGlobals, event *asink.Event) { if err != nil { //bail out if the file we are trying to upload already got deleted if util.ErrorFileNotFound(err) { - event.Status |= asink.DISCARDED + event.LocalStatus |= asink.DISCARDED return } panic(err) @@ -150,7 +150,7 @@ func ProcessLocalEvent(globals AsinkGlobals, event *asink.Event) { //If the file didn't actually change, squash this event if latestLocal != nil && event.Hash == latestLocal.Hash { os.Remove(tmpfilename) - event.Status |= asink.DISCARDED + event.LocalStatus |= asink.DISCARDED return } @@ -173,7 +173,7 @@ func ProcessLocalEvent(globals AsinkGlobals, event *asink.Event) { } else { //if we're trying to delete a file that we thought was already deleted, there's no need to delete it again if latestLocal != nil && latestLocal.IsDelete() { - event.Status |= asink.DISCARDED + event.LocalStatus |= asink.DISCARDED return } } @@ -195,13 +195,15 @@ func ProcessRemoteEvent(globals AsinkGlobals, event *asink.Event) { //if we already have this event, or if it is older than our most recent event, bail out if latestLocal != nil { if event.Timestamp < latestLocal.Timestamp || event.IsSameEvent(latestLocal) { - event.Status |= asink.DISCARDED + event.LocalStatus |= asink.DISCARDED return } if latestLocal.Hash != event.Predecessor && latestLocal.Hash != event.Hash { - panic("conflict") - //TODO handle conflict + fmt.Printf("conflict:\n") + fmt.Printf("OLD %+v\n", latestLocal) + fmt.Printf("NEW %+v\n", event) + //TODO handle conflict? } } diff --git a/client/database.go b/client/database.go index c48097c..906168b 100644 --- a/client/database.go +++ b/client/database.go @@ -37,7 +37,7 @@ func GetAndInitDB(config *conf.ConfigFile) (*AsinkDB, error) { } if !rows.Next() { //if this is false, it means no rows were returned - tx.Exec("CREATE TABLE events (id INTEGER, localid INTEGER PRIMARY KEY ASC, type INTEGER, status INTEGER, path TEXT, hash TEXT, predecessor TEXT, timestamp INTEGER, permissions INTEGER);") + tx.Exec("CREATE TABLE events (id INTEGER, localid INTEGER PRIMARY KEY ASC, type INTEGER, localstatus INTEGER, path TEXT, hash TEXT, predecessor TEXT, timestamp INTEGER, permissions INTEGER);") // tx.Exec("CREATE INDEX IF NOT EXISTS localididx on events (localid)") tx.Exec("CREATE INDEX IF NOT EXISTS ididx on events (id);") tx.Exec("CREATE INDEX IF NOT EXISTS pathidx on events (path);") @@ -66,7 +66,7 @@ func (adb *AsinkDB) DatabaseAddEvent(e *asink.Event) (err error) { adb.lock.Unlock() }() - result, err := tx.Exec("INSERT INTO events (id, type, status, path, hash, predecessor, timestamp, permissions) VALUES (?,?,?,?,?,?,?,?);", e.Id, e.Type, e.Status, e.Path, e.Hash, e.Predecessor, e.Timestamp, e.Permissions) + result, err := tx.Exec("INSERT INTO events (id, type, localstatus, path, hash, predecessor, timestamp, permissions) VALUES (?,?,?,?,?,?,?,?);", e.Id, e.Type, e.LocalStatus, e.Path, e.Hash, e.Predecessor, e.Timestamp, e.Permissions) if err != nil { return err } @@ -102,7 +102,7 @@ func (adb *AsinkDB) DatabaseUpdateEvent(e *asink.Event) (err error) { adb.lock.Unlock() }() - result, err := tx.Exec("UPDATE events SET id=?, type=?, status=?, path=?, hash=?, prececessor=?, timestamp=?, permissions=? WHERE localid=?;", e.Id, e.Type, e.Status, e.Path, e.Hash, e.Predecessor, e.Timestamp, e.Permissions, e.LocalId) + result, err := tx.Exec("UPDATE events SET id=?, type=?, localstatus=?, path=?, hash=?, prececessor=?, timestamp=?, permissions=? WHERE localid=?;", e.Id, e.Type, e.LocalStatus, e.Path, e.Hash, e.Predecessor, e.Timestamp, e.Permissions, e.LocalId) if err != nil { return err } @@ -127,10 +127,10 @@ func (adb *AsinkDB) DatabaseLatestEventForPath(path string) (event *asink.Event, //make sure the database gets unlocked defer adb.lock.Unlock() - row := adb.db.QueryRow("SELECT id, localid, type, status, path, hash, predecessor, timestamp, permissions FROM events WHERE path == ? ORDER BY timestamp DESC LIMIT 1;", path) + row := adb.db.QueryRow("SELECT id, localid, type, localstatus, path, hash, predecessor, timestamp, permissions FROM events WHERE path == ? ORDER BY timestamp DESC LIMIT 1;", path) event = new(asink.Event) - err = row.Scan(&event.Id, &event.LocalId, &event.Type, &event.Status, &event.Path, &event.Hash, &event.Predecessor, &event.Timestamp, &event.Permissions) + err = row.Scan(&event.Id, &event.LocalId, &event.Type, &event.LocalStatus, &event.Path, &event.Hash, &event.Predecessor, &event.Timestamp, &event.Permissions) switch { case err == sql.ErrNoRows: @@ -148,10 +148,10 @@ func (adb *AsinkDB) DatabaseLatestRemoteEvent() (event *asink.Event, err error) //make sure the database gets unlocked defer adb.lock.Unlock() - row := adb.db.QueryRow("SELECT id, localid, type, status, path, hash, predecessor, timestamp, permissions FROM events WHERE id > 0 ORDER BY id DESC LIMIT 1;") + row := adb.db.QueryRow("SELECT id, localid, type, localstatus, path, hash, predecessor, timestamp, permissions FROM events WHERE id > 0 ORDER BY id DESC LIMIT 1;") event = new(asink.Event) - err = row.Scan(&event.Id, &event.LocalId, &event.Type, &event.Status, &event.Path, &event.Hash, &event.Predecessor, &event.Timestamp, &event.Permissions) + err = row.Scan(&event.Id, &event.LocalId, &event.Type, &event.LocalStatus, &event.Path, &event.Hash, &event.Predecessor, &event.Timestamp, &event.Permissions) switch { case err == sql.ErrNoRows: diff --git a/client/path_map.go b/client/path_map.go index 55c110a..4a859e5 100644 --- a/client/path_map.go +++ b/client/path_map.go @@ -33,7 +33,7 @@ func PathLocker(db *AsinkDB) { case event = <-pathUnlockerChan: if v, ok = m[event.Path]; ok != false { //only update status in data structures if the event hasn't been discarded - if event.Status&asink.DISCARDED == 0 { + if event.LocalStatus&asink.DISCARDED == 0 { if v.latestEvent == nil || !v.latestEvent.IsSameEvent(event) { err := db.DatabaseAddEvent(event) if err != nil { diff --git a/client/watcher.go b/client/watcher.go index 9bb6fe6..9f33597 100644 --- a/client/watcher.go +++ b/client/watcher.go @@ -21,6 +21,12 @@ func StartWatching(watchDir string, fileUpdates chan *asink.Event) { if err != nil { panic("Failed to watch " + path) } + } else if info.Mode().IsRegular() { + event := new(asink.Event) + event.Path = path + event.Type = asink.UPDATE + event.Timestamp = info.ModTime().UnixNano() + fileUpdates <- event } return nil } diff --git a/events.go b/events.go index 4837895..6485844 100644 --- a/events.go +++ b/events.go @@ -17,20 +17,22 @@ type EventStatus uint32 const ( //Local event status flags - DISCARDED = 1 << iota //event is to be discarded because it errored or is duplicate + DISCARDED = 1 << iota //event is to be discarded because it errored or is duplicate ) type Event struct { Id int64 - LocalId int64 Type EventType - Status EventStatus Path string Hash string Predecessor string Timestamp int64 Permissions os.FileMode - InDB bool `json:"-"` //defaults to false. Omitted from json marshalling. + Username string + Sharename string + LocalStatus EventStatus `json:"-"` + LocalId int64 `json:"-"` + InDB bool `json:"-"` //defaults to false. Omitted from json marshalling. } func (e *Event) IsUpdate() bool { diff --git a/server/admin/admin.go b/server/admin/admin.go new file mode 100644 index 0000000..daf4063 --- /dev/null +++ b/server/admin/admin.go @@ -0,0 +1,48 @@ +package main + +import ( + "fmt" + "os" +) + +type AdminCommand struct { + cmd string + fn func(args []string) + explanation string +} + +var commands []AdminCommand = []AdminCommand{ + AdminCommand{ + cmd: "useradd", + fn: UserAdd, + explanation: "Add a user", + }, + AdminCommand{ + cmd: "userdel", + fn: UserDel, + explanation: "Remove a user", + }, + AdminCommand{ + cmd: "usermod", + fn: UserMod, + explanation: "Modify a user", + }, +} + +func main() { + if len(os.Args) > 1 { + cmd := os.Args[1] + for _, c := range commands { + if c.cmd == cmd { + c.fn(os.Args[2:]) + return + } + } + fmt.Println("Invalid subcommand specified, please pick from the following:") + } else { + fmt.Println("No subcommand specified, please pick one from the following:") + } + for _, c := range commands { + fmt.Printf("\t%s\t\t%s\n", c.cmd, c.explanation) + } +} diff --git a/server/admin/rpc.go b/server/admin/rpc.go new file mode 100644 index 0000000..f970008 --- /dev/null +++ b/server/admin/rpc.go @@ -0,0 +1,27 @@ +package main + +import ( + "net" + "net/rpc" + "log" + "syscall" +) + +func RPCCall(method string, args interface{}, reply interface{}) error { + socket := "/tmp/asink.sock" + client, err := rpc.DialHTTP("unix", socket) + if err != nil { + if err2, ok := err.(*net.OpError); ok { + if err2.Err == syscall.ENOENT { + log.Fatal("The socket ("+socket+") was not found") + } else if err2.Err == syscall.ECONNREFUSED { + log.Fatal("A connection was refused to "+socket+". Please check the permissions and ensure the server is running.") + } + } + return err + } + defer client.Close() + + err = client.Call(method, args, reply) + return err +} diff --git a/server/admin/users.go b/server/admin/users.go new file mode 100644 index 0000000..f694270 --- /dev/null +++ b/server/admin/users.go @@ -0,0 +1,153 @@ +package main + +import ( + "asink/server" + "code.google.com/p/gopass" + "flag" + "fmt" + "os" + "strconv" + "reflect" +) + +type boolIsSetFlag struct { + Value bool + IsSet bool //true if explicitly set from the command-line, false otherwise +} + +func newBoolIsSetFlag(defaultValue bool) *boolIsSetFlag { + b := new(boolIsSetFlag) + b.Value = defaultValue + return b +} + +func (b *boolIsSetFlag) Set(value string) error { + v, err := strconv.ParseBool(value) + b.Value = v + b.IsSet = true + return err +} + +func (b *boolIsSetFlag) String() string { return fmt.Sprintf("%v", *b) } + +func (b *boolIsSetFlag) IsBoolFlag() bool { return true } + +func UserAdd(args []string) { + flags := flag.NewFlagSet("useradd", flag.ExitOnError) + admin := flags.Bool("admin", false, "User should be an administrator") + flags.Parse(args) + + if flags.NArg() != 1 { + fmt.Println("Error: please supply a username (and only one)") + os.Exit(1) + } + + passwordOne, err := gopass.GetPass("Enter password for new user: ") + if err != nil { + panic(err) + } + passwordTwo, err := gopass.GetPass("Enter the same password again: ") + if err != nil { + panic(err) + } + + if passwordOne != passwordTwo { + fmt.Println("Error: Passwords do not match. Please try again.") + os.Exit(1) + } + + user := new(server.User) + + if *admin { + user.Role = server.ADMIN + } else { + user.Role = server.NORMAL + } + user.Username = flags.Arg(0) + user.PWHash = server.HashPassword(passwordOne) + + fmt.Println(user) + i := 99 + err = RPCCall("UserModifier.AddUser", user, &i) + if err != nil { + panic(err) + } +} + +func UserDel(args []string) { + if len(args) != 1 { + fmt.Println("Error: please supply a username (and only one)") + os.Exit(1) + } + + user := new(server.User) + user.Username = args[0] + + fmt.Println(user) + i := 99 + err := RPCCall("UserModifier.RemoveUser", user, &i) + if err != nil { + panic(err) + } +} + +func UserMod(args []string) { + rpcargs := new(server.UserModifierArgs) + rpcargs.Current = new(server.User) + rpcargs.Updated = new(server.User) + + admin := newBoolIsSetFlag(false) + + flags := flag.NewFlagSet("usermod", flag.ExitOnError) + flags.Var(admin, "admin", "User should be an administrator") + flags.BoolVar(&rpcargs.UpdatePassword, "password", false, "Change the user's password") + flags.BoolVar(&rpcargs.UpdatePassword, "p", false, "Change the user's password (short version)") + flags.BoolVar(&rpcargs.UpdateLogin, "login", false, "Change the user's username") + flags.BoolVar(&rpcargs.UpdateLogin, "l", false, "Change the user's username (short version)") + flags.Parse(args) + + //set the UpdateAdmin flag based on whether it was present on the command-line + + if flags.NArg() != 1 { + fmt.Println("Error: please supply a username (and only one)") + os.Exit(1) + } + rpcargs.Current.Username = flags.Arg(0) + + if rpcargs.UpdateLogin == true { + fmt.Print("New login: ") + fmt.Scanf("%s", &rpcargs.Updated.Username) + } + + if rpcargs.UpdatePassword { + passwordOne, err := gopass.GetPass("Enter new password for user: ") + if err != nil { + panic(err) + } + passwordTwo, err := gopass.GetPass("Enter the same password again: ") + if err != nil { + panic(err) + } + + if passwordOne != passwordTwo { + fmt.Println("Error: Passwords do not match. Please try again.") + os.Exit(1) + } + rpcargs.Updated.PWHash = server.HashPassword(passwordOne) + } + + rpcargs.UpdateRole = admin.IsSet + if admin.Value { + rpcargs.Updated.Role = server.ADMIN + } else { + rpcargs.Updated.Role = server.NORMAL + } + + fmt.Println(rpcargs) + i := 99 + err := RPCCall("UserModifier.ModifyUser", rpcargs, &i) + if err != nil { + fmt.Println(reflect.TypeOf(err)) + panic(err) + } +} diff --git a/server/admin_rpc.go b/server/admin_rpc.go new file mode 100644 index 0000000..20da3aa --- /dev/null +++ b/server/admin_rpc.go @@ -0,0 +1,58 @@ +package server + +import ( + "fmt" + "net" + "net/http" + "net/rpc" +) + +type UserModifier int + +type UserModifierArgs struct { + Current *User + Updated *User + UpdateLogin bool + UpdateRole bool + UpdatePassword bool +} + +func (u *UserModifier) AddUser(user *User, result *int) error { + fmt.Println("adding user: ", user) + ret := 0 + result = &ret + return nil +} + +func (u *UserModifier) ModifyUser(args *UserModifierArgs, result *int) error { + fmt.Println("modifying user: ", args) + fmt.Println("from: ", args.Current) + fmt.Println("to: ", args.Updated) + ret := 0 + result = &ret + return nil +} + +func (u *UserModifier) RemoveUser(user *User, result *int) error { + fmt.Println("removing user: ", user) + ret := 0 + result = &ret + return nil +} + +func StartRPC(townDown chan int) { + defer func() { townDown <- 0 }() //the main thread waits for this to ensure the socket is closed + + usermod := new(UserModifier) + rpc.Register(usermod) + rpc.HandleHTTP() + l, err := net.Listen("unix", "/tmp/asink.sock") + if err != nil { + panic(err) + } + defer l.Close() + + go http.Serve(l, nil) + + WaitOnExit() +} diff --git a/server/database.go b/server/database.go deleted file mode 100644 index 2a842de..0000000 --- a/server/database.go +++ /dev/null @@ -1,100 +0,0 @@ -package main - -import ( - "asink" - "database/sql" - _ "github.com/mattn/go-sqlite3" - "sync" -) - -type AsinkDB struct { - db *sql.DB - lock sync.Mutex -} - -func GetAndInitDB() (*AsinkDB, error) { - dbLocation := "asink-server.db" //TODO make me configurable - - db, err := sql.Open("sqlite3", "file:"+dbLocation+"?cache=shared&mode=rwc") - if err != nil { - return nil, err - } - - //make sure the events table is created - tx, err := db.Begin() - if err != nil { - return nil, err - } - rows, err := tx.Query("SELECT name FROM sqlite_master WHERE type='table' AND name='events';") - if err != nil { - return nil, err - } - if !rows.Next() { - //if this is false, it means no rows were returned - tx.Exec("CREATE TABLE events (id INTEGER PRIMARY KEY ASC, localid INTEGER, type INTEGER, status INTEGER, path TEXT, hash TEXT, predecessor TEXT, timestamp INTEGER, permissions INTEGER);") - tx.Exec("CREATE INDEX IF NOT EXISTS pathidx on events (path);") - } - err = tx.Commit() - if err != nil { - return nil, err - } - - ret := new(AsinkDB) - ret.db = db - return ret, nil -} - -func (adb *AsinkDB) DatabaseAddEvent(e *asink.Event) (err error) { - adb.lock.Lock() - tx, err := adb.db.Begin() - if err != nil { - return err - } - - //make sure the transaction gets rolled back on error, and the database gets unlocked - defer func() { - if err != nil { - tx.Rollback() - } - adb.lock.Unlock() - }() - - result, err := tx.Exec("INSERT INTO events (localid, type, status, path, hash, predecessor, timestamp, permissions) VALUES (?,?,?,?,?,?,?,?);", e.LocalId, e.Type, e.Status, e.Path, e.Hash, e.Predecessor, e.Timestamp, e.Permissions) - if err != nil { - return err - } - id, err := result.LastInsertId() - if err != nil { - return err - } - err = tx.Commit() - if err != nil { - return err - } - - e.Id = id - e.InDB = true - return nil -} - -func (adb *AsinkDB) DatabaseRetrieveEvents(firstId uint64, maxEvents uint) (events []*asink.Event, err error) { - adb.lock.Lock() - //make sure the database gets unlocked on return - defer func() { - adb.lock.Unlock() - }() - rows, err := adb.db.Query("SELECT id, localid, type, status, path, hash, predecessor, timestamp, permissions FROM events WHERE id >= ? ORDER BY id ASC LIMIT ?;", firstId, maxEvents) - if err != nil { - return nil, err - } - for rows.Next() { - var event asink.Event - err = rows.Scan(&event.Id, &event.LocalId, &event.Type, &event.Status, &event.Path, &event.Hash, &event.Predecessor, &event.Timestamp, &event.Permissions) - if err != nil { - return nil, err - } - events = append(events, &event) - } - - return events, nil -} diff --git a/server/exit.go b/server/exit.go new file mode 100644 index 0000000..fad888b --- /dev/null +++ b/server/exit.go @@ -0,0 +1,43 @@ +package server + +import ( + "os" + "os/signal" + "sync/atomic" +) + +var exitWaiterCount int32 +var exitCalled chan int +var exitWaiterChan chan int + +func init() { + exitWaiterCount = 0 + exitWaiterChan = make(chan int) + go setupCleanExitOnSignals() +} + +func setupCleanExitOnSignals() { + //wait to properly close the socket when we're exiting + exitCode := 0 + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt) + defer signal.Stop(sig) + + select { + case <-sig: + case exitCode = <-exitCalled: + } + + for c := atomic.AddInt32(&exitWaiterCount, -1); c >= 0; c = atomic.AddInt32(&exitWaiterCount, -1) { + exitWaiterChan <- exitCode + } +} + +func Exit(exitCode int) { + exitCalled <- exitCode +} + +func WaitOnExit() int { + atomic.AddInt32(&exitWaiterCount, 1) + return <-exitWaiterChan +} diff --git a/server/server/database.go b/server/server/database.go new file mode 100644 index 0000000..a92742a --- /dev/null +++ b/server/server/database.go @@ -0,0 +1,171 @@ +package main + +import ( + "asink" + "asink/server" + "database/sql" + _ "github.com/mattn/go-sqlite3" + "sync" +) + +type AsinkDB struct { + db *sql.DB + lock sync.Mutex +} + +func GetAndInitDB() (*AsinkDB, error) { + dbLocation := "asink-server.db" //TODO make me configurable + + db, err := sql.Open("sqlite3", "file:"+dbLocation+"?cache=shared&mode=rwc") + if err != nil { + return nil, err + } + + //make sure all the tables are created + tx, err := db.Begin() + if err != nil { + return nil, err + } + + rows, err := tx.Query("SELECT name FROM sqlite_master WHERE type='table' AND name='events';") + if err != nil { + return nil, err + } + if !rows.Next() { + //if this is false, it means no rows were returned + tx.Exec("CREATE TABLE events (id INTEGER PRIMARY KEY ASC, userid INTEGER, type INTEGER, path TEXT, hash TEXT, predecessor TEXT, timestamp INTEGER, permissions INTEGER);") + tx.Exec("CREATE INDEX IF NOT EXISTS pathidx on events (path);") + tx.Exec("CREATE INDEX IF NOT EXISTS timestampidx on events (timestamp);") + } else { + rows.Close() + } + + rows, err = tx.Query("SELECT name FROM sqlite_master WHERE type='table' AND name='users';") + if err != nil { + return nil, err + } + if !rows.Next() { + //if this is false, it means no rows were returned + tx.Exec("CREATE TABLE user (id INTEGER PRIMARY KEY ASC, username TEXT, pwhash TEXT, role INTEGER);") + } else { + rows.Close() + } + + err = tx.Commit() + if err != nil { + return nil, err + } + + ret := new(AsinkDB) + ret.db = db + return ret, nil +} + +func (adb *AsinkDB) DatabaseAddEvent(e *asink.Event) (err error) { + adb.lock.Lock() + tx, err := adb.db.Begin() + if err != nil { + return err + } + + //make sure the transaction gets rolled back on error, and the database gets unlocked + defer func() { + if err != nil { + tx.Rollback() + } + adb.lock.Unlock() + }() + + result, err := tx.Exec("INSERT INTO events (userid, type, path, hash, predecessor, timestamp, permissions) VALUES (?,?,?,?,?,?,?,?);", e.Type, e.Path, e.Hash, e.Predecessor, e.Timestamp, e.Permissions) + if err != nil { + return err + } + id, err := result.LastInsertId() + if err != nil { + return err + } + err = tx.Commit() + if err != nil { + return err + } + + e.Id = id + e.InDB = true + return nil +} + +func (adb *AsinkDB) DatabaseRetrieveEvents(firstId uint64, maxEvents uint) (events []*asink.Event, err error) { + adb.lock.Lock() + //make sure the database gets unlocked on return + defer func() { + adb.lock.Unlock() + }() + rows, err := adb.db.Query("SELECT id, type, path, hash, predecessor, timestamp, permissions FROM events WHERE id >= ? ORDER BY id ASC LIMIT ?;", firstId, maxEvents) + if err != nil { + return nil, err + } + for rows.Next() { + var event asink.Event + err = rows.Scan(&event.Id, &event.Type, &event.Path, &event.Hash, &event.Predecessor, &event.Timestamp, &event.Permissions) + if err != nil { + return nil, err + } + events = append(events, &event) + } + + return events, nil +} + +func (adb *AsinkDB) DatabaseAddUser(u *server.User) (err error) { + adb.lock.Lock() + tx, err := adb.db.Begin() + if err != nil { + return err + } + + //make sure the transaction gets rolled back on error, and the database gets unlocked + defer func() { + if err != nil { + tx.Rollback() + } + adb.lock.Unlock() + }() + + result, err := tx.Exec("INSERT INTO users (username, pwhash, role) VALUES (?,?);", u.Username, u.PWHash, u.Role) + if err != nil { + return err + } + id, err := result.LastInsertId() + if err != nil { + return err + } + err = tx.Commit() + if err != nil { + return err + } + + u.Id = id + return nil +} + +func (adb *AsinkDB) DatabaseGetUser(username string) (user *server.User, err error) { + adb.lock.Lock() + //make sure the database gets unlocked + defer adb.lock.Unlock() + + row := adb.db.QueryRow("SELECT id, username, pwhash, role FROM users WHERE username == ?;", username) + + user = new(server.User) + err = row.Scan(&user.Id, &user.Username, &user.PWHash, &user.Role) + + switch { + case err == sql.ErrNoRows: + return nil, nil + case err != nil: + return nil, err + default: + return user, nil + } +} + + diff --git a/server/longpolling.go b/server/server/longpolling.go similarity index 100% rename from server/longpolling.go rename to server/server/longpolling.go diff --git a/server/server.go b/server/server/server.go similarity index 73% rename from server/server.go rename to server/server/server.go index 30562b8..2549625 100644 --- a/server/server.go +++ b/server/server/server.go @@ -2,13 +2,17 @@ package main import ( "asink" + "asink/server" + "encoding/base64" "encoding/json" "flag" "fmt" "io/ioutil" + "net" "net/http" "regexp" "strconv" + "strings" ) //global variables @@ -34,15 +38,24 @@ func init() { func main() { flag.Parse() + rpcTornDown := make(chan int) + go server.StartRPC(rpcTornDown) + http.HandleFunc("/", rootHandler) http.HandleFunc("/events", eventHandler) http.HandleFunc("/events/", eventHandler) - //TODO replace with http://golang.org/pkg/net/http/#ListenAndServeTLS - err := http.ListenAndServe(fmt.Sprintf(":%d", port), nil) + //TODO add HTTPS, something like http://golang.org/pkg/net/http/#ListenAndServeTLS + l, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { - fmt.Println(err) + panic(err) } + defer l.Close() + go http.Serve(l, nil) + //TODO handle errors from http.Serve? + + server.WaitOnExit() + <-rpcTornDown } func rootHandler(w http.ResponseWriter, r *http.Request) { @@ -139,6 +152,19 @@ func putEvents(w http.ResponseWriter, r *http.Request) { } func eventHandler(w http.ResponseWriter, r *http.Request) { + user := AuthenticateUser(r) + if user == nil { + apiresponse := asink.APIResponse{ + Status: asink.ERROR, + Explanation: "This operation requires user authentication", + } + b, err := json.Marshal(apiresponse) + if err != nil { + b = []byte(err.Error()) + } + w.Write(b) + return + } if r.Method == "GET" { //if GET, return any events later than (and including) the event id passed in if sm := eventsRegexp.FindStringSubmatch(r.RequestURI); sm != nil { @@ -164,3 +190,34 @@ func eventHandler(w http.ResponseWriter, r *http.Request) { w.Write(b) } } + +func AuthenticateUser(r *http.Request) (user *server.User) { + h, ok := r.Header["Authorization"] + if !ok { + return nil + } + authparts := strings.Split(h[0], " ") + if len(authparts) != 2 || authparts[0] != "Basic" { + return nil + } + + userpass, err := base64.StdEncoding.DecodeString(authparts[1]) + if err != nil { + return nil + } + splituserpass := strings.Split(string(userpass), ":") + if len(splituserpass) != 2 { + return nil + } + + user, err = adb.DatabaseGetUser(splituserpass[0]) + if err != nil || user == nil { + return nil + } + + if user.ValidPassword(splituserpass[1]) { + return user + } else { + return nil + } +} diff --git a/server/users.go b/server/users.go new file mode 100644 index 0000000..66be0c5 --- /dev/null +++ b/server/users.go @@ -0,0 +1,35 @@ +package server + +import ( + "crypto/sha256" + "fmt" +) + +type UserRole uint32 + +const ( + //User roles + ADMIN = 1 << iota + NORMAL +) + +type User struct { + Id int64 + Username string + PWHash string + Role UserRole +} + +func HashPassword(pw string) string { + hashfn := sha256.New() + hashfn.Write([]byte(pw)) + return fmt.Sprintf("%x", hashfn.Sum(nil)) +} + +func (u *User) ValidPassword(pw string) bool { + return HashPassword(pw) == u.PWHash +} + +func (u *User) IsAdmin() bool { + return u.Role&ADMIN == ADMIN +}