Merge pull request #303 from writeas/fix-oauth-account-creation
Respect registration config on OAuth signup flow
This commit is contained in:
commit
99d86a7489
15 changed files with 238 additions and 78 deletions
11
account.go
11
account.go
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright © 2018-2019 A Bunch Tell LLC.
|
* Copyright © 2018-2020 A Bunch Tell LLC.
|
||||||
*
|
*
|
||||||
* This file is part of WriteFreely.
|
* This file is part of WriteFreely.
|
||||||
*
|
*
|
||||||
|
@ -168,11 +168,7 @@ func signupWithRegistration(app *App, signup userRegistration, w http.ResponseWr
|
||||||
|
|
||||||
// Log invite if needed
|
// Log invite if needed
|
||||||
if signup.InviteCode != "" {
|
if signup.InviteCode != "" {
|
||||||
cu, err := app.db.GetUserForAuth(signup.Alias)
|
err = app.db.CreateInvitedUser(signup.InviteCode, u.ID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = app.db.CreateInvitedUser(signup.InviteCode, cu.ID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -493,6 +489,9 @@ func login(app *App, w http.ResponseWriter, r *http.Request) error {
|
||||||
return impart.HTTPError{http.StatusPreconditionFailed, "This user never added a password or email address. Please contact us for help."}
|
return impart.HTTPError{http.StatusPreconditionFailed, "This user never added a password or email address. Please contact us for help."}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(u.HashedPass) == 0 {
|
||||||
|
return impart.HTTPError{http.StatusUnauthorized, "This user never set a password. Perhaps try logging in via OAuth?"}
|
||||||
|
}
|
||||||
if !auth.Authenticated(u.HashedPass, []byte(signin.Pass)) {
|
if !auth.Authenticated(u.HashedPass, []byte(signin.Pass)) {
|
||||||
return impart.HTTPError{http.StatusUnauthorized, "Incorrect password."}
|
return impart.HTTPError{http.StatusUnauthorized, "Incorrect password."}
|
||||||
}
|
}
|
||||||
|
|
21
database.go
21
database.go
|
@ -132,8 +132,8 @@ type writestore interface {
|
||||||
|
|
||||||
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
|
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
|
||||||
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
|
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
|
||||||
ValidateOAuthState(context.Context, string) (string, string, int64, error)
|
ValidateOAuthState(context.Context, string) (string, string, int64, string, error)
|
||||||
GenerateOAuthState(context.Context, string, string, int64) (string, error)
|
GenerateOAuthState(context.Context, string, string, int64, string) (string, error)
|
||||||
GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error)
|
GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error)
|
||||||
RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error
|
RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error
|
||||||
|
|
||||||
|
@ -178,6 +178,7 @@ func (db *datastore) dateSub(l int, unit string) string {
|
||||||
return fmt.Sprintf("DATE_SUB(NOW(), INTERVAL %d %s)", l, unit)
|
return fmt.Sprintf("DATE_SUB(NOW(), INTERVAL %d %s)", l, unit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateUser creates a new user in the database from the given User, UPDATING it in the process with the user's ID.
|
||||||
func (db *datastore) CreateUser(cfg *config.Config, u *User, collectionTitle string) error {
|
func (db *datastore) CreateUser(cfg *config.Config, u *User, collectionTitle string) error {
|
||||||
if db.PostIDExists(u.Username) {
|
if db.PostIDExists(u.Username) {
|
||||||
return impart.HTTPError{http.StatusConflict, "Invalid collection name."}
|
return impart.HTTPError{http.StatusConflict, "Invalid collection name."}
|
||||||
|
@ -2516,24 +2517,26 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
|
||||||
return &t, nil
|
return &t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64) (string, error) {
|
func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64, inviteCode string) (string, error) {
|
||||||
state := store.Generate62RandomString(24)
|
state := store.Generate62RandomString(24)
|
||||||
attachUserVal := sql.NullInt64{Valid: attachUser > 0, Int64: attachUser}
|
attachUserVal := sql.NullInt64{Valid: attachUser > 0, Int64: attachUser}
|
||||||
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id) VALUES (?, ?, ?, FALSE, "+db.now()+", ?)", state, provider, clientID, attachUserVal)
|
inviteCodeVal := sql.NullString{Valid: inviteCode != "", String: inviteCode}
|
||||||
|
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id, invite_code) VALUES (?, ?, ?, FALSE, "+db.now()+", ?, ?)", state, provider, clientID, attachUserVal, inviteCodeVal)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("unable to record oauth client state: %w", err)
|
return "", fmt.Errorf("unable to record oauth client state: %w", err)
|
||||||
}
|
}
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
|
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) {
|
||||||
var provider string
|
var provider string
|
||||||
var clientID string
|
var clientID string
|
||||||
var attachUserID sql.NullInt64
|
var attachUserID sql.NullInt64
|
||||||
|
var inviteCode sql.NullString
|
||||||
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
|
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
|
||||||
err := tx.
|
err := tx.
|
||||||
QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).
|
QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id, invite_code FROM oauth_client_states WHERE state = ? AND used = FALSE", state).
|
||||||
Scan(&provider, &clientID, &attachUserID)
|
Scan(&provider, &clientID, &attachUserID, &inviteCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -2552,9 +2555,9 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (stri
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", 0, nil
|
return "", "", 0, "", nil
|
||||||
}
|
}
|
||||||
return provider, clientID, attachUserID.Int64, nil
|
return provider, clientID, attachUserID.Int64, inviteCode.String, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
|
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
|
||||||
|
|
|
@ -18,13 +18,13 @@ func TestOAuthDatastore(t *testing.T) {
|
||||||
driverName: "",
|
driverName: "",
|
||||||
}
|
}
|
||||||
|
|
||||||
state, err := ds.GenerateOAuthState(ctx, "test", "development", 0)
|
state, err := ds.GenerateOAuthState(ctx, "test", "development", 0, "")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Len(t, state, 24)
|
assert.Len(t, state, 24)
|
||||||
|
|
||||||
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state)
|
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state)
|
||||||
|
|
||||||
_, _, _, err = ds.ValidateOAuthState(ctx, state)
|
_, _, _, _, err = ds.ValidateOAuthState(ctx, state)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state)
|
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state)
|
||||||
|
|
16
invites.go
16
invites.go
|
@ -1,5 +1,5 @@
|
||||||
/*
|
/*
|
||||||
* Copyright © 2019 A Bunch Tell LLC.
|
* Copyright © 2019-2020 A Bunch Tell LLC.
|
||||||
*
|
*
|
||||||
* This file is part of WriteFreely.
|
* This file is part of WriteFreely.
|
||||||
*
|
*
|
||||||
|
@ -42,6 +42,18 @@ func (i Invite) Expired() bool {
|
||||||
return i.Expires != nil && i.Expires.Before(time.Now())
|
return i.Expires != nil && i.Expires.Before(time.Now())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i Invite) Active(db *datastore) bool {
|
||||||
|
if i.Expired() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if i.MaxUses.Valid && i.MaxUses.Int64 > 0 {
|
||||||
|
if c := db.GetUsersInvitedCount(i.ID); c >= i.MaxUses.Int64 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (i Invite) ExpiresFriendly() string {
|
func (i Invite) ExpiresFriendly() string {
|
||||||
return i.Expires.Format("January 2, 2006, 3:04 PM")
|
return i.Expires.Format("January 2, 2006, 3:04 PM")
|
||||||
}
|
}
|
||||||
|
@ -161,9 +173,11 @@ func handleViewInvite(app *App, w http.ResponseWriter, r *http.Request) error {
|
||||||
Error string
|
Error string
|
||||||
Flashes []template.HTML
|
Flashes []template.HTML
|
||||||
Invite string
|
Invite string
|
||||||
|
OAuth *OAuthButtons
|
||||||
}{
|
}{
|
||||||
StaticPage: pageForReq(app, r),
|
StaticPage: pageForReq(app, r),
|
||||||
Invite: inviteCode,
|
Invite: inviteCode,
|
||||||
|
OAuth: NewOAuthButtons(app.cfg),
|
||||||
}
|
}
|
||||||
|
|
||||||
if expired {
|
if expired {
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
@import "post-temp";
|
@import "post-temp";
|
||||||
@import "effects";
|
@import "effects";
|
||||||
@import "admin";
|
@import "admin";
|
||||||
|
@import "login";
|
||||||
@import "pages/error";
|
@import "pages/error";
|
||||||
@import "lib/elements";
|
@import "lib/elements";
|
||||||
@import "lib/material";
|
@import "lib/material";
|
||||||
|
|
45
less/login.less
Normal file
45
less/login.less
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
/*
|
||||||
|
* Copyright © 2020 A Bunch Tell LLC.
|
||||||
|
*
|
||||||
|
* This file is part of WriteFreely.
|
||||||
|
*
|
||||||
|
* WriteFreely is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU Affero General Public License, included
|
||||||
|
* in the LICENSE file in this source code package.
|
||||||
|
*/
|
||||||
|
|
||||||
|
.row.signinbtns {
|
||||||
|
justify-content: space-evenly;
|
||||||
|
font-size: 1em;
|
||||||
|
margin-top: 2em;
|
||||||
|
margin-bottom: 1em;
|
||||||
|
|
||||||
|
.loginbtn {
|
||||||
|
height: 40px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#writeas-login, #gitlab-login {
|
||||||
|
box-sizing: border-box;
|
||||||
|
font-size: 17px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.or {
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 3.5em;
|
||||||
|
|
||||||
|
p {
|
||||||
|
display: inline-block;
|
||||||
|
background-color: white;
|
||||||
|
padding: 0 1em;
|
||||||
|
}
|
||||||
|
|
||||||
|
hr {
|
||||||
|
margin-top: -1.6em;
|
||||||
|
margin-bottom: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
hr.short {
|
||||||
|
max-width: 30rem;
|
||||||
|
}
|
||||||
|
}
|
|
@ -62,7 +62,8 @@ var migrations = []Migration{
|
||||||
New("support oauth", oauth), // V3 -> V4
|
New("support oauth", oauth), // V3 -> V4
|
||||||
New("support slack oauth", oauthSlack), // V4 -> v5
|
New("support slack oauth", oauthSlack), // V4 -> v5
|
||||||
New("support ActivityPub mentions", supportActivityPubMentions), // V5 -> V6
|
New("support ActivityPub mentions", supportActivityPubMentions), // V5 -> V6
|
||||||
New("support oauth attach", oauthAttach), // V6 -> V7 (v0.12.0)
|
New("support oauth attach", oauthAttach), // V6 -> V7
|
||||||
|
New("support oauth via invite", oauthInvites), // V7 -> V8 (v0.12.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CurrentVer returns the current migration version the application is on
|
// CurrentVer returns the current migration version the application is on
|
||||||
|
|
45
migrations/v8.go
Normal file
45
migrations/v8.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
/*
|
||||||
|
* Copyright © 2020 A Bunch Tell LLC.
|
||||||
|
*
|
||||||
|
* This file is part of WriteFreely.
|
||||||
|
*
|
||||||
|
* WriteFreely is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU Affero General Public License, included
|
||||||
|
* in the LICENSE file in this source code package.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package migrations
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
wf_db "github.com/writeas/writefreely/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
func oauthInvites(db *datastore) error {
|
||||||
|
dialect := wf_db.DialectMySQL
|
||||||
|
if db.driverName == driverSQLite {
|
||||||
|
dialect = wf_db.DialectSQLite
|
||||||
|
}
|
||||||
|
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
|
||||||
|
builders := []wf_db.SQLBuilder{
|
||||||
|
dialect.
|
||||||
|
AlterTable("oauth_client_states").
|
||||||
|
AddColumn(dialect.Column("invite_code", wf_db.ColumnTypeChar, wf_db.OptionalInt{
|
||||||
|
Set: true,
|
||||||
|
Value: 6,
|
||||||
|
}).SetNullable(true)),
|
||||||
|
}
|
||||||
|
for _, builder := range builders {
|
||||||
|
query, err := builder.ToSQL()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.ExecContext(ctx, query); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
57
oauth.go
57
oauth.go
|
@ -1,3 +1,13 @@
|
||||||
|
/*
|
||||||
|
* Copyright © 2019-2020 A Bunch Tell LLC.
|
||||||
|
*
|
||||||
|
* This file is part of WriteFreely.
|
||||||
|
*
|
||||||
|
* WriteFreely is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU Affero General Public License, included
|
||||||
|
* in the LICENSE file in this source code package.
|
||||||
|
*/
|
||||||
|
|
||||||
package writefreely
|
package writefreely
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -15,10 +25,27 @@ import (
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
"github.com/writeas/impart"
|
"github.com/writeas/impart"
|
||||||
"github.com/writeas/web-core/log"
|
"github.com/writeas/web-core/log"
|
||||||
|
|
||||||
"github.com/writeas/writefreely/config"
|
"github.com/writeas/writefreely/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OAuthButtons holds display information for different OAuth providers we support.
|
||||||
|
type OAuthButtons struct {
|
||||||
|
SlackEnabled bool
|
||||||
|
WriteAsEnabled bool
|
||||||
|
GitLabEnabled bool
|
||||||
|
GitLabDisplayName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOAuthButtons creates a new OAuthButtons struct based on our app configuration.
|
||||||
|
func NewOAuthButtons(cfg *config.Config) *OAuthButtons {
|
||||||
|
return &OAuthButtons{
|
||||||
|
SlackEnabled: cfg.SlackOauth.ClientID != "",
|
||||||
|
WriteAsEnabled: cfg.WriteAsOauth.ClientID != "",
|
||||||
|
GitLabEnabled: cfg.GitlabOauth.ClientID != "",
|
||||||
|
GitLabDisplayName: config.OrDefaultString(cfg.GitlabOauth.DisplayName, gitlabDisplayName),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TokenResponse contains data returned when a token is created either
|
// TokenResponse contains data returned when a token is created either
|
||||||
// through a code exchange or using a refresh token.
|
// through a code exchange or using a refresh token.
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
|
@ -61,8 +88,8 @@ type OAuthDatastoreProvider interface {
|
||||||
type OAuthDatastore interface {
|
type OAuthDatastore interface {
|
||||||
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
|
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
|
||||||
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
|
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
|
||||||
ValidateOAuthState(context.Context, string) (string, string, int64, error)
|
ValidateOAuthState(context.Context, string) (string, string, int64, string, error)
|
||||||
GenerateOAuthState(context.Context, string, string, int64) (string, error)
|
GenerateOAuthState(context.Context, string, string, int64, string) (string, error)
|
||||||
|
|
||||||
CreateUser(*config.Config, *User, string) error
|
CreateUser(*config.Config, *User, string) error
|
||||||
GetUserByID(int64) (*User, error)
|
GetUserByID(int64) (*User, error)
|
||||||
|
@ -108,7 +135,7 @@ func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Req
|
||||||
attachUser = user.ID
|
attachUser = user.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser)
|
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser, r.FormValue("invite_code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("viewOauthInit error: %s", err)
|
log.Error("viewOauthInit error: %s", err)
|
||||||
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
|
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
|
||||||
|
@ -228,7 +255,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
|
||||||
code := r.FormValue("code")
|
code := r.FormValue("code")
|
||||||
state := r.FormValue("state")
|
state := r.FormValue("state")
|
||||||
|
|
||||||
provider, clientID, attachUserID, err := h.DB.ValidateOAuthState(ctx, state)
|
provider, clientID, attachUserID, inviteCode, err := h.DB.ValidateOAuthState(ctx, state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Unable to ValidateOAuthState: %s", err)
|
log.Error("Unable to ValidateOAuthState: %s", err)
|
||||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||||
|
@ -240,7 +267,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
|
||||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now that we have the access token, let's use it real quick to make sur
|
// Now that we have the access token, let's use it real quick to make sure
|
||||||
// it really really works.
|
// it really really works.
|
||||||
tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
|
tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -262,6 +289,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
|
||||||
}
|
}
|
||||||
|
|
||||||
if localUserID != -1 {
|
if localUserID != -1 {
|
||||||
|
// Existing user, so log in now
|
||||||
user, err := h.DB.GetUserByID(localUserID)
|
user, err := h.DB.GetUserByID(localUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Unable to GetUserByID %d: %s", localUserID, err)
|
log.Error("Unable to GetUserByID %d: %s", localUserID, err)
|
||||||
|
@ -282,6 +310,22 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
|
||||||
return impart.HTTPError{http.StatusFound, "/me/settings"}
|
return impart.HTTPError{http.StatusFound, "/me/settings"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New user registration below.
|
||||||
|
// First, verify that user is allowed to register
|
||||||
|
if inviteCode != "" {
|
||||||
|
// Verify invite code is valid
|
||||||
|
i, err := app.db.GetUserInvite(inviteCode)
|
||||||
|
if err != nil {
|
||||||
|
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||||
|
}
|
||||||
|
if !i.Active(app.db) {
|
||||||
|
return impart.HTTPError{http.StatusNotFound, "Invite link has expired."}
|
||||||
|
}
|
||||||
|
} else if !app.cfg.App.OpenRegistration {
|
||||||
|
addSessionFlash(app, w, r, ErrUserNotFound.Error(), nil)
|
||||||
|
return impart.HTTPError{http.StatusFound, "/login"}
|
||||||
|
}
|
||||||
|
|
||||||
displayName := tokenInfo.DisplayName
|
displayName := tokenInfo.DisplayName
|
||||||
if len(displayName) == 0 {
|
if len(displayName) == 0 {
|
||||||
displayName = tokenInfo.Username
|
displayName = tokenInfo.Username
|
||||||
|
@ -295,6 +339,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
|
||||||
TokenRemoteUser: tokenInfo.UserID,
|
TokenRemoteUser: tokenInfo.UserID,
|
||||||
Provider: provider,
|
Provider: provider,
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
|
InviteCode: inviteCode,
|
||||||
}
|
}
|
||||||
tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
|
tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,7 @@ type viewOauthSignupVars struct {
|
||||||
Provider string
|
Provider string
|
||||||
ClientID string
|
ClientID string
|
||||||
TokenHash string
|
TokenHash string
|
||||||
|
InviteCode string
|
||||||
|
|
||||||
LoginUsername string
|
LoginUsername string
|
||||||
Alias string // TODO: rename this to match the data it represents: the collection title
|
Alias string // TODO: rename this to match the data it represents: the collection title
|
||||||
|
@ -57,6 +58,7 @@ const (
|
||||||
oauthParamAlias = "alias"
|
oauthParamAlias = "alias"
|
||||||
oauthParamEmail = "email"
|
oauthParamEmail = "email"
|
||||||
oauthParamPassword = "password"
|
oauthParamPassword = "password"
|
||||||
|
oauthParamInviteCode = "invite_code"
|
||||||
)
|
)
|
||||||
|
|
||||||
type oauthSignupPageParams struct {
|
type oauthSignupPageParams struct {
|
||||||
|
@ -68,6 +70,7 @@ type oauthSignupPageParams struct {
|
||||||
ClientID string
|
ClientID string
|
||||||
Provider string
|
Provider string
|
||||||
TokenHash string
|
TokenHash string
|
||||||
|
InviteCode string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p oauthSignupPageParams) HashTokenParams(key string) string {
|
func (p oauthSignupPageParams) HashTokenParams(key string) string {
|
||||||
|
@ -92,6 +95,7 @@ func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.R
|
||||||
TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID),
|
TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID),
|
||||||
ClientID: r.FormValue(oauthParamClientID),
|
ClientID: r.FormValue(oauthParamClientID),
|
||||||
Provider: r.FormValue(oauthParamProvider),
|
Provider: r.FormValue(oauthParamProvider),
|
||||||
|
InviteCode: r.FormValue(oauthParamInviteCode),
|
||||||
}
|
}
|
||||||
if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) {
|
if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) {
|
||||||
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."}
|
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."}
|
||||||
|
@ -128,6 +132,14 @@ func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.R
|
||||||
return h.showOauthSignupPage(app, w, r, tp, err)
|
return h.showOauthSignupPage(app, w, r, tp, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Log invite if needed
|
||||||
|
if tp.InviteCode != "" {
|
||||||
|
err = app.db.CreateInvitedUser(tp.InviteCode, newUser.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken))
|
err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return h.showOauthSignupPage(app, w, r, tp, err)
|
return h.showOauthSignupPage(app, w, r, tp, err)
|
||||||
|
@ -195,6 +207,7 @@ func (h oauthHandler) showOauthSignupPage(app *App, w http.ResponseWriter, r *ht
|
||||||
Provider: tp.Provider,
|
Provider: tp.Provider,
|
||||||
ClientID: tp.ClientID,
|
ClientID: tp.ClientID,
|
||||||
TokenHash: tp.TokenHash,
|
TokenHash: tp.TokenHash,
|
||||||
|
InviteCode: tp.InviteCode,
|
||||||
|
|
||||||
LoginUsername: username,
|
LoginUsername: username,
|
||||||
Alias: collTitle,
|
Alias: collTitle,
|
||||||
|
|
|
@ -13,8 +13,6 @@ package writefreely
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"github.com/writeas/nerds/store"
|
|
||||||
"github.com/writeas/slug"
|
"github.com/writeas/slug"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -167,7 +165,7 @@ func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessTok
|
||||||
func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse {
|
func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse {
|
||||||
return &InspectResponse{
|
return &InspectResponse{
|
||||||
UserID: resp.User.ID,
|
UserID: resp.User.ID,
|
||||||
Username: fmt.Sprintf("%s-%s", slug.Make(resp.User.Name), store.GenerateRandomString("0123456789bcdfghjklmnpqrstvwxyz", 5)),
|
Username: slug.Make(resp.User.Name),
|
||||||
DisplayName: resp.User.Name,
|
DisplayName: resp.User.Name,
|
||||||
Email: resp.User.Email,
|
Email: resp.User.Email,
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,8 +22,8 @@ type MockOAuthDatastoreProvider struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockOAuthDatastore struct {
|
type MockOAuthDatastore struct {
|
||||||
DoGenerateOAuthState func(context.Context, string, string, int64) (string, error)
|
DoGenerateOAuthState func(context.Context, string, string, int64, string) (string, error)
|
||||||
DoValidateOAuthState func(context.Context, string) (string, string, int64, error)
|
DoValidateOAuthState func(context.Context, string) (string, string, int64, string, error)
|
||||||
DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
|
DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
|
||||||
DoCreateUser func(*config.Config, *User, string) error
|
DoCreateUser func(*config.Config, *User, string) error
|
||||||
DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
|
DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
|
||||||
|
@ -86,11 +86,11 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config {
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
|
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) {
|
||||||
if m.DoValidateOAuthState != nil {
|
if m.DoValidateOAuthState != nil {
|
||||||
return m.DoValidateOAuthState(ctx, state)
|
return m.DoValidateOAuthState(ctx, state)
|
||||||
}
|
}
|
||||||
return "", "", 0, nil
|
return "", "", 0, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
|
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
|
||||||
|
@ -119,15 +119,13 @@ func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
|
||||||
if m.DoGetUserByID != nil {
|
if m.DoGetUserByID != nil {
|
||||||
return m.DoGetUserByID(userID)
|
return m.DoGetUserByID(userID)
|
||||||
}
|
}
|
||||||
user := &User{
|
user := &User{}
|
||||||
|
|
||||||
}
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64) (string, error) {
|
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64, inviteCode string) (string, error) {
|
||||||
if m.DoGenerateOAuthState != nil {
|
if m.DoGenerateOAuthState != nil {
|
||||||
return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID)
|
return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID, inviteCode)
|
||||||
}
|
}
|
||||||
return store.Generate62RandomString(14), nil
|
return store.Generate62RandomString(14), nil
|
||||||
}
|
}
|
||||||
|
@ -173,7 +171,7 @@ func TestViewOauthInit(t *testing.T) {
|
||||||
app := &MockOAuthDatastoreProvider{
|
app := &MockOAuthDatastoreProvider{
|
||||||
DoDB: func() OAuthDatastore {
|
DoDB: func() OAuthDatastore {
|
||||||
return &MockOAuthDatastore{
|
return &MockOAuthDatastore{
|
||||||
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64) (string, error) {
|
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64, inviteCode string) (string, error) {
|
||||||
return "", fmt.Errorf("pretend unable to write state error")
|
return "", fmt.Errorf("pretend unable to write state error")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,39 +3,6 @@
|
||||||
<meta itemprop="description" content="Log in to {{.SiteName}}.">
|
<meta itemprop="description" content="Log in to {{.SiteName}}.">
|
||||||
<style>
|
<style>
|
||||||
input{margin-bottom:0.5em;}
|
input{margin-bottom:0.5em;}
|
||||||
.or {
|
|
||||||
text-align: center;
|
|
||||||
margin-bottom: 3.5em;
|
|
||||||
}
|
|
||||||
.or p {
|
|
||||||
display: inline-block;
|
|
||||||
background-color: white;
|
|
||||||
padding: 0 1em;
|
|
||||||
}
|
|
||||||
.or hr {
|
|
||||||
margin-top: -1.6em;
|
|
||||||
margin-bottom: 0;
|
|
||||||
}
|
|
||||||
hr.short {
|
|
||||||
max-width: 30rem;
|
|
||||||
}
|
|
||||||
.row.signinbtns {
|
|
||||||
justify-content: space-evenly;
|
|
||||||
font-size: 1em;
|
|
||||||
margin-top: 3em;
|
|
||||||
margin-bottom: 2em;
|
|
||||||
}
|
|
||||||
.loginbtn {
|
|
||||||
height: 40px;
|
|
||||||
}
|
|
||||||
#writeas-login {
|
|
||||||
box-sizing: border-box;
|
|
||||||
font-size: 17px;
|
|
||||||
}
|
|
||||||
#gitlab-login {
|
|
||||||
box-sizing: border-box;
|
|
||||||
font-size: 17px;
|
|
||||||
}
|
|
||||||
</style>
|
</style>
|
||||||
{{end}}
|
{{end}}
|
||||||
{{define "content"}}
|
{{define "content"}}
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
{{define "head"}}<title>Log in — {{.SiteName}}</title>
|
{{define "head"}}<title>Finish Creating Account — {{.SiteName}}</title>
|
||||||
<meta name="description" content="Log in to {{.SiteName}}.">
|
|
||||||
<meta itemprop="description" content="Log in to {{.SiteName}}.">
|
|
||||||
<style>input{margin-bottom:0.5em;}</style>
|
<style>input{margin-bottom:0.5em;}</style>
|
||||||
<style type="text/css">
|
<style type="text/css">
|
||||||
h2 {
|
h2 {
|
||||||
|
@ -58,7 +56,7 @@ form dd {
|
||||||
{{end}}
|
{{end}}
|
||||||
{{define "content"}}
|
{{define "content"}}
|
||||||
<div id="pricing" class="tight content-container">
|
<div id="pricing" class="tight content-container">
|
||||||
<h1>Log in to {{.SiteName}}</h1>
|
<h1>Finish creating account</h1>
|
||||||
|
|
||||||
{{if .Flashes}}<ul class="errors">
|
{{if .Flashes}}<ul class="errors">
|
||||||
{{range .Flashes}}<li class="urgent">{{.}}</li>{{end}}
|
{{range .Flashes}}<li class="urgent">{{.}}</li>{{end}}
|
||||||
|
@ -74,6 +72,7 @@ form dd {
|
||||||
<input type="hidden" name="provider" value="{{ .Provider }}" />
|
<input type="hidden" name="provider" value="{{ .Provider }}" />
|
||||||
<input type="hidden" name="client_id" value="{{ .ClientID }}" />
|
<input type="hidden" name="client_id" value="{{ .ClientID }}" />
|
||||||
<input type="hidden" name="signature" value="{{ .TokenHash }}" />
|
<input type="hidden" name="signature" value="{{ .TokenHash }}" />
|
||||||
|
{{if .InviteCode}}<input type="hidden" name="invite_code" value="{{ .InviteCode }}" />{{end}}
|
||||||
|
|
||||||
<dl class="billing">
|
<dl class="billing">
|
||||||
<label>
|
<label>
|
||||||
|
@ -96,7 +95,7 @@ form dd {
|
||||||
</dd>
|
</dd>
|
||||||
</label>
|
</label>
|
||||||
<dt>
|
<dt>
|
||||||
<input type="submit" id="btn-login" value="Login" />
|
<input type="submit" id="btn-login" value="Next" />
|
||||||
</dt>
|
</dt>
|
||||||
</dl>
|
</dl>
|
||||||
</form>
|
</form>
|
||||||
|
@ -129,7 +128,7 @@ var $aliasSite = document.getElementById('alias-site');
|
||||||
var aliasOK = true;
|
var aliasOK = true;
|
||||||
var typingTimer;
|
var typingTimer;
|
||||||
var doneTypingInterval = 750;
|
var doneTypingInterval = 750;
|
||||||
var doneTyping = function() {
|
var doneTyping = function(genID) {
|
||||||
// Check on username
|
// Check on username
|
||||||
var alias = $alias.el.value;
|
var alias = $alias.el.value;
|
||||||
if (alias != "") {
|
if (alias != "") {
|
||||||
|
@ -152,6 +151,11 @@ var doneTyping = function() {
|
||||||
$aliasSite.className = $aliasSite.className.replace(/(?:^|\s)error(?!\S)/g, '');
|
$aliasSite.className = $aliasSite.className.replace(/(?:^|\s)error(?!\S)/g, '');
|
||||||
$aliasSite.innerHTML = '{{ if .Federation }}@<strong>' + data.data + '</strong>@{{.FriendlyHost}}{{ else }}{{.FriendlyHost}}/<strong>' + data.data + '</strong>/{{ end }}';
|
$aliasSite.innerHTML = '{{ if .Federation }}@<strong>' + data.data + '</strong>@{{.FriendlyHost}}{{ else }}{{.FriendlyHost}}/<strong>' + data.data + '</strong>/{{ end }}';
|
||||||
} else {
|
} else {
|
||||||
|
if (genID === true) {
|
||||||
|
$alias.el.value = alias + "-" + randStr(4);
|
||||||
|
doneTyping();
|
||||||
|
return;
|
||||||
|
}
|
||||||
aliasOK = false;
|
aliasOK = false;
|
||||||
$alias.setClass('error');
|
$alias.setClass('error');
|
||||||
$aliasSite.className = 'error';
|
$aliasSite.className = 'error';
|
||||||
|
@ -169,6 +173,14 @@ $alias.on('keyup input', function() {
|
||||||
clearTimeout(typingTimer);
|
clearTimeout(typingTimer);
|
||||||
typingTimer = setTimeout(doneTyping, doneTypingInterval);
|
typingTimer = setTimeout(doneTyping, doneTypingInterval);
|
||||||
});
|
});
|
||||||
doneTyping();
|
function randStr(len) {
|
||||||
|
var res = '';
|
||||||
|
var chars = '23456789bcdfghjklmnpqrstvwxyz';
|
||||||
|
for (var i=0; i<len; i++) {
|
||||||
|
res += chars.charAt(Math.floor(Math.random() * chars.length));
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
doneTyping(true);
|
||||||
</script>
|
</script>
|
||||||
{{end}}
|
{{end}}
|
||||||
|
|
|
@ -70,6 +70,25 @@ form dd {
|
||||||
</ul>{{end}}
|
</ul>{{end}}
|
||||||
|
|
||||||
<div id="billing">
|
<div id="billing">
|
||||||
|
{{ if or .OAuth.SlackEnabled .OAuth.WriteAsEnabled .OAuth.GitLabEnabled }}
|
||||||
|
<div class="row content-container signinbtns">
|
||||||
|
{{ if .OAuth.SlackEnabled }}
|
||||||
|
<a class="loginbtn" href="/oauth/slack{{if .Invite}}?invite_code={{.Invite}}{{end}}"><img alt="Sign in with Slack" height="40" width="172" src="/img/sign_in_with_slack.png" srcset="/img/sign_in_with_slack.png 1x, /img/sign_in_with_slack@2x.png 2x" /></a>
|
||||||
|
{{ end }}
|
||||||
|
{{ if .OAuth.WriteAsEnabled }}
|
||||||
|
<a class="btn cta loginbtn" id="writeas-login" href="/oauth/write.as{{if .Invite}}?invite_code={{.Invite}}{{end}}">Sign in with <strong>Write.as</strong></a>
|
||||||
|
{{ end }}
|
||||||
|
{{ if .OAuth.GitLabEnabled }}
|
||||||
|
<a class="btn cta loginbtn" id="gitlab-login" href="/oauth/gitlab{{if .Invite}}?invite_code={{.Invite}}{{end}}">Sign in with <strong>{{.OAuth.GitLabDisplayName}}</strong></a>
|
||||||
|
{{ end }}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="or">
|
||||||
|
<p>or</p>
|
||||||
|
<hr class="short" />
|
||||||
|
</div>
|
||||||
|
{{ end }}
|
||||||
|
|
||||||
<form action="/auth/signup" method="POST" id="signup-form" onsubmit="return signup()">
|
<form action="/auth/signup" method="POST" id="signup-form" onsubmit="return signup()">
|
||||||
<input type="hidden" name="invite_code" value="{{.Invite}}" />
|
<input type="hidden" name="invite_code" value="{{.Invite}}" />
|
||||||
<dl class="billing">
|
<dl class="billing">
|
||||||
|
|
Loading…
Add table
Reference in a new issue