package main import ( "context" "encoding/json" "errors" "fmt" "net" "net/http" "os" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "golang.org/x/crypto/bcrypt" ) func lookupEmail(conn *pgx.Conn, username string) (email string, err error) { query := "select id from accounts where username=$1" rows, err := conn.Query(context.Background(), query, username) if err != nil { return } var ids []int for rows.Next() { var id int rows.Scan(&id) ids = append(ids, id) } rows.Close() if rows.Err() != nil { return } for _, id := range ids { query = "select email from users where account_id=$1" err = conn.QueryRow(context.Background(), query, id).Scan(&email) if err == nil && email != "" { return } } return } func lookupUsername(conn *pgx.Conn, email string) (username string, err error) { query := "select account_id from users where email=$1" var id int err = conn.QueryRow(context.Background(), query, email).Scan(&id) if err != nil { return } query = "select username from accounts where id=$1" err = conn.QueryRow(context.Background(), query, id).Scan(&username) return } func auth(conn *pgx.Conn, email, password string) (err error) { query := "select account_id,confirmed_at,approved,disabled," + "encrypted_password from users where email=$1" var id int var confirmed pgtype.Timestamptz var approved, disabled bool var hash string err = conn.QueryRow(context.Background(), query, email). Scan(&id, &confirmed, &approved, &disabled, &hash) if err != nil { return } if !confirmed.Valid { return errors.New("account is not confirmed") } if !approved { return errors.New("account is not approved") } if disabled { return errors.New("account disabled") } query = "select suspended_at from accounts where id=$1" var suspended pgtype.Timestamptz err = conn.QueryRow(context.Background(), query, id).Scan(&suspended) if err != nil { return } if suspended.Valid { return errors.New("account suspended") } return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) } type handler struct { Banlist map[string]bool } func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { _, banned := h.Banlist[r.Header.Get("X-Forwarded-For")] if banned { return } if r.Method != "POST" { fmt.Println("method is not POST", r) return } var params struct { Secret, Username, Email, Password string } err := json.NewDecoder(r.Body).Decode(¶ms) if err != nil { fmt.Println(err, r) return } if params.Secret != os.Getenv("AUTH_SECRET") { fmt.Println("wrong secret", r) h.Banlist[r.Header.Get("X-Forwarded-For")] = true return } var response struct { Username, Email, Error string } defer func() { buf, _ := json.Marshal(response) w.Write(buf) }() path := "dbname=" + os.Getenv("DATABASE") conn, err := pgx.Connect(context.Background(), path) if err != nil { response.Error = fmt.Sprint(err) return } defer conn.Close(context.Background()) if params.Email == "" { params.Email, err = lookupEmail(conn, params.Username) if err != nil { response.Error = fmt.Sprint(err) return } } if params.Username == "" { params.Username, err = lookupUsername(conn, params.Email) if err != nil { response.Error = fmt.Sprint(err) return } } err = auth(conn, params.Email, params.Password) if err != nil { response.Error = fmt.Sprint(err) } response.Username = params.Username response.Email = params.Email } func main() { if os.Getenv("AUTH_SECRET") == "" { panic("AUTH_SECRET is empty") } socket, err := net.Listen("unix", os.Getenv("SOCKET")) if err != nil { panic(err) } os.Chmod(os.Getenv("SOCKET"), 0770) server := http.Server{ Handler: &handler{Banlist: make(map[string]bool)}, } server.Serve(socket) }