From 950446bbdc13d62f1f8b674cdf1b852d8497635b Mon Sep 17 00:00:00 2001 From: Sebastian Harl Date: Sun, 23 Nov 2014 12:33:27 +0100 Subject: [PATCH] Let the query helper ensure that all string arguments are escaped. Introduce a new type, identifier, to be used for strings that shall never be escaped, similar to how html/template treats the HTML, JS, etc. types. The query helper now expects printf style format and value arguments. --- server/query.go | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/server/query.go b/server/query.go index 16b250c..edeae73 100644 --- a/server/query.go +++ b/server/query.go @@ -42,7 +42,7 @@ func listAll(req request, s *Server) (*page, error) { return nil, fmt.Errorf("%s not found", strings.Title(req.cmd)) } - res, err := s.query(fmt.Sprintf("LIST %s", req.cmd)) + res, err := s.query("LIST %s", identifier(req.cmd)) if err != nil { return nil, err } @@ -54,12 +54,12 @@ func lookup(req request, s *Server) (*page, error) { if req.r.Method != "POST" { return nil, errors.New("Method not allowed") } - q := proto.EscapeString(req.r.FormValue("query")) + q := req.r.FormValue("query") if q == "''" { return nil, errors.New("Empty query") } - res, err := s.query(fmt.Sprintf("LOOKUP hosts MATCHING name =~ %s", q)) + res, err := s.query("LOOKUP hosts MATCHING name =~ %s", q) if err != nil { return nil, err } @@ -71,35 +71,46 @@ func fetch(req request, s *Server) (*page, error) { return nil, fmt.Errorf("%s not found", strings.Title(req.cmd)) } - var q string + var res interface{} + var err error switch req.cmd { case "host": if len(req.args) != 1 { return nil, fmt.Errorf("%s not found", strings.Title(req.cmd)) } - q = fmt.Sprintf("FETCH host %s", proto.EscapeString(req.args[0])) + res, err = s.query("FETCH host %s", req.args[0]) case "service", "metric": if len(req.args) != 2 { return nil, fmt.Errorf("%s not found", strings.Title(req.cmd)) } - host := proto.EscapeString(req.args[0]) - name := proto.EscapeString(req.args[1]) - q = fmt.Sprintf("FETCH %s %s.%s", req.cmd, host, name) + res, err = s.query("FETCH %s %s.%s", identifier(req.cmd), req.args[0], req.args[1]) default: panic("Unknown request: fetch(" + req.cmd + ")") } - - res, err := s.query(q) if err != nil { return nil, err } return tmpl(s.results[req.cmd], res) } -func (s *Server) query(cmd string) (interface{}, error) { +type identifier string + +func (s *Server) query(cmd string, args ...interface{}) (interface{}, error) { c := <-s.conns defer func() { s.conns <- c }() + for i, arg := range args { + switch v := arg.(type) { + case identifier: + // Nothing to do. + case string: + args[i] = proto.EscapeString(v) + default: + panic(fmt.Sprintf("query: invalid type %T", arg)) + } + } + + cmd = fmt.Sprintf(cmd, args...) m := &proto.Message{ Type: proto.ConnectionQuery, Raw: []byte(cmd), -- 2.30.2