diff --git a/account_import.go b/account_import.go index 88a9a91..579770b 100644 --- a/account_import.go +++ b/account_import.go @@ -5,6 +5,7 @@ import ( "html/template" "io" "io/ioutil" + "mime/multipart" "net/http" "os" "path/filepath" @@ -55,36 +56,58 @@ func viewImport(app *App, u *User, w http.ResponseWriter, r *http.Request) error func handleImport(app *App, u *User, w http.ResponseWriter, r *http.Request) error { // limit 10MB per submission + // TODO: increase? r.ParseMultipartForm(10 << 20) files := r.MultipartForm.File["files"] - var fileErrs []error filesSubmitted := len(files) - var filesImported int + var filesImported, collsImported int + var errs []error + // TODO: support multiple zip uploads at once + if filesSubmitted == 1 && files[0].Header.Get("Content-Type") == "application/zip" { + filesSubmitted, filesImported, collsImported, errs = importZipPosts(app, w, r, files[0], u) + } else { + filesImported, errs = importFilePosts(app, w, r, files, u) + } + + if len(errs) != 0 { + _ = addSessionFlash(app, w, r, multierror.ListFormatFunc(errs), nil) + } + if filesImported == filesSubmitted { + postAdj := "posts" + if filesSubmitted == 1 { + postAdj = "post" + } + if collsImported != 0 { + collAdj := "collections" + if collsImported == 1 { + collAdj = "collection" + } + _ = addSessionFlash(app, w, r, fmt.Sprintf( + "SUCCESS: Import complete, %d %s imported across %d %s.", + filesImported, + postAdj, + collsImported, + collAdj, + ), nil) + } else { + _ = addSessionFlash(app, w, r, fmt.Sprintf("SUCCESS: Import complete, %d %s imported.", filesImported, postAdj), nil) + } + } else if filesImported > 0 { + _ = addSessionFlash(app, w, r, fmt.Sprintf("INFO: %d of %d posts imported, see details below.", filesImported, filesSubmitted), nil) + } + return impart.HTTPError{http.StatusFound, "/me/import"} +} + +func importFilePosts(app *App, w http.ResponseWriter, r *http.Request, files []*multipart.FileHeader, u *User) (int, []error) { + var fileErrs []error + var count int for _, formFile := range files { - file, err := formFile.Open() - if err != nil { - fileErrs = append(fileErrs, fmt.Errorf("failed to open form file: %s", formFile.Filename)) - log.Error("import textfile: open from form: %v", err) + if filepath.Ext(formFile.Filename) == ".zip" { + fileErrs = append(fileErrs, fmt.Errorf("zips are supported as a single upload only: %s", formFile.Filename)) + log.Info("zip included in bulk files, skipping") continue } - defer file.Close() - - tempFile, err := ioutil.TempFile("", "post-upload-*.txt") - if err != nil { - fileErrs = append(fileErrs, fmt.Errorf("failed to create temporary file for: %s", formFile.Filename)) - log.Error("import textfile: create temp file: %v", err) - continue - } - defer tempFile.Close() - - _, err = io.Copy(tempFile, file) - if err != nil { - fileErrs = append(fileErrs, fmt.Errorf("failed to copy file into temporary location: %s", formFile.Filename)) - log.Error("import textfile: copy to temp: %v", err) - continue - } - - info, err := tempFile.Stat() + info, err := formFileToTemp(formFile) if err != nil { fileErrs = append(fileErrs, fmt.Errorf("failed to get file info of: %s", formFile.Filename)) log.Error("import textfile: stat temp file: %v", err) @@ -142,20 +165,95 @@ func handleImport(app *App, u *User, w http.ResponseWriter, r *http.Request) err false, ) } - filesImported++ + count++ } - if len(fileErrs) != 0 { - _ = addSessionFlash(app, w, r, multierror.ListFormatFunc(fileErrs), nil) + return count, fileErrs +} + +func importZipPosts(app *App, w http.ResponseWriter, r *http.Request, file *multipart.FileHeader, u *User) (filesSubmitted, importedPosts, importedColls int, errs []error) { + info, err := formFileToTemp(file) + if err != nil { + errs = append(errs, fmt.Errorf("upload temp file: %v", err)) + return } - if filesImported == filesSubmitted { - verb := "posts" - if filesSubmitted == 1 { - verb = "post" - } - _ = addSessionFlash(app, w, r, fmt.Sprintf("SUCCESS: Import complete, %d %s imported.", filesImported, verb), nil) - } else if filesImported > 0 { - _ = addSessionFlash(app, w, r, fmt.Sprintf("INFO: %d of %d posts imported, see details below.", filesImported, filesSubmitted), nil) + postMap, err := wfimport.FromZipDirs(filepath.Join(os.TempDir(), info.Name())) + if err != nil { + errs = append(errs, fmt.Errorf("parse posts and collections from zip: %v", err)) + return } - return impart.HTTPError{http.StatusFound, "/me/import"} + + for collKey, posts := range postMap { + // TODO: will posts ever be 0? should skip if so + collObj := CollectionObj{} + importedColls++ + if collKey != wfimport.DraftsKey { + coll, err := app.db.GetCollection(collKey) + if err == ErrCollectionNotFound { + coll, err = app.db.CreateCollection(app.cfg, collKey, collKey, u.ID) + if err != nil { + errs = append(errs, fmt.Errorf("create non existent collection: %v", err)) + continue + } + coll.hostName = app.cfg.App.Host + collObj.Collection = *coll + } else if err != nil { + errs = append(errs, fmt.Errorf("get collection: %v", err)) + continue + } + collObj.Collection = *coll + } + + for _, post := range posts { + if post != nil { + filesSubmitted++ + created := post.Created.Format("2006-01-02T15:04:05Z") + submittedPost := SubmittedPost{ + Title: &post.Title, + Content: &post.Content, + Font: "norm", + Created: &created, + } + rp, err := app.db.CreatePost(u.ID, collObj.Collection.ID, &submittedPost) + if err != nil { + errs = append(errs, fmt.Errorf("create post: %v", err)) + } + + if collObj.Collection.ID != 0 && app.cfg.App.Federation { + go federatePost( + app, + &PublicPost{ + Post: rp, + Collection: &collObj, + }, + collObj.Collection.ID, + false, + ) + } + importedPosts++ + } + } + } + return +} + +func formFileToTemp(formFile *multipart.FileHeader) (os.FileInfo, error) { + file, err := formFile.Open() + if err != nil { + return nil, fmt.Errorf("failed to open form file: %s", formFile.Filename) + } + defer file.Close() + + tempFile, err := ioutil.TempFile("", fmt.Sprintf("upload-*%s", filepath.Ext(formFile.Filename))) + if err != nil { + return nil, fmt.Errorf("failed to create temporary file for: %s", formFile.Filename) + } + defer tempFile.Close() + + _, err = io.Copy(tempFile, file) + if err != nil { + return nil, fmt.Errorf("failed to copy file into temporary location: %s", formFile.Filename) + } + + return tempFile.Stat() } diff --git a/go.sum b/go.sum index de39ce4..64400b9 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,7 @@ github.com/gogs/minwinsvc v0.0.0-20170301035411-95be6356811a/go.mod h1:TUIZ+29jo github.com/golang/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/lint v0.0.0-20181217174547-8f45f776aaf1 h1:6DVPu65tee05kY0/rciBQ47ue+AnuY8KTayV6VHikIo= github.com/golang/lint v0.0.0-20181217174547-8f45f776aaf1/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= +github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf h1:7+FW5aGwISbqUtkfmIpZJGRgNFg2ioYPvFaUxdqpDsg= github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf/go.mod h1:RpwtwJQFrIEPstU94h88MWPXP2ektJZ8cZ0YntAmXiE= @@ -169,6 +170,8 @@ github.com/writeas/import v0.0.0-20190815214647-baae8acd8d06 h1:S6oKKP8GhSoyZUvV github.com/writeas/import v0.0.0-20190815214647-baae8acd8d06/go.mod h1:f3K8z7YnJwKnPIT4h7980n9C6cQb4DIB2QcxVCTB7lE= github.com/writeas/import v0.0.0-20190815235139-628d10daaa9e h1:31PkvDTWkjzC1nGzWw9uAE92ZfcVyFX/K9L9ejQjnEs= github.com/writeas/import v0.0.0-20190815235139-628d10daaa9e/go.mod h1:f3K8z7YnJwKnPIT4h7980n9C6cQb4DIB2QcxVCTB7lE= +github.com/writeas/import v0.1.0 h1:ZbAOb6QL24tgZOEAmEJ/hk59fF4UGOX8sQLaYE+yNiA= +github.com/writeas/import v0.1.0/go.mod h1:f3K8z7YnJwKnPIT4h7980n9C6cQb4DIB2QcxVCTB7lE= github.com/writeas/import v0.1.1 h1:SbYltT+nxrJBUe0xQWJqeKMHaupbxV0a6K3RtwcE4yY= github.com/writeas/import v0.1.1/go.mod h1:gFe0Pl7ZWYiXbI0TJxeMMyylPGZmhVvCfQxhMEc8CxM= github.com/writeas/import v0.2.0 h1:Ov23JW9Rnjxk06rki1Spar45bNX647HhwhAZj3flJiY= diff --git a/templates/user/import.tmpl b/templates/user/import.tmpl index 0fd8a93..0c04f06 100644 --- a/templates/user/import.tmpl +++ b/templates/user/import.tmpl @@ -12,7 +12,7 @@