Skip to content

Commit 10d110d

Browse files
committed
write ecalls
1 parent 33b8331 commit 10d110d

File tree

2 files changed

+182
-2
lines changed

2 files changed

+182
-2
lines changed

go-cosmwasm/api/ecall_record.go

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package api
2+
3+
import (
4+
"crypto/sha256"
5+
"encoding/hex"
6+
"fmt"
7+
"os"
8+
"path/filepath"
9+
"sync"
10+
)
11+
12+
// NodeMode determines how the node handles SGX enclave calls
13+
type NodeMode string
14+
15+
const (
16+
// NodeModeSGX - Run with real SGX enclave and record outputs
17+
NodeModeSGX NodeMode = "sgx"
18+
// NodeModeReplay - Replay recorded outputs without SGX
19+
NodeModeReplay NodeMode = "replay"
20+
)
21+
22+
// EcallRecorder handles recording and replaying ecall data
23+
type EcallRecorder struct {
24+
mu sync.RWMutex
25+
mode NodeMode
26+
recordDir string
27+
}
28+
29+
var (
30+
globalRecorder *EcallRecorder
31+
recorderOnce sync.Once
32+
)
33+
34+
// GetRecorder returns the global ecall recorder instance
35+
func GetRecorder() *EcallRecorder {
36+
recorderOnce.Do(func() {
37+
mode := NodeMode(os.Getenv("SECRET_NODE_MODE"))
38+
if mode == "" {
39+
mode = NodeModeSGX // Default to SGX mode
40+
}
41+
42+
recordDir := os.Getenv("SECRET_ECALL_RECORD_DIR")
43+
if recordDir == "" {
44+
recordDir = "/tmp/secret_ecall_records"
45+
}
46+
47+
globalRecorder = &EcallRecorder{
48+
mode: mode,
49+
recordDir: recordDir,
50+
}
51+
52+
// Create record directory if it doesn't exist
53+
if err := os.MkdirAll(recordDir, 0o755); err != nil {
54+
fmt.Printf("Warning: could not create ecall record directory: %v\n", err)
55+
}
56+
57+
fmt.Printf("[EcallRecorder] Initialized in %s mode, record dir: %s\n", mode, recordDir)
58+
})
59+
return globalRecorder
60+
}
61+
62+
// Mode returns the current node mode
63+
func (r *EcallRecorder) Mode() NodeMode {
64+
return r.mode
65+
}
66+
67+
// IsSGXMode returns true if running in SGX mode
68+
func (r *EcallRecorder) IsSGXMode() bool {
69+
return r.mode == NodeModeSGX
70+
}
71+
72+
// IsReplayMode returns true if running in replay mode
73+
func (r *EcallRecorder) IsReplayMode() bool {
74+
return r.mode == NodeModeReplay
75+
}
76+
77+
// computeHash computes SHA256 hash of input for filename
78+
func computeHash(operation string, input []byte) string {
79+
h := sha256.New()
80+
h.Write([]byte(operation))
81+
h.Write(input)
82+
return hex.EncodeToString(h.Sum(nil))
83+
}
84+
85+
// getFilePath returns the file path for a given operation and hash
86+
func (r *EcallRecorder) getFilePath(operation string, hash string) string {
87+
return filepath.Join(r.recordDir, fmt.Sprintf("%s_%s.bin", operation, hash[:16]))
88+
}
89+
90+
// Record stores an ecall output to file (used in SGX mode)
91+
func (r *EcallRecorder) Record(operation string, input []byte, output []byte, err error) error {
92+
r.mu.Lock()
93+
defer r.mu.Unlock()
94+
95+
hash := computeHash(operation, input)
96+
filePath := r.getFilePath(operation, hash)
97+
98+
// If there was an error, write empty file with .err extension
99+
if err != nil {
100+
errPath := filePath + ".err"
101+
if writeErr := os.WriteFile(errPath, []byte(err.Error()), 0o644); writeErr != nil {
102+
return fmt.Errorf("failed to write error file: %w", writeErr)
103+
}
104+
fmt.Printf("[EcallRecorder] Recorded error to %s\n", errPath)
105+
return nil
106+
}
107+
108+
// Write output bytes directly to file
109+
if writeErr := os.WriteFile(filePath, output, 0o644); writeErr != nil {
110+
return fmt.Errorf("failed to write record file: %w", writeErr)
111+
}
112+
113+
fmt.Printf("[EcallRecorder] Recorded %d bytes to %s\n", len(output), filePath)
114+
return nil
115+
}
116+
117+
// Replay retrieves a recorded ecall output from file (used in replay mode)
118+
func (r *EcallRecorder) Replay(operation string, input []byte) ([]byte, error, bool) {
119+
r.mu.RLock()
120+
defer r.mu.RUnlock()
121+
122+
hash := computeHash(operation, input)
123+
filePath := r.getFilePath(operation, hash)
124+
125+
// Check for error file first
126+
errPath := filePath + ".err"
127+
if errData, err := os.ReadFile(errPath); err == nil {
128+
fmt.Printf("[EcallRecorder] Replayed error from %s\n", errPath)
129+
return nil, fmt.Errorf("%s", string(errData)), true
130+
}
131+
132+
// Read output file
133+
output, err := os.ReadFile(filePath)
134+
if err != nil {
135+
if os.IsNotExist(err) {
136+
return nil, nil, false // Not found
137+
}
138+
return nil, fmt.Errorf("failed to read record file: %w", err), true
139+
}
140+
141+
fmt.Printf("[EcallRecorder] Replayed %d bytes from %s\n", len(output), filePath)
142+
return output, nil, true
143+
}
144+
145+
// --- Wrapper functions for specific ecalls ---
146+
147+
// RecordGetEncryptedSeed records the GetEncryptedSeed ecall
148+
func RecordGetEncryptedSeed(cert []byte, output []byte, err error) error {
149+
return GetRecorder().Record("GetEncryptedSeed", cert, output, err)
150+
}
151+
152+
// ReplayGetEncryptedSeed attempts to replay a recorded GetEncryptedSeed call
153+
func ReplayGetEncryptedSeed(cert []byte) ([]byte, error, bool) {
154+
return GetRecorder().Replay("GetEncryptedSeed", cert)
155+
}

go-cosmwasm/api/lib.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,14 +536,39 @@ func CreateAttestationReport(no_epid bool, no_dcap bool, is_migration_report boo
536536
}
537537

538538
func GetEncryptedSeed(cert []byte) ([]byte, error) {
539+
recorder := GetRecorder()
540+
541+
// In replay mode, try to get from recorded data
542+
if recorder.IsReplayMode() {
543+
if output, err, found := ReplayGetEncryptedSeed(cert); found {
544+
fmt.Printf("[GetEncryptedSeed] Replay mode: returning recorded result\n")
545+
return output, err
546+
}
547+
return nil, fmt.Errorf("GetEncryptedSeed: no recorded data found for input (replay mode)")
548+
}
549+
550+
// SGX mode: call the actual enclave
539551
errmsg := C.Buffer{}
540552
certSlice := sendSlice(cert)
541553
defer freeAfterSend(certSlice)
542554
res, err := C.get_encrypted_seed(certSlice, &errmsg)
555+
556+
var output []byte
557+
var callErr error
543558
if err != nil {
544-
return nil, errorWithMessage(err, errmsg)
559+
callErr = errorWithMessage(err, errmsg)
560+
} else {
561+
output = receiveVector(res)
545562
}
546-
return receiveVector(res), nil
563+
564+
// Record the result for non-SGX nodes
565+
if recordErr := RecordGetEncryptedSeed(cert, output, callErr); recordErr != nil {
566+
fmt.Printf("[GetEncryptedSeed] Warning: failed to record ecall: %v\n", recordErr)
567+
} else {
568+
fmt.Printf("[GetEncryptedSeed] SGX mode: recorded ecall result\n")
569+
}
570+
571+
return output, callErr
547572
}
548573

549574
func GetEncryptedGenesisSeed(pk []byte) ([]byte, error) {

0 commit comments

Comments
 (0)