lor.sh/auth/auth.go

187 lines
3.7 KiB
Go

package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"os"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"golang.org/x/crypto/bcrypt"
)
func lookupEmail(conn *pgx.Conn, username string) (email string, err error) {
query := "select id from accounts where username=$1"
rows, err := conn.Query(context.Background(), query, username)
if err != nil {
return
}
var ids []int
for rows.Next() {
var id int
rows.Scan(&id)
ids = append(ids, id)
}
rows.Close()
if rows.Err() != nil {
return
}
for _, id := range ids {
query = "select email from users where account_id=$1"
err = conn.QueryRow(context.Background(), query, id).Scan(&email)
if err == nil && email != "" {
return
}
}
return
}
func lookupUsername(conn *pgx.Conn, email string) (username string, err error) {
query := "select account_id from users where email=$1"
var id int
err = conn.QueryRow(context.Background(), query, email).Scan(&id)
if err != nil {
return
}
query = "select username from accounts where id=$1"
err = conn.QueryRow(context.Background(), query, id).Scan(&username)
return
}
func auth(conn *pgx.Conn, email, password string) (err error) {
query := "select account_id,confirmed_at,approved,disabled," +
"encrypted_password from users where email=$1"
var id int
var confirmed pgtype.Timestamptz
var approved, disabled bool
var hash string
err = conn.QueryRow(context.Background(), query, email).
Scan(&id, &confirmed, &approved, &disabled, &hash)
if err != nil {
return
}
if !confirmed.Valid {
return errors.New("account is not confirmed")
}
if !approved {
return errors.New("account is not approved")
}
if disabled {
return errors.New("account disabled")
}
query = "select suspended_at from accounts where id=$1"
var suspended pgtype.Timestamptz
err = conn.QueryRow(context.Background(), query, id).Scan(&suspended)
if err != nil {
return
}
if suspended.Valid {
return errors.New("account suspended")
}
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
}
type handler struct {
Banlist map[string]bool
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
_, banned := h.Banlist[r.Header.Get("X-Forwarded-For")]
if banned {
return
}
if r.Method != "POST" {
fmt.Println("method is not POST", r)
return
}
var params struct {
Secret, Username, Email, Password string
}
err := json.NewDecoder(r.Body).Decode(&params)
if err != nil {
fmt.Println(err, r)
return
}
if params.Secret != os.Getenv("AUTH_SECRET") {
fmt.Println("wrong secret", r)
h.Banlist[r.Header.Get("X-Forwarded-For")] = true
return
}
var response struct {
Username, Email, Error string
}
defer func() {
buf, _ := json.Marshal(response)
w.Write(buf)
}()
path := "dbname=" + os.Getenv("DATABASE")
conn, err := pgx.Connect(context.Background(), path)
if err != nil {
response.Error = fmt.Sprint(err)
return
}
defer conn.Close(context.Background())
if params.Email == "" {
params.Email, err = lookupEmail(conn, params.Username)
if err != nil {
response.Error = fmt.Sprint(err)
return
}
}
if params.Username == "" {
params.Username, err = lookupUsername(conn, params.Email)
if err != nil {
response.Error = fmt.Sprint(err)
return
}
}
err = auth(conn, params.Email, params.Password)
if err != nil {
response.Error = fmt.Sprint(err)
}
response.Username = params.Username
response.Email = params.Email
}
func main() {
if os.Getenv("AUTH_SECRET") == "" {
panic("AUTH_SECRET is empty")
}
socket, err := net.Listen("unix", os.Getenv("SOCKET"))
if err != nil {
panic(err)
}
os.Chmod(os.Getenv("SOCKET"), 0770)
server := http.Server{
Handler: &handler{Banlist: make(map[string]bool)},
}
server.Serve(socket)
}