diff --git a/plugins/environment.go b/plugins/environment.go index e47912df..fa6a4b2a 100644 --- a/plugins/environment.go +++ b/plugins/environment.go @@ -9,6 +9,7 @@ import ( "log" "os" "path" + "path/filepath" "strings" "github.com/golang/protobuf/proto" @@ -184,15 +185,47 @@ func HandleResponse(response *Response, outputLocation string) error { return fmt.Errorf("unable to overwrite %s", outputLocation) default: // write files into a directory named by outputLocation if !isDirectory(outputLocation) { - os.Mkdir(outputLocation, 0o755) + if err := os.Mkdir(outputLocation, 0o755); err != nil { + return fmt.Errorf("os.Mkdir(%q): %w", outputLocation, err) + } + } + // Resolve the canonical absolute path of the output directory once so + // that each file.Name can be validated against it, preventing a plugin + // from writing files outside the declared output location. + resolvedOut, err := filepath.EvalSymlinks(outputLocation) + if err != nil { + return fmt.Errorf("filepath.EvalSymlinks(%q): %w", outputLocation, err) + } + absOut, err := filepath.Abs(resolvedOut) + if err != nil { + return fmt.Errorf("filepath.Abs(%q): %w", resolvedOut, err) } for _, file := range response.Files { - p := outputLocation + "/" + file.Name - dir := path.Dir(p) - os.MkdirAll(dir, 0o755) - f, _ := os.Create(p) + // Sanitize file.Name to prevent path traversal. + cleanName := filepath.Clean(file.Name) + if filepath.IsAbs(cleanName) { + return fmt.Errorf("plugin returned absolute file path %q", file.Name) + } + p := filepath.Join(outputLocation, cleanName) + absP, err := filepath.Abs(p) + if err != nil { + return fmt.Errorf("filepath.Abs(%q): %w", p, err) + } + if absP != absOut && !strings.HasPrefix(absP, absOut+string(filepath.Separator)) { + return fmt.Errorf("plugin file path %q escapes the output directory", file.Name) + } + dir := filepath.Dir(p) + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("os.MkdirAll(%q): %w", dir, err) + } + f, err := os.Create(p) + if err != nil { + return fmt.Errorf("os.Create(%q): %w", p, err) + } defer f.Close() - f.Write(file.Data) + if _, err := f.Write(file.Data); err != nil { + return fmt.Errorf("writing %q: %w", p, err) + } } } return nil