Files
oauth2-proxy/main.go

108 lines
2.6 KiB
Go

package main
import (
"context"
"fmt"
"io"
"log"
"net/http"
"os"
"golang.org/x/oauth2"
)
func main() {
authConfigPath := os.Getenv("AUTH_CONFIG")
if authConfigPath == "" {
log.Fatalf("AUTH_CONFIG env is missing")
}
token, err := readConfig(authConfigPath)
if err != nil {
log.Fatalf("failed to read config file: %v\n", err)
}
oauthConfig := oauth2.Config{
ClientID: os.Getenv("CLIENT_ID"),
ClientSecret: os.Getenv("CLIENT_SECRET"),
Endpoint: oauth2.Endpoint{
AuthURL: os.Getenv("AUTH_URL"),
TokenURL: os.Getenv("TOKEN_URL"),
DeviceAuthURL: os.Getenv("DEVICE_AUTH_URL"),
AuthStyle: oauth2.AuthStyleAutoDetect,
},
RedirectURL: os.Getenv("REDIRECT_URI"),
}
oauthClient := oauthConfig.Client(context.Background(), token)
mux := http.NewServeMux()
mux.HandleFunc("GET /oauth", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, oauthConfig.AuthCodeURL(""), http.StatusTemporaryRedirect)
})
mux.HandleFunc("GET /oauth/callback", func(w http.ResponseWriter, r *http.Request) {
token, err = oauthConfig.Exchange(context.Background(), r.URL.Query().Get("code"))
if err != nil {
http.Error(w, err.Error(), 500)
return
}
oauthClient = oauthConfig.Client(context.Background(), token)
if err := writeConfig(authConfigPath, token); err != nil {
http.Error(w, err.Error(), 500)
return
}
w.WriteHeader(200)
fmt.Fprintf(w, "ok")
})
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
tokenSource := oauthConfig.TokenSource(context.Background(), token)
newToken, err := tokenSource.Token()
if err != nil {
http.Error(w, fmt.Sprintf("failed to get new token: %v", err), 500)
return
}
if newToken.AccessToken != token.AccessToken {
if err := writeConfig(authConfigPath, newToken); err != nil {
http.Error(w, fmt.Sprintf("failed to save new token: %v", err), 500)
return
}
}
baseURL := os.Getenv("API_URL")
req, err := http.NewRequest(r.Method, baseURL+r.URL.Path, r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("failed to create request: %v", err), 500)
return
}
req.URL.RawQuery = r.URL.RawQuery
resp, err := oauthClient.Do(req)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
defer resp.Body.Close()
for header, values := range resp.Header {
for _, v := range values {
w.Header().Add(header, v)
}
}
if _, err := io.Copy(w, resp.Body); err != nil {
http.Error(w, err.Error(), 500)
}
})
log.Println("starting http server")
if err := http.ListenAndServe(":5000", mux); err != nil {
log.Fatalf("failed to start http server: %v\n", err)
}
}