package main import ( "context" "fmt" "io" "log" "net/http" "os" "golang.org/x/oauth2" ) func main() { authConfigPath := os.Getenv("AUTH_CONFIG") if authConfigPath == "" { log.Fatalf("AUTH_CONFIG env is missing") } token, err := readConfig(authConfigPath) if err != nil { log.Fatalf("failed to read config file: %v\n", err) } oauthConfig := oauth2.Config{ ClientID: os.Getenv("CLIENT_ID"), ClientSecret: os.Getenv("CLIENT_SECRET"), Endpoint: oauth2.Endpoint{ AuthURL: os.Getenv("AUTH_URL"), TokenURL: os.Getenv("TOKEN_URL"), DeviceAuthURL: os.Getenv("DEVICE_AUTH_URL"), AuthStyle: oauth2.AuthStyleAutoDetect, }, RedirectURL: os.Getenv("REDIRECT_URI"), } oauthClient := oauthConfig.Client(context.Background(), token) mux := http.NewServeMux() mux.HandleFunc("GET /oauth", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, oauthConfig.AuthCodeURL(""), http.StatusTemporaryRedirect) }) mux.HandleFunc("GET /oauth/callback", func(w http.ResponseWriter, r *http.Request) { token, err = oauthConfig.Exchange(context.Background(), r.URL.Query().Get("code")) if err != nil { http.Error(w, err.Error(), 500) return } oauthClient = oauthConfig.Client(context.Background(), token) if err := writeConfig(authConfigPath, token); err != nil { http.Error(w, err.Error(), 500) return } w.WriteHeader(200) fmt.Fprintf(w, "ok") }) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { tokenSource := oauthConfig.TokenSource(context.Background(), token) newToken, err := tokenSource.Token() if err != nil { http.Error(w, fmt.Sprintf("failed to get new token: %v", err), 500) return } if newToken.AccessToken != token.AccessToken { if err := writeConfig(authConfigPath, newToken); err != nil { http.Error(w, fmt.Sprintf("failed to save new token: %v", err), 500) return } } baseURL := os.Getenv("API_URL") req, err := http.NewRequest(r.Method, baseURL+r.URL.Path, r.Body) if err != nil { http.Error(w, fmt.Sprintf("failed to create request: %v", err), 500) return } req.URL.RawQuery = r.URL.RawQuery resp, err := oauthClient.Do(req) if err != nil { http.Error(w, err.Error(), 500) return } defer resp.Body.Close() for header, values := range resp.Header { for _, v := range values { w.Header().Add(header, v) } } if _, err := io.Copy(w, resp.Body); err != nil { http.Error(w, err.Error(), 500) } }) log.Println("starting http server") if err := http.ListenAndServe(":5000", mux); err != nil { log.Fatalf("failed to start http server: %v\n", err) } }