diff --git a/asink/client.go b/asink/client.go index d241841..fccbb81 100644 --- a/asink/client.go +++ b/asink/client.go @@ -11,6 +11,7 @@ import ( "errors" "flag" "fmt" + "io" "io/ioutil" "os" "os/user" @@ -204,7 +205,19 @@ func ProcessLocalEvent(globals AsinkGlobals, event *asink.Event) { //upload file to remote storage StatStartUpload() - err = globals.storage.Put(cachedFilename, event.Hash) + uploadWriteCloser, err := globals.storage.Put(event.Hash) + if err != nil { + panic(err) + } + defer uploadWriteCloser.Close() + + uploadFile, err := os.Open(cachedFilename) + if err != nil { + panic(err) + } + defer uploadFile.Close() + + _, err = io.Copy(uploadWriteCloser, uploadFile) StatStopUpload() if err != nil { panic(err) @@ -270,9 +283,15 @@ func ProcessRemoteEvent(globals AsinkGlobals, event *asink.Event) { panic(err) //TODO handle sensibly } tmpfilename := outfile.Name() - outfile.Close() StatStartDownload() - err = globals.storage.Get(tmpfilename, event.Hash) + downloadReadCloser, err := globals.storage.Get(event.Hash) + if err != nil { + panic(err) + } + defer downloadReadCloser.Close() + _, err = io.Copy(outfile, downloadReadCloser) + + outfile.Close() StatStopDownload() if err != nil { panic(err) //TODO handle sensibly diff --git a/asink/storage.go b/asink/storage.go index 6a0e05c..d135570 100644 --- a/asink/storage.go +++ b/asink/storage.go @@ -7,11 +7,14 @@ package main import ( "code.google.com/p/goconf/conf" "errors" + "io" ) type Storage interface { - Put(filename string, hash string) error - Get(filename string, hash string) error + // Close() MUST be called on the returned io.WriteCloser + Put(hash string) (io.WriteCloser, error) + // Close() MUST be called on the returned io.ReadCloser + Get(hash string) (io.ReadCloser, error) } func GetStorage(config *conf.ConfigFile) (Storage, error) { diff --git a/asink/storage_ftp.go b/asink/storage_ftp.go index be38179..eb0c9cd 100644 --- a/asink/storage_ftp.go +++ b/asink/storage_ftp.go @@ -9,7 +9,6 @@ import ( "errors" "github.com/jlaffaye/goftp" "io" - "os" "strconv" ) @@ -58,68 +57,70 @@ func NewFTPStorage(config *conf.ConfigFile) (*FTPStorage, error) { return fs, nil } -func (fs *FTPStorage) Put(filename string, hash string) (e error) { +func (fs *FTPStorage) Put(hash string) (w io.WriteCloser, e error) { + returningNormally := false //make sure we don't flood the FTP server fs.connectionsChan <- 0 - defer func() { <-fs.connectionsChan }() - - infile, err := os.Open(filename) - if err != nil { - return err - } - defer infile.Close() + defer func() { + if !returningNormally { + <-fs.connectionsChan + } + }() connection, err := ftp.Connect(fs.server + ":" + strconv.Itoa(fs.port)) if err != nil { - return err + return nil, err } - defer connection.Quit() + defer func() { + if !returningNormally { + connection.Quit() + } + }() err = connection.Login(fs.username, fs.password) if err != nil { - return err + return nil, err } err = connection.ChangeDir(fs.directory) if err != nil { - return err + return nil, err } - return connection.Stor(hash, infile) + reader, writer := io.Pipe() + + go func() { + err := connection.Stor(hash, reader) + if err != nil { + reader.CloseWithError(err) + } + <-fs.connectionsChan + connection.Quit() + }() + + returningNormally = true + return writer, nil } -func (fs *FTPStorage) Get(filename string, hash string) error { +func (fs *FTPStorage) Get(hash string) (io.ReadCloser, error) { fs.connectionsChan <- 0 defer func() { <-fs.connectionsChan }() connection, err := ftp.Connect(fs.server + ":" + strconv.Itoa(fs.port)) if err != nil { - return err + return nil, err } defer connection.Quit() err = connection.Login(fs.username, fs.password) if err != nil { - return err + return nil, err } err = connection.ChangeDir(fs.directory) if err != nil { - return err + return nil, err } - downloadedFile, err := connection.Retr(hash) - if err != nil { - return err - } - defer downloadedFile.Close() - - outfile, err := os.Create(filename) - if err != nil { - return err - } - defer outfile.Close() - - _, err = io.Copy(outfile, downloadedFile) - return err + return connection.Retr(hash) } diff --git a/asink/storage_local.go b/asink/storage_local.go index 567e4a9..f8b6374 100644 --- a/asink/storage_local.go +++ b/asink/storage_local.go @@ -9,6 +9,7 @@ import ( "code.google.com/p/goconf/conf" "errors" "io" + "io/ioutil" "os" "path" ) @@ -41,37 +42,44 @@ func NewLocalStorage(config *conf.ConfigFile) (*LocalStorage, error) { return ls, nil } -func (ls *LocalStorage) Put(filename string, hash string) (e error) { - tmpfile, err := util.CopyToTmp(filename, ls.tmpSubdir) - if err != nil { - return err - } +type putWriteCloser struct { + outfile *os.File + filename string +} - err = os.Rename(tmpfile, path.Join(ls.storageDir, hash)) +func (wc putWriteCloser) Write(p []byte) (n int, err error) { + return wc.outfile.Write(p) +} + +func (wc putWriteCloser) Close() error { + tmpfilename := wc.outfile.Name() + wc.outfile.Close() + + err := os.Rename(tmpfilename, wc.filename) if err != nil { - err := os.Remove(tmpfile) + err := os.Remove(tmpfilename) if err != nil { return err } } - return nil } -func (ls *LocalStorage) Get(filename string, hash string) error { - infile, err := os.Open(path.Join(ls.storageDir, hash)) +func (ls *LocalStorage) Put(hash string) (w io.WriteCloser, e error) { + outfile, err := ioutil.TempFile(ls.tmpSubdir, "asink") if err != nil { - return err + return nil, err } - defer infile.Close() - outfile, err := os.Create(filename) - if err != nil { - return err - } - defer outfile.Close() + w = putWriteCloser{outfile, path.Join(ls.storageDir, hash)} - _, err = io.Copy(outfile, infile) - - return err + return +} + +func (ls *LocalStorage) Get(hash string) (r io.ReadCloser, e error) { + r, err := os.Open(path.Join(ls.storageDir, hash)) + if err != nil { + return nil, err + } + return } diff --git a/util/util.go b/util/util.go index 56dae7d..40e6a00 100644 --- a/util/util.go +++ b/util/util.go @@ -46,6 +46,21 @@ func RecursiveRemoveEmptyDirs(dir string) { } } +func CopyReaderToTmp(src io.Reader, tmpdir string) (string, error) { + outfile, err := ioutil.TempFile(tmpdir, "asink") + if err != nil { + return "", err + } + defer outfile.Close() + + _, err = io.Copy(outfile, src) + if err != nil { + return "", err + } + + return outfile.Name(), nil +} + func CopyToTmp(src string, tmpdir string) (string, error) { infile, err := os.Open(src) if err != nil { @@ -53,18 +68,7 @@ func CopyToTmp(src string, tmpdir string) (string, error) { } defer infile.Close() - outfile, err := ioutil.TempFile(tmpdir, "asink") - if err != nil { - return "", err - } - defer outfile.Close() - - _, err = io.Copy(outfile, infile) - if err != nil { - return "", err - } - - return outfile.Name(), nil + return CopyReaderToTmp(infile, tmpdir) } func ErrorFileNotFound(err error) bool {