|
| 1 | +package api |
| 2 | + |
| 3 | +import ( |
| 4 | + "crypto/sha256" |
| 5 | + "encoding/hex" |
| 6 | + "fmt" |
| 7 | + "os" |
| 8 | + "path/filepath" |
| 9 | + "sync" |
| 10 | +) |
| 11 | + |
| 12 | +// NodeMode determines how the node handles SGX enclave calls |
| 13 | +type NodeMode string |
| 14 | + |
| 15 | +const ( |
| 16 | + // NodeModeSGX - Run with real SGX enclave and record outputs |
| 17 | + NodeModeSGX NodeMode = "sgx" |
| 18 | + // NodeModeReplay - Replay recorded outputs without SGX |
| 19 | + NodeModeReplay NodeMode = "replay" |
| 20 | +) |
| 21 | + |
| 22 | +// EcallRecorder handles recording and replaying ecall data |
| 23 | +type EcallRecorder struct { |
| 24 | + mu sync.RWMutex |
| 25 | + mode NodeMode |
| 26 | + recordDir string |
| 27 | +} |
| 28 | + |
| 29 | +var ( |
| 30 | + globalRecorder *EcallRecorder |
| 31 | + recorderOnce sync.Once |
| 32 | +) |
| 33 | + |
| 34 | +// GetRecorder returns the global ecall recorder instance |
| 35 | +func GetRecorder() *EcallRecorder { |
| 36 | + recorderOnce.Do(func() { |
| 37 | + mode := NodeMode(os.Getenv("SECRET_NODE_MODE")) |
| 38 | + if mode == "" { |
| 39 | + mode = NodeModeSGX // Default to SGX mode |
| 40 | + } |
| 41 | + |
| 42 | + recordDir := os.Getenv("SECRET_ECALL_RECORD_DIR") |
| 43 | + if recordDir == "" { |
| 44 | + recordDir = "/tmp/secret_ecall_records" |
| 45 | + } |
| 46 | + |
| 47 | + globalRecorder = &EcallRecorder{ |
| 48 | + mode: mode, |
| 49 | + recordDir: recordDir, |
| 50 | + } |
| 51 | + |
| 52 | + // Create record directory if it doesn't exist |
| 53 | + if err := os.MkdirAll(recordDir, 0o755); err != nil { |
| 54 | + fmt.Printf("Warning: could not create ecall record directory: %v\n", err) |
| 55 | + } |
| 56 | + |
| 57 | + fmt.Printf("[EcallRecorder] Initialized in %s mode, record dir: %s\n", mode, recordDir) |
| 58 | + }) |
| 59 | + return globalRecorder |
| 60 | +} |
| 61 | + |
| 62 | +// Mode returns the current node mode |
| 63 | +func (r *EcallRecorder) Mode() NodeMode { |
| 64 | + return r.mode |
| 65 | +} |
| 66 | + |
| 67 | +// IsSGXMode returns true if running in SGX mode |
| 68 | +func (r *EcallRecorder) IsSGXMode() bool { |
| 69 | + return r.mode == NodeModeSGX |
| 70 | +} |
| 71 | + |
| 72 | +// IsReplayMode returns true if running in replay mode |
| 73 | +func (r *EcallRecorder) IsReplayMode() bool { |
| 74 | + return r.mode == NodeModeReplay |
| 75 | +} |
| 76 | + |
| 77 | +// computeHash computes SHA256 hash of input for filename |
| 78 | +func computeHash(operation string, input []byte) string { |
| 79 | + h := sha256.New() |
| 80 | + h.Write([]byte(operation)) |
| 81 | + h.Write(input) |
| 82 | + return hex.EncodeToString(h.Sum(nil)) |
| 83 | +} |
| 84 | + |
| 85 | +// getFilePath returns the file path for a given operation and hash |
| 86 | +func (r *EcallRecorder) getFilePath(operation string, hash string) string { |
| 87 | + return filepath.Join(r.recordDir, fmt.Sprintf("%s_%s.bin", operation, hash[:16])) |
| 88 | +} |
| 89 | + |
| 90 | +// Record stores an ecall output to file (used in SGX mode) |
| 91 | +func (r *EcallRecorder) Record(operation string, input []byte, output []byte, err error) error { |
| 92 | + r.mu.Lock() |
| 93 | + defer r.mu.Unlock() |
| 94 | + |
| 95 | + hash := computeHash(operation, input) |
| 96 | + filePath := r.getFilePath(operation, hash) |
| 97 | + |
| 98 | + // If there was an error, write empty file with .err extension |
| 99 | + if err != nil { |
| 100 | + errPath := filePath + ".err" |
| 101 | + if writeErr := os.WriteFile(errPath, []byte(err.Error()), 0o644); writeErr != nil { |
| 102 | + return fmt.Errorf("failed to write error file: %w", writeErr) |
| 103 | + } |
| 104 | + fmt.Printf("[EcallRecorder] Recorded error to %s\n", errPath) |
| 105 | + return nil |
| 106 | + } |
| 107 | + |
| 108 | + // Write output bytes directly to file |
| 109 | + if writeErr := os.WriteFile(filePath, output, 0o644); writeErr != nil { |
| 110 | + return fmt.Errorf("failed to write record file: %w", writeErr) |
| 111 | + } |
| 112 | + |
| 113 | + fmt.Printf("[EcallRecorder] Recorded %d bytes to %s\n", len(output), filePath) |
| 114 | + return nil |
| 115 | +} |
| 116 | + |
| 117 | +// Replay retrieves a recorded ecall output from file (used in replay mode) |
| 118 | +func (r *EcallRecorder) Replay(operation string, input []byte) ([]byte, error, bool) { |
| 119 | + r.mu.RLock() |
| 120 | + defer r.mu.RUnlock() |
| 121 | + |
| 122 | + hash := computeHash(operation, input) |
| 123 | + filePath := r.getFilePath(operation, hash) |
| 124 | + |
| 125 | + // Check for error file first |
| 126 | + errPath := filePath + ".err" |
| 127 | + if errData, err := os.ReadFile(errPath); err == nil { |
| 128 | + fmt.Printf("[EcallRecorder] Replayed error from %s\n", errPath) |
| 129 | + return nil, fmt.Errorf("%s", string(errData)), true |
| 130 | + } |
| 131 | + |
| 132 | + // Read output file |
| 133 | + output, err := os.ReadFile(filePath) |
| 134 | + if err != nil { |
| 135 | + if os.IsNotExist(err) { |
| 136 | + return nil, nil, false // Not found |
| 137 | + } |
| 138 | + return nil, fmt.Errorf("failed to read record file: %w", err), true |
| 139 | + } |
| 140 | + |
| 141 | + fmt.Printf("[EcallRecorder] Replayed %d bytes from %s\n", len(output), filePath) |
| 142 | + return output, nil, true |
| 143 | +} |
| 144 | + |
| 145 | +// --- Wrapper functions for specific ecalls --- |
| 146 | + |
| 147 | +// RecordGetEncryptedSeed records the GetEncryptedSeed ecall |
| 148 | +func RecordGetEncryptedSeed(cert []byte, output []byte, err error) error { |
| 149 | + return GetRecorder().Record("GetEncryptedSeed", cert, output, err) |
| 150 | +} |
| 151 | + |
| 152 | +// ReplayGetEncryptedSeed attempts to replay a recorded GetEncryptedSeed call |
| 153 | +func ReplayGetEncryptedSeed(cert []byte) ([]byte, error, bool) { |
| 154 | + return GetRecorder().Replay("GetEncryptedSeed", cert) |
| 155 | +} |
0 commit comments