Commit a1f34bec authored by Hrishikesh Barman's avatar Hrishikesh Barman Committed by Brian Brazil
Browse files

Added CORS Origin flag (#5011)


Signed-off-by: default avatarHrishikesh Barman <hrishikeshbman@gmail.com>
parent c44cd7e1
......@@ -26,6 +26,7 @@ import (
"os"
"os/signal"
"path/filepath"
"regexp"
"runtime"
"strings"
"sync"
......@@ -50,6 +51,7 @@ import (
"github.com/prometheus/prometheus/discovery"
sd_config "github.com/prometheus/prometheus/discovery/config"
"github.com/prometheus/prometheus/notifier"
"github.com/prometheus/prometheus/pkg/relabel"
"github.com/prometheus/prometheus/promql"
"github.com/prometheus/prometheus/rules"
"github.com/prometheus/prometheus/scrape"
......@@ -99,7 +101,8 @@ func main() {
queryMaxSamples int
RemoteFlushDeadline model.Duration
prometheusURL string
prometheusURL string
corsRegexString string
promlogConfig promlog.Config
}{
......@@ -209,6 +212,9 @@ func main() {
a.Flag("query.max-samples", "Maximum number of samples a single query can load into memory. Note that queries will fail if they would load more samples than this into memory, so this also limits the number of samples a query can return.").
Default("50000000").IntVar(&cfg.queryMaxSamples)
a.Flag("web.cors.origin", `Regex for CORS origin. It is fully anchored. Eg. 'https?://(domain1|domain2)\.com'`).
Default(".*").StringVar(&cfg.corsRegexString)
promlogflag.AddFlags(a, &cfg.promlogConfig)
_, err := a.Parse(os.Args[1:])
......@@ -224,6 +230,12 @@ func main() {
os.Exit(2)
}
cfg.web.CORSOrigin, err = compileCORSRegexString(cfg.corsRegexString)
if err != nil {
fmt.Fprintln(os.Stderr, errors.Wrapf(err, "could not compile CORS regex string %q", cfg.corsRegexString))
os.Exit(2)
}
cfg.web.ReadTimeout = time.Duration(cfg.webTimeout)
// Default -web.route-prefix to path of -web.external-url.
if cfg.web.RoutePrefix == "" {
......@@ -674,6 +686,15 @@ func startsOrEndsWithQuote(s string) bool {
strings.HasSuffix(s, "\"") || strings.HasSuffix(s, "'")
}
// compileCORSRegexString compiles given string and adds anchors
func compileCORSRegexString(s string) (*regexp.Regexp, error) {
r, err := relabel.NewRegexp(s)
if err != nil {
return nil, err
}
return r.Regexp, nil
}
// computeExternalURL computes a sanitized external URL from a raw input. It infers unset
// URL parts from the OS and the given listen address.
func computeExternalURL(u, listenAddr string) (*url.URL, error) {
......
// Copyright 2013 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httputil
import (
"net/http"
"regexp"
)
var corsHeaders = map[string]string{
"Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Expose-Headers": "Date",
"Vary": "Origin",
}
// Enables cross-site script calls.
func SetCORS(w http.ResponseWriter, o *regexp.Regexp, r *http.Request) {
origin := r.Header.Get("Origin")
if origin == "" {
return
}
for k, v := range corsHeaders {
w.Header().Set(k, v)
}
if o.String() == ".*" {
w.Header().Set("Access-Control-Allow-Origin", "*")
return
}
if o.MatchString(origin) {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
}
// Copyright 2016 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httputil
import (
"net/http"
"regexp"
"testing"
)
func getCORSHandlerFunc() http.Handler {
hf := func(w http.ResponseWriter, r *http.Request) {
reg := regexp.MustCompile(`^https://foo\.com$`)
SetCORS(w, reg, r)
w.WriteHeader(http.StatusOK)
}
return http.HandlerFunc(hf)
}
func TestCORSHandler(t *testing.T) {
tearDown := setup()
defer tearDown()
client := &http.Client{}
ch := getCORSHandlerFunc()
mux.Handle("/any_path", ch)
dummyOrigin := "https://foo.com"
// OPTIONS with legit origin
req, err := http.NewRequest("OPTIONS", server.URL+"/any_path", nil)
if err != nil {
t.Error("could not create request")
}
req.Header.Set("Origin", dummyOrigin)
resp, err := client.Do(req)
if err != nil {
t.Error("client get failed with unexpected error")
}
AccessControlAllowOrigin := resp.Header.Get("Access-Control-Allow-Origin")
if AccessControlAllowOrigin != dummyOrigin {
t.Fatalf("%q does not match %q", dummyOrigin, AccessControlAllowOrigin)
}
// OPTIONS with bad origin
req, err = http.NewRequest("OPTIONS", server.URL+"/any_path", nil)
if err != nil {
t.Error("could not create request")
}
req.Header.Set("Origin", "https://not-foo.com")
resp, err = client.Do(req)
if err != nil {
t.Error("client get failed with unexpected error")
}
AccessControlAllowOrigin = resp.Header.Get("Access-Control-Allow-Origin")
if AccessControlAllowOrigin != "" {
t.Fatalf("Access-Control-Allow-Origin should not exist but it was set to: %q", AccessControlAllowOrigin)
}
}
......@@ -23,6 +23,7 @@ import (
"net/url"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"time"
......@@ -76,13 +77,6 @@ const (
errorNotFound errorType = "not_found"
)
var corsHeaders = map[string]string{
"Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin",
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Origin": "*",
"Access-Control-Expose-Headers": "Date",
}
var remoteReadQueries = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: namespace,
Subsystem: subsystem,
......@@ -129,13 +123,6 @@ type apiFuncResult struct {
finalizer func()
}
// Enables cross-site script calls.
func setCORS(w http.ResponseWriter) {
for h, v := range corsHeaders {
w.Header().Set(h, v)
}
}
type apiFunc func(r *http.Request) apiFuncResult
// TSDBAdmin defines the tsdb interfaces used by the v1 API for admin operations.
......@@ -165,6 +152,7 @@ type API struct {
logger log.Logger
remoteReadSampleLimit int
remoteReadGate *gate.Gate
CORSOrigin *regexp.Regexp
}
func init() {
......@@ -187,6 +175,7 @@ func NewAPI(
rr rulesRetriever,
remoteReadSampleLimit int,
remoteReadConcurrencyLimit int,
CORSOrigin *regexp.Regexp,
) *API {
return &API{
QueryEngine: qe,
......@@ -204,6 +193,7 @@ func NewAPI(
remoteReadSampleLimit: remoteReadSampleLimit,
remoteReadGate: gate.New(remoteReadConcurrencyLimit),
logger: logger,
CORSOrigin: CORSOrigin,
}
}
......@@ -211,7 +201,7 @@ func NewAPI(
func (api *API) Register(r *route.Router) {
wrap := func(f apiFunc) http.HandlerFunc {
hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
setCORS(w)
httputil.SetCORS(w, api.CORSOrigin, r)
result := f(r)
if result.err != nil {
api.respondError(w, result.err, result.data)
......
......@@ -1416,12 +1416,6 @@ func TestOptionsMethod(t *testing.T) {
if resp.StatusCode != http.StatusNoContent {
t.Fatalf("Expected status %d, got %d", http.StatusNoContent, resp.StatusCode)
}
for h, v := range corsHeaders {
if resp.Header.Get(h) != v {
t.Fatalf("Expected %q for header %q, got %q", v, h, resp.Header.Get(h))
}
}
}
func TestRespond(t *testing.T) {
......
......@@ -28,6 +28,7 @@ import (
"os"
"path"
"path/filepath"
"regexp"
"runtime"
"sort"
"strings"
......@@ -173,6 +174,7 @@ type Options struct {
Flags map[string]string
ListenAddress string
CORSOrigin *regexp.Regexp
ReadTimeout time.Duration
MaxConnections int
ExternalURL *url.URL
......@@ -259,6 +261,7 @@ func New(logger log.Logger, o *Options) *Handler {
h.ruleManager,
h.options.RemoteReadSampleLimit,
h.options.RemoteReadConcurrencyLimit,
h.options.CORSOrigin,
)
if o.RoutePrefix != "/" {
......@@ -340,20 +343,6 @@ func New(logger log.Logger, o *Options) *Handler {
return h
}
var corsHeaders = map[string]string{
"Access-Control-Allow-Headers": "Accept, Authorization, Content-Type, Origin",
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Origin": "*",
"Access-Control-Expose-Headers": "Date",
}
// Enables cross-site script calls.
func setCORS(w http.ResponseWriter) {
for h, v := range corsHeaders {
w.Header().Set(h, v)
}
}
func serveDebug(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
subpath := route.Param(ctx, "subpath")
......@@ -474,7 +463,7 @@ func (h *Handler) Run(ctx context.Context) error {
mux.Handle(apiPath+"/", http.StripPrefix(apiPath,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
setCORS(w)
httputil.SetCORS(w, h.options.CORSOrigin, r)
hhFunc(w, r)
}),
))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment