Code

Let the query helper ensure that all string arguments are escaped.
authorSebastian Harl <sh@tokkee.org>
Sun, 23 Nov 2014 11:33:27 +0000 (12:33 +0100)
committerSebastian Harl <sh@tokkee.org>
Sun, 23 Nov 2014 11:33:27 +0000 (12:33 +0100)
Introduce a new type, identifier, to be used for strings that shall never be
escaped, similar to how html/template treats the HTML, JS, etc. types. The
query helper now expects printf style format and value arguments.

server/query.go

index 16b250c22128551541cf003b5a55de115bb664f7..edeae731baa5f9519618e25e2e9321dae2c7591b 100644 (file)
@@ -42,7 +42,7 @@ func listAll(req request, s *Server) (*page, error) {
                return nil, fmt.Errorf("%s not found", strings.Title(req.cmd))
        }
 
-       res, err := s.query(fmt.Sprintf("LIST %s", req.cmd))
+       res, err := s.query("LIST %s", identifier(req.cmd))
        if err != nil {
                return nil, err
        }
@@ -54,12 +54,12 @@ func lookup(req request, s *Server) (*page, error) {
        if req.r.Method != "POST" {
                return nil, errors.New("Method not allowed")
        }
-       q := proto.EscapeString(req.r.FormValue("query"))
+       q := req.r.FormValue("query")
        if q == "''" {
                return nil, errors.New("Empty query")
        }
 
-       res, err := s.query(fmt.Sprintf("LOOKUP hosts MATCHING name =~ %s", q))
+       res, err := s.query("LOOKUP hosts MATCHING name =~ %s", q)
        if err != nil {
                return nil, err
        }
@@ -71,35 +71,46 @@ func fetch(req request, s *Server) (*page, error) {
                return nil, fmt.Errorf("%s not found", strings.Title(req.cmd))
        }
 
-       var q string
+       var res interface{}
+       var err error
        switch req.cmd {
        case "host":
                if len(req.args) != 1 {
                        return nil, fmt.Errorf("%s not found", strings.Title(req.cmd))
                }
-               q = fmt.Sprintf("FETCH host %s", proto.EscapeString(req.args[0]))
+               res, err = s.query("FETCH host %s", req.args[0])
        case "service", "metric":
                if len(req.args) != 2 {
                        return nil, fmt.Errorf("%s not found", strings.Title(req.cmd))
                }
-               host := proto.EscapeString(req.args[0])
-               name := proto.EscapeString(req.args[1])
-               q = fmt.Sprintf("FETCH %s %s.%s", req.cmd, host, name)
+               res, err = s.query("FETCH %s %s.%s", identifier(req.cmd), req.args[0], req.args[1])
        default:
                panic("Unknown request: fetch(" + req.cmd + ")")
        }
-
-       res, err := s.query(q)
        if err != nil {
                return nil, err
        }
        return tmpl(s.results[req.cmd], res)
 }
 
-func (s *Server) query(cmd string) (interface{}, error) {
+type identifier string
+
+func (s *Server) query(cmd string, args ...interface{}) (interface{}, error) {
        c := <-s.conns
        defer func() { s.conns <- c }()
 
+       for i, arg := range args {
+               switch v := arg.(type) {
+               case identifier:
+                       // Nothing to do.
+               case string:
+                       args[i] = proto.EscapeString(v)
+               default:
+                       panic(fmt.Sprintf("query: invalid type %T", arg))
+               }
+       }
+
+       cmd = fmt.Sprintf(cmd, args...)
        m := &proto.Message{
                Type: proto.ConnectionQuery,
                Raw:  []byte(cmd),