170 lines
3.4 KiB
Go
170 lines
3.4 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"os"
|
||
|
"os/signal"
|
||
|
"strings"
|
||
|
"syscall"
|
||
|
|
||
|
"github.com/jackc/pgx/v5"
|
||
|
"github.com/lor00x/goldap/message"
|
||
|
ldap "github.com/vjeantet/ldapserver"
|
||
|
)
|
||
|
|
||
|
func auth(username, email, password string) (ruser, remail string, err error) {
|
||
|
var payload struct {
|
||
|
Secret string
|
||
|
Username string
|
||
|
Email string
|
||
|
Password string
|
||
|
}
|
||
|
|
||
|
payload.Secret = os.Getenv("AUTH_SECRET")
|
||
|
payload.Username = username
|
||
|
payload.Email = email
|
||
|
payload.Password = password
|
||
|
|
||
|
raw, err := json.Marshal(payload)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
body := bytes.NewReader(raw)
|
||
|
|
||
|
req, err := http.NewRequest("POST", os.Getenv("AUTH_URL"), body)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
req.Header.Set("Content-Type", "application/json")
|
||
|
|
||
|
resp, err := http.DefaultClient.Do(req)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
defer resp.Body.Close()
|
||
|
|
||
|
var result struct {
|
||
|
Username string
|
||
|
Email string
|
||
|
Error string
|
||
|
}
|
||
|
|
||
|
err = json.NewDecoder(resp.Body).Decode(&result)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if result.Error != "" {
|
||
|
err = errors.New(result.Error)
|
||
|
}
|
||
|
|
||
|
ruser = result.Username
|
||
|
remail = result.Email
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func bind(w ldap.ResponseWriter, m *ldap.Message) {
|
||
|
r := m.GetBindRequest()
|
||
|
res := ldap.NewBindResponse(ldap.LDAPResultSuccess)
|
||
|
|
||
|
username := string(r.Name())
|
||
|
password := string(r.AuthenticationSimple())
|
||
|
|
||
|
if username == os.Getenv("LDAP_USER") {
|
||
|
if password == os.Getenv("LDAP_PASS") {
|
||
|
w.Write(res)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if username == "root" {
|
||
|
res.SetResultCode(ldap.LDAPResultInvalidCredentials)
|
||
|
res.SetDiagnosticMessage("root login is disabled")
|
||
|
w.Write(res)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
_, _, err := auth(username, "", password)
|
||
|
if err == nil {
|
||
|
fmt.Println("bind:", username, "ok")
|
||
|
w.Write(res)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
fmt.Println("bind:", username, "incorrect password")
|
||
|
res.SetResultCode(ldap.LDAPResultInvalidCredentials)
|
||
|
res.SetDiagnosticMessage("invalid credentials")
|
||
|
w.Write(res)
|
||
|
}
|
||
|
|
||
|
func updateEmail(username, email string) (err error) {
|
||
|
path := "dbname=" + os.Getenv("DATABASE")
|
||
|
conn, err := pgx.Connect(context.Background(), path)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
defer conn.Close(context.Background())
|
||
|
|
||
|
query := "update \"user\" set email=$1 where username=$2"
|
||
|
_, err = conn.Exec(context.Background(), query, email, username)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func search(w ldap.ResponseWriter, m *ldap.Message) {
|
||
|
r := m.GetSearchRequest()
|
||
|
res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultSuccess)
|
||
|
defer w.Write(res)
|
||
|
|
||
|
var username, email string
|
||
|
|
||
|
s := r.FilterString()
|
||
|
s = s[6 : len(s)-1]
|
||
|
|
||
|
if strings.Contains(s, "@") {
|
||
|
username, email, _ = auth("", s, "")
|
||
|
} else {
|
||
|
username, email, _ = auth(s, "", "")
|
||
|
}
|
||
|
|
||
|
if username == "" || email == "" {
|
||
|
fmt.Println("search:", s, "not found")
|
||
|
return
|
||
|
}
|
||
|
fmt.Println("search:", s, "found", username, email)
|
||
|
|
||
|
err := updateEmail(username, email)
|
||
|
if err != nil {
|
||
|
fmt.Println(err)
|
||
|
}
|
||
|
|
||
|
e := ldap.NewSearchResultEntry(strings.ToLower(username))
|
||
|
e.AddAttribute("uid", message.AttributeValue(strings.ToLower(username)))
|
||
|
e.AddAttribute("mail", message.AttributeValue(strings.ToLower(email)))
|
||
|
w.Write(e)
|
||
|
}
|
||
|
|
||
|
func main() {
|
||
|
ldap.Logger = ldap.DiscardingLogger
|
||
|
|
||
|
server := ldap.NewServer()
|
||
|
|
||
|
routes := ldap.NewRouteMux()
|
||
|
routes.Bind(bind)
|
||
|
routes.Search(search)
|
||
|
server.Handle(routes)
|
||
|
|
||
|
go server.ListenAndServe(":10389")
|
||
|
|
||
|
ch := make(chan os.Signal)
|
||
|
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
|
||
|
<-ch
|
||
|
close(ch)
|
||
|
|
||
|
server.Stop()
|
||
|
}
|