187 lines
3.7 KiB
Go
187 lines
3.7 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
func lookupEmail(conn *pgx.Conn, username string) (email string, err error) {
|
|
query := "select id from accounts where username=$1"
|
|
rows, err := conn.Query(context.Background(), query, username)
|
|
if err != nil {
|
|
return
|
|
}
|
|
var ids []int
|
|
for rows.Next() {
|
|
var id int
|
|
rows.Scan(&id)
|
|
ids = append(ids, id)
|
|
}
|
|
rows.Close()
|
|
if rows.Err() != nil {
|
|
return
|
|
}
|
|
|
|
for _, id := range ids {
|
|
query = "select email from users where account_id=$1"
|
|
err = conn.QueryRow(context.Background(), query, id).Scan(&email)
|
|
if err == nil && email != "" {
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func lookupUsername(conn *pgx.Conn, email string) (username string, err error) {
|
|
query := "select account_id from users where email=$1"
|
|
var id int
|
|
err = conn.QueryRow(context.Background(), query, email).Scan(&id)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
query = "select username from accounts where id=$1"
|
|
err = conn.QueryRow(context.Background(), query, id).Scan(&username)
|
|
return
|
|
}
|
|
|
|
func auth(conn *pgx.Conn, email, password string) (err error) {
|
|
query := "select account_id,confirmed_at,approved,disabled," +
|
|
"encrypted_password from users where email=$1"
|
|
|
|
var id int
|
|
var confirmed pgtype.Timestamptz
|
|
var approved, disabled bool
|
|
var hash string
|
|
|
|
err = conn.QueryRow(context.Background(), query, email).
|
|
Scan(&id, &confirmed, &approved, &disabled, &hash)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if !confirmed.Valid {
|
|
return errors.New("account is not confirmed")
|
|
}
|
|
if !approved {
|
|
return errors.New("account is not approved")
|
|
}
|
|
if disabled {
|
|
return errors.New("account disabled")
|
|
}
|
|
|
|
query = "select suspended_at from accounts where id=$1"
|
|
|
|
var suspended pgtype.Timestamptz
|
|
|
|
err = conn.QueryRow(context.Background(), query, id).Scan(&suspended)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if suspended.Valid {
|
|
return errors.New("account suspended")
|
|
}
|
|
|
|
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
|
}
|
|
|
|
type handler struct {
|
|
Banlist map[string]bool
|
|
}
|
|
|
|
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
_, banned := h.Banlist[r.Header.Get("X-Forwarded-For")]
|
|
if banned {
|
|
return
|
|
}
|
|
|
|
if r.Method != "POST" {
|
|
fmt.Println("method is not POST", r)
|
|
return
|
|
}
|
|
|
|
var params struct {
|
|
Secret, Username, Email, Password string
|
|
}
|
|
|
|
err := json.NewDecoder(r.Body).Decode(¶ms)
|
|
if err != nil {
|
|
fmt.Println(err, r)
|
|
return
|
|
}
|
|
|
|
if params.Secret != os.Getenv("AUTH_SECRET") {
|
|
fmt.Println("wrong secret", r)
|
|
h.Banlist[r.Header.Get("X-Forwarded-For")] = true
|
|
return
|
|
}
|
|
|
|
var response struct {
|
|
Username, Email, Error string
|
|
}
|
|
|
|
defer func() {
|
|
buf, _ := json.Marshal(response)
|
|
w.Write(buf)
|
|
}()
|
|
|
|
path := "dbname=" + os.Getenv("DATABASE")
|
|
conn, err := pgx.Connect(context.Background(), path)
|
|
if err != nil {
|
|
response.Error = fmt.Sprint(err)
|
|
return
|
|
}
|
|
defer conn.Close(context.Background())
|
|
|
|
if params.Email == "" {
|
|
params.Email, err = lookupEmail(conn, params.Username)
|
|
if err != nil {
|
|
response.Error = fmt.Sprint(err)
|
|
return
|
|
}
|
|
}
|
|
|
|
if params.Username == "" {
|
|
params.Username, err = lookupUsername(conn, params.Email)
|
|
if err != nil {
|
|
response.Error = fmt.Sprint(err)
|
|
return
|
|
}
|
|
}
|
|
|
|
err = auth(conn, params.Email, params.Password)
|
|
if err != nil {
|
|
response.Error = fmt.Sprint(err)
|
|
}
|
|
|
|
response.Username = params.Username
|
|
response.Email = params.Email
|
|
}
|
|
|
|
func main() {
|
|
if os.Getenv("AUTH_SECRET") == "" {
|
|
panic("AUTH_SECRET is empty")
|
|
}
|
|
|
|
socket, err := net.Listen("unix", os.Getenv("SOCKET"))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
os.Chmod(os.Getenv("SOCKET"), 0770)
|
|
|
|
server := http.Server{
|
|
Handler: &handler{Banlist: make(map[string]bool)},
|
|
}
|
|
server.Serve(socket)
|
|
}
|