diff --git a/go.mod b/go.mod index 02c9550..7455170 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,9 @@ go 1.25.0 toolchain go1.25.6 require ( - github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72 + github.com/beeper/ai-bridge v0.0.0-20260531201429-3d0bf92ccf00 github.com/rs/zerolog v1.35.1 - go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25 + go.mau.fi/util v0.9.9 maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4 ) @@ -28,12 +28,12 @@ require ( github.com/tidwall/sjson v1.2.5 // indirect github.com/yuin/goldmark v1.8.2 // indirect go.mau.fi/zeroconfig v0.2.0 // indirect - golang.org/x/crypto v0.50.0 // indirect - golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect - golang.org/x/net v0.53.0 // indirect + golang.org/x/crypto v0.51.0 // indirect + golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a // indirect + golang.org/x/net v0.54.0 // indirect golang.org/x/sync v0.20.0 // indirect - golang.org/x/sys v0.43.0 // indirect - golang.org/x/text v0.36.0 // indirect + golang.org/x/sys v0.44.0 // indirect + golang.org/x/text v0.37.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect maunium.net/go/mauflag v1.0.0 // indirect diff --git a/go.sum b/go.sum index 7ec830b..6463929 100644 --- a/go.sum +++ b/go.sum @@ -1,47 +1,29 @@ -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= -github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72 h1:Pw2qyz5mizv/UL4JTKiK1sbYfUl6o8dk/KcNyFlSFG0= -github.com/beeper/ai-bridge v0.0.0-20260524021151-5c8086351a72/go.mod h1:Uf2M1ogzy7VGB6uUzzHjZL2eaYt79DK0Py8I6xZl3r0= +github.com/beeper/ai-bridge v0.0.0-20260531201429-3d0bf92ccf00 h1:RIdSWhnzWxhNpt9evjb5kmCNjfgj6Hrl+Kd75yut43c= +github.com/beeper/ai-bridge v0.0.0-20260531201429-3d0bf92ccf00/go.mod h1:+icZV4D9wnp0NTP8bsfS/WXrf/8plzmnp/3bhQEnL3E= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= -github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= -github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= github.com/coreos/go-systemd/v22 v22.7.0 h1:LAEzFkke61DFROc7zNLX/WA2i5J8gYqe0rSj9KI28KA= github.com/coreos/go-systemd/v22 v22.7.0/go.mod h1:xNUYtjHu2EDXbsxz1i41wouACIwT7Ybq9o0BQhMwD0w= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= -github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8= github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= -github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= -github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81 h1:WDsQxOJDy0N1VRAjXLpi8sCEZRSGarLWQevDxpTBRrM= github.com/petermattis/goid v0.0.0-20260330135022-df67b199bc81/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= -github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= -github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/rs/zerolog v1.35.1 h1:m7xQeoiLIiV0BCEY4Hs+j2NG4Gp2o2KPKmhnnLiazKI= github.com/rs/zerolog v1.35.1/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= @@ -58,43 +40,25 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= -github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE= github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.5 h1:7AoWPCIZJGv4jvtFEuCe3GhAbI7uF9ckIooaXvwlIR4= -go.mau.fi/util v0.9.5/go.mod h1:g1uvZ03VQhtTt2BgaRGVytS/Zj67NV0YNIECch0sQCQ= -go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25 h1:YPEmc+li7TF6C9AdRTcSLMb6yCHdF27/wNT7kFLIVNg= -go.mau.fi/util v0.9.9-0.20260511124621-9241e81bdf25/go.mod h1:jE9FfhbgEgAwxei6lomO9v8zdCIATcquONUu4vjRwSs= +go.mau.fi/util v0.9.9 h1:ujDeXCo07HBor5oQLyO1tHklupmqVmPgasc53d7q/NE= +go.mau.fi/util v0.9.9/go.mod h1:pqt4Vcrt+5gcH/CgrHZg11qSx+b34o6mknGzOEA6waY= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= -golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= -golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= -golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= -golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= -golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM= -golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= -golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= -golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= +golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a h1:+3jdDGGB8NGb1Zktc737jlt3/A5f6UlwSzmvqUuufxw= +golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw= +golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= +golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= -golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= -golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= -golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= @@ -103,7 +67,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b h1:OaZ5Y1l4XACFlgy4BmZcCLdYPJZzgZWqZJnpdSITmoM= -maunium.net/go/mautrix v0.26.3-0.20260119125818-e28f7170bc4b/go.mod h1:CUxSZcjPtQNxsZLRQqETAxg2hiz7bjWT+L1HCYoMMKo= maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4 h1:zNC9eVAhw8FhKpM3AxNAh/iy75UEYX91uJUvqqAYlvo= maunium.net/go/mautrix v0.27.1-0.20260513120123-5fba7e3afae4/go.mod h1:3sOGhXi3P1V6/NruTA0gujkvTypXVUraWktCuTGyDuM= diff --git a/pkg/connector/ai_commands.go b/pkg/connector/ai_commands.go index 26fd54f..5ce5c8b 100644 --- a/pkg/connector/ai_commands.go +++ b/pkg/connector/ai_commands.go @@ -36,7 +36,7 @@ func helpText() string { return strings.Join([]string{ "DummyBridge demo commands:", "help", - "stream [seconds] [--runs=N] [--profile=balanced|tools|errors|artifacts] [--seed=N] [--chars=N] [--terminal=stop|length|abort|error] [--delay-ms=min:max] [--stagger-ms=min:max] [--actions=N] [--no-approval] [--allow-abort] [--allow-error]", + "stream [seconds] [--runs=N] [--profile=balanced|tools|errors|artifacts] [--seed=N] [--chars=N] [--terminal=stop|length|tool_calls|content_filter|other|abort|error] [--delay-ms=min:max] [--stagger-ms=min:max] [--actions=N] [--no-approval] [--allow-abort] [--allow-error]", "stream-tools ... [common options]", "Notes: stream enables approval requests by default; approval-tagged tools emit a separate Matrix approval event with reaction options.", }, "\n") @@ -183,7 +183,7 @@ func parseStreamLikeCommand(tokens []string, cmd *randomCommand, deriveActions b cmd.Terminal = "finish" case "abort", "error": cmd.Terminal = strings.ToLower(value) - case "length", "tool-calls", "content-filter", "other": + case "length", "tool_calls", "content_filter", "other": cmd.Terminal = agui.NormalizeFinishReason(value) default: return nil, fmt.Errorf("unknown terminal %q", value) @@ -348,6 +348,9 @@ func parseCommonOptions(tokens []string) (commonCommandOptions, error) { return opts, fmt.Errorf("%s requires a value", token) } opts.FinishReason = agui.NormalizeFinishReason(value) + if !agui.ValidFinishReason(opts.FinishReason) { + return opts, fmt.Errorf("unknown finish reason %q", value) + } case "abort": opts.Abort = true case "error": diff --git a/pkg/connector/ai_parse_helpers.go b/pkg/connector/ai_parse_helpers.go index 9b06d7d..a942bef 100644 --- a/pkg/connector/ai_parse_helpers.go +++ b/pkg/connector/ai_parse_helpers.go @@ -146,16 +146,89 @@ func sliceByStep(text string, parts, index int) string { if parts <= 1 || text == "" { return text } - start := 0 - for i := 0; i < index; i++ { - start += splitCount(len(text), parts, i) + units := naturalTextUnits(text) + if len(units) == 0 { + return "" + } + if parts >= len(units) { + if index >= 0 && index < len(units) { + return units[index] + } + return "" + } + + cumulative := make([]int, len(units)+1) + for i, unit := range units { + cumulative[i+1] = cumulative[i] + len(unit) + 2 + } + boundary := func(step int) int { + if step <= 0 { + return 0 + } + if step >= parts { + return len(units) + } + target := cumulative[len(units)] * step / parts + out := 0 + for out < len(units) && cumulative[out] < target { + out++ + } + if out < step { + out = step + } + maxBoundary := len(units) - (parts - step) + if out > maxBoundary { + out = maxBoundary + } + return out } - length := splitCount(len(text), parts, index) - if start >= len(text) || length <= 0 { + start := boundary(index) + end := boundary(index + 1) + if start >= end || start < 0 || end > len(units) { return "" } - end := min(start+length, len(text)) - return text[start:end] + return strings.Join(units[start:end], "\n\n") +} + +func naturalTextUnits(text string) []string { + var units []string + for _, block := range strings.Split(text, "\n\n") { + block = strings.TrimSpace(block) + if block == "" { + continue + } + if isMarkdownSensitiveBlock(block) { + units = append(units, block) + continue + } + units = append(units, splitSentences(block)...) + } + return units +} + +func splitSentences(text string) []string { + var sentences []string + start := 0 + for i := 0; i < len(text); i++ { + switch text[i] { + case '.', '!', '?': + if i+1 < len(text) && text[i+1] != ' ' && text[i+1] != '\n' { + continue + } + sentence := strings.TrimSpace(text[start : i+1]) + if sentence != "" { + sentences = append(sentences, sentence) + } + start = i + 1 + for start < len(text) && (text[start] == ' ' || text[start] == '\n') { + start++ + } + } + } + if tail := strings.TrimSpace(text[start:]); tail != "" { + sentences = append(sentences, tail) + } + return sentences } func sanitizeToolName(name string) string { diff --git a/pkg/connector/ai_plans.go b/pkg/connector/ai_plans.go index b96987b..4d7d736 100644 --- a/pkg/connector/ai_plans.go +++ b/pkg/connector/ai_plans.go @@ -67,7 +67,7 @@ func hasSeedFlag(input string) bool { return false } -func buildAIRunFromCommandWithApprovals(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string, approvals map[string]agui.ToolApprovalResponse) (*aistream.Run, error) { +func buildAIRunFromCommandWithApprovals(ctx context.Context, runID, threadID string, now time.Time, cmd *parsedCommand, agentID, agentName string, approvals map[string]aistream.ToolApprovalResponse) (*aistream.Run, error) { runtime := virtualAIRuntime(now) run := aistream.NewRun(runID, threadID, aistream.DefaultModel, agentID, agentName, now) writer := aistream.NewWriter(run, runtime.now) @@ -87,6 +87,7 @@ func buildAIRunFromCommandWithApprovals(ctx context.Context, runID, threadID str err = runner.runRandom(ctx, writer, *cmd.Random) } if errors.Is(err, errApprovalRequested) { + writer.Interrupt() err = nil } if err != nil { diff --git a/pkg/connector/ai_runner.go b/pkg/connector/ai_runner.go index 6fc497c..aa202e7 100644 --- a/pkg/connector/ai_runner.go +++ b/pkg/connector/ai_runner.go @@ -204,9 +204,9 @@ func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec tool toolCallID := fmt.Sprintf("dummy-tool-%d-%s", spec.SequenceIndex, sanitizeToolName(spec.Name)) input := toolRequestInput(spec) approvalID := approvalIDForRun(w.Run.RunID, toolCallID) - var approval *agui.ToolApproval + var approval *aistream.ToolApproval if spec.Approval { - approval = &agui.ToolApproval{ID: approvalID, NeedsApproval: true} + approval = &aistream.ToolApproval{ID: approvalID, NeedsApproval: true} } displayMetadata := toolDisplayMetadata(spec.Name) w.ToolStartWithMetadata(toolCallID, spec.Name, spec.SequenceIndex-1, approval, displayMetadata) @@ -255,7 +255,14 @@ func (r aiRunner) runToolSpec(ctx context.Context, w *aistream.Writer, spec tool } w.ToolApprovalInputComplete(toolCallID, spec.Name, input) annotateProviderRawEvent(w, spec, "tool_call_input_complete") - w.ToolApprovalRequestedWithMetadata(toolCallID, spec.Name, input, *approval, displayMetadata) + w.ToolApprovalRequestedWithRequest(aistream.ApprovalRequest{ + ID: approvalID, + ToolCallID: toolCallID, + ToolName: spec.Name, + Input: input, + Approval: *approval, + Metadata: displayMetadata, + }) annotateProviderRawEvent(w, spec, "approval_requested") return errApprovalRequested case spec.Deny: @@ -330,13 +337,13 @@ func annotateProviderRawEvent(w *aistream.Writer, spec toolSpec, stage string) { if !spec.Provider || w == nil || w.Run == nil || len(w.Run.Events) == 0 { return } - w.Run.Events[len(w.Run.Events)-1]["rawEvent"] = map[string]any{ + w.Run.Events[len(w.Run.Events)-1].Set("rawEvent", map[string]any{ "provider": "dummybridge", "stage": stage, "tool": spec.Name, "sequence": spec.SequenceIndex, "tags": spec.Tags, - } + }) } func jsonToolInput(input any) string { diff --git a/pkg/connector/ai_runtime_test.go b/pkg/connector/ai_runtime_test.go index a94d520..51b0752 100644 --- a/pkg/connector/ai_runtime_test.go +++ b/pkg/connector/ai_runtime_test.go @@ -12,6 +12,7 @@ import ( "github.com/beeper/ai-bridge/pkg/ag-ui" "github.com/beeper/ai-bridge/pkg/ai-stream" + aimatrix "github.com/beeper/ai-bridge/pkg/ai-stream/matrix" "maunium.net/go/mautrix/id" ) @@ -71,19 +72,19 @@ func TestBuildAIRunLoremIncludesArtifactsStateAndMetadata(t *testing.T) { } seen := map[string]bool{} for _, evt := range run.Events { - switch evt["type"] { + switch evt.Type() { case agui.EventTextMessageContent, agui.EventStepStarted, agui.EventStepFinished: - seen[evt["type"].(string)] = true + seen[evt.Type()] = true case agui.EventStateDelta: - seen[evt["type"].(string)] = true - if _, ok := evt["delta"].([]map[string]any); !ok { - t.Fatalf("STATE_DELTA should use JSON Patch array, got %#v", evt["delta"]) + seen[evt.Type()] = true + if _, ok := evt.Get("delta").([]map[string]any); !ok { + t.Fatalf("STATE_DELTA should use JSON Patch array, got %#v", evt.Get("delta")) } case agui.EventCustom: - name, _ := evt["name"].(string) + name, _ := evt.Get("name").(string) seen[name] = true if name == "com.beeper.data" { - value := evt["value"].(map[string]any) + value := evt.Get("value").(map[string]any) if value["name"] == "temp" { t.Fatal("transient data must not persist as metadata") } @@ -95,13 +96,12 @@ func TestBuildAIRunLoremIncludesArtifactsStateAndMetadata(t *testing.T) { t.Fatalf("missing %s in events", key) } } - metadata := run.Metadata() - if metadata["model"] == "" || metadata["threadId"] != "thread-1" || metadata["runId"] != "run-1" { - t.Fatalf("bad metadata: %#v", metadata) + payload := run.AI(aistream.AIKindFinal) + if payload.Model == "" || payload.ThreadID != "thread-1" || payload.RunID != "run-1" { + t.Fatalf("bad AI payload: %#v", payload) } - data := metadata["data"].(map[string]any) - if _, ok := data["temp"]; ok { - t.Fatalf("transient data leaked into final metadata: %#v", data) + if _, ok := payload.Data["temp"]; ok { + t.Fatalf("transient data leaked into final AI payload: %#v", payload.Data) } } @@ -117,70 +117,68 @@ func TestBuildAIRunToolsApprovalUsesAGUIApprovalAndPrompt(t *testing.T) { t.Fatalf("approval prompt ID = %q, want run-scoped ID", run.Prompts[0].ID) } foundToolStart := false - seenApprovalStateBeforeCustom := false + seenToolCallEndBeforeInterrupt := false + seenInterrupt := false for _, evt := range run.Events { - if evt["type"] == agui.EventToolCallStart { - if evt["state"] != agui.ToolStateApprovalRequested { - t.Fatalf("expected approval-requested tool state, got %#v", evt) + if evt.Type() == agui.EventToolCallStart { + if evt.Get("state") != agui.ToolStateAwaitingInput { + t.Fatalf("tool start should stay a normal AG-UI tool call, got %#v", evt) } - approval, ok := evt["approval"].(*agui.ToolApproval) - if !ok { - t.Fatalf("expected tool start approval metadata, got %#v", evt["approval"]) - } - if approval.ID != "approval-run-1-dummy-tool-1-shell" || !approval.NeedsApproval { - t.Fatalf("bad approval metadata: %#v", approval) + if evt.Has("approval") { + t.Fatalf("tool start must not carry Beeper approval metadata: %#v", evt) } - metadata, ok := evt["metadata"].(map[string]any) + metadata, ok := evt.Get("metadata").(map[string]any) if !ok || metadata["displayName"] != "Run Command" { - t.Fatalf("bad tool display metadata: %#v", evt["metadata"]) + t.Fatalf("bad tool display metadata: %#v", evt.Get("metadata")) } foundToolStart = true } - if evt["type"] == agui.EventToolCallEnd { - if evt["state"] == agui.ToolStateInputComplete { - t.Fatalf("approval tool must not downgrade to input-complete: %#v", evt) + if evt.Type() == agui.EventToolCallEnd { + if evt.Get("state") != agui.ToolStateInputComplete { + t.Fatalf("tool call should finish normally before AG-UI interrupt: %#v", evt) } - if evt["state"] == agui.ToolStateApprovalRequested { - if evt["input"] != nil { - t.Fatalf("approval input-complete event should omit placeholder input: %#v", evt) - } - seenApprovalStateBeforeCustom = true + if evt.Get("input") != nil { + t.Fatalf("approval input-complete event should omit placeholder input: %#v", evt) } + seenToolCallEndBeforeInterrupt = true } - if evt["type"] == agui.EventCustom && evt["name"] == agui.ApprovalCustomRequested { - if !seenApprovalStateBeforeCustom { - t.Fatalf("approval custom event should be emitted after approval state update: %#v", run.Events) + if evt.Type() == agui.EventCustom { + t.Fatalf("approval must use AG-UI interrupt outcome, not custom event: %#v", evt) + } + if evt.Type() == agui.EventRunFinished { + if !seenToolCallEndBeforeInterrupt { + t.Fatalf("approval interrupt should be emitted after approval state update: %#v", run.Events) } - value := evt["value"].(map[string]any) - if _, hasOptions := value["options"]; hasOptions { - t.Fatalf("AG-UI approval event must not embed Matrix reaction options: %#v", value) + interrupts := eventInterrupts(t, evt) + if len(interrupts) != 1 { + t.Fatalf("approval run should finish with one interrupt: %#v", evt) } - if value["approvalMessageId"] != "approval-run-1-dummy-tool-1-shell" { - t.Fatalf("approval event should name the Matrix reaction target: %#v", value) + interrupt := interrupts[0] + if interrupt.ID != "approval-run-1-dummy-tool-1-shell" || interrupt.Reason != agui.InterruptReasonToolCall || interrupt.ToolCallID != "dummy-tool-1-shell" { + t.Fatalf("bad approval interrupt: %#v", interrupt) } - metadata, ok := value["metadata"].(map[string]any) + metadata, ok := interrupt.Metadata["metadata"].(map[string]any) if !ok || metadata["displayName"] != "Run Command" { - t.Fatalf("approval event should carry tool display metadata: %#v", value["metadata"]) + t.Fatalf("approval interrupt should carry tool display metadata: %#v", interrupt.Metadata) } - choices, ok := value["choices"].([]aistream.ApprovalChoice) - if !ok || len(choices) == 0 || choices[0].Key != aistream.ApprovalChoiceApprove { - t.Fatalf("approval event should duplicate renderer choices: %#v", value["choices"]) + choices := approvalChoicesFromMetadata(t, interrupt.Metadata) + if len(choices) == 0 || choices[0].Key != aistream.ApprovalChoiceApprove { + t.Fatalf("approval interrupt should duplicate renderer choices: %#v", interrupt.Metadata["choices"]) } - if value["input"] != nil { - t.Fatalf("approval event should omit placeholder tool input: %#v", value) + if interrupt.Metadata["input"] != nil { + t.Fatalf("approval interrupt should omit placeholder tool input: %#v", interrupt.Metadata) } + seenInterrupt = true } } if !foundToolStart { t.Fatal("missing tool start event") } - if run.Status.State != "streaming" { - t.Fatalf("approval request should pause the run without terminal status, got %#v", run.Status) + if run.Status.State != "interrupted" { + t.Fatalf("approval request should interrupt the run, got %#v", run.Status) } - for _, evt := range run.Events { - if evt["type"] == agui.EventRunFinished { - t.Fatalf("approval request should not finish the run before response: %#v", run.Events) - } + if !seenInterrupt { + t.Fatalf("approval request missing AG-UI interrupt outcome: %#v", run.Events) } } @@ -204,11 +202,11 @@ func TestApprovalPromptSeqStartsAtNextPackedCarrierSeq(t *testing.T) { if err != nil { t.Fatal(err) } - carriers, err := aistream.PackRunFromSeq(*run, "$anchor", aistream.CarrierBudgetBytes, 1) + carriers, err := aistream.PackRunByTimeFromSeq(*run, 1, demoStreamCarrierMaxSpan) if err != nil { t.Fatal(err) } - nextSeq := aistream.NextSeq(splitCarriersForTimedEmission(carriers)) + nextSeq := aistream.NextSeq(carriers) if nextSeq <= 1 { t.Fatalf("expected initial stream to consume carrier sequence numbers, got %d", nextSeq) } @@ -228,14 +226,14 @@ func TestApprovalPromptSeqStartsAtNextPackedCarrierSeq(t *testing.T) { AgentName: run.AgentName, SeqStart: prompt.SeqStart, } - continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]aistream.ToolApprovalResponse{approvalCtx.ID: { ID: prompt.ID, Approved: true, }}, time.Unix(20, 0)) if err != nil { t.Fatal(err) } - continuationCarriers, err := aistream.PackRunFromSeq(continuation, "$anchor", aistream.CarrierBudgetBytes, approvalCtx.SeqStart) + continuationCarriers, err := aistream.PackRunByTimeFromSeq(continuation, approvalCtx.SeqStart, demoStreamCarrierMaxSpan) if err != nil { t.Fatal(err) } @@ -258,11 +256,11 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { } sizingRun := *run annotateApprovalEventIDs(&sizingRun, approvalEventIDPlaceholders(sizingRun.Prompts)) - initialCarriers, err := aistream.PackRunFromSeq(sizingRun, "$anchor", aistream.CarrierBudgetBytes, 1) + initialCarriers, err := aistream.PackRunByTimeFromSeq(sizingRun, 1, demoStreamCarrierMaxSpan) if err != nil { t.Fatal(err) } - initialCarriers = splitCarriersForTimedEmission(initialCarriers) + initialCarriers = initialCarriers nextSeq := aistream.NextSeq(initialCarriers) if nextSeq <= 1 { t.Fatalf("expected initial carriers to advance sequence, got %d", nextSeq) @@ -289,39 +287,43 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { } annotateApprovalEventIDs(run, map[string]id.EventID{prompt.ID: "$approval"}) - annotatedCarriers, err := aistream.PackRunFromSeq(*run, "$anchor", aistream.CarrierBudgetBytes, 1) + annotatedCarriers, err := aistream.PackRunByTimeFromSeq(*run, 1, demoStreamCarrierMaxSpan) if err != nil { t.Fatal(err) } - var annotatedValue map[string]any + var annotatedInterrupt *agui.Interrupt for _, carrier := range annotatedCarriers { for _, env := range carrier.Envelopes { - if env.Part["type"] != agui.EventCustom || env.Part["name"] != agui.ApprovalCustomRequested { + if env.Event.Type() != agui.EventRunFinished { continue } - annotatedValue, _ = env.Part["value"].(map[string]any) + interrupts := eventInterrupts(t, env.Event) + if len(interrupts) > 0 { + interrupt := interrupts[0] + annotatedInterrupt = &interrupt + } } } - if annotatedValue == nil || annotatedValue["approvalMessageId"] != prompt.ID { - t.Fatalf("approval-requested stream event missing approval message id: %#v", annotatedValue) + if annotatedInterrupt == nil || annotatedInterrupt.Metadata["approvalMessageId"] != prompt.ID { + t.Fatalf("approval interrupt missing approval message id: %#v", annotatedInterrupt) } - if annotatedValue["approvalEventId"] != "$approval" { - t.Fatalf("approval-requested stream event missing Matrix event target: %#v", annotatedValue) + if annotatedInterrupt.Metadata["approvalEventId"] != "$approval" { + t.Fatalf("approval interrupt missing Matrix event target: %#v", annotatedInterrupt) } - annotatedCarriers = splitCarriersForTimedEmission(annotatedCarriers) + annotatedCarriers = annotatedCarriers if annotatedNextSeq := aistream.NextSeq(annotatedCarriers); annotatedNextSeq != nextSeq { t.Fatalf("approval event target changed stream sequence: initial=%d annotated=%d", nextSeq, annotatedNextSeq) } - choices, ok := annotatedValue["choices"].([]any) - if !ok || len(choices) != len(aistream.DefaultApprovalChoices()) { - t.Fatalf("approval-requested stream event missing choices: %#v", annotatedValue["choices"]) + choices := approvalChoicesFromMetadata(t, annotatedInterrupt.Metadata) + if len(choices) != len(aistream.DefaultApprovalChoices()) { + t.Fatalf("approval interrupt missing choices: %#v", annotatedInterrupt.Metadata["choices"]) } - firstChoice, ok := choices[0].(map[string]any) - if !ok || firstChoice["key"] != aistream.ApprovalChoiceApprove || firstChoice["label"] != "Allow once" { + firstChoice := choices[0] + if firstChoice.Key != aistream.ApprovalChoiceApprove || firstChoice.Label != "Allow once" { t.Fatalf("approval-requested stream event has bad choice shape: %#v", choices[0]) } - continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]aistream.ToolApprovalResponse{approvalCtx.ID: { ID: prompt.ID, Approved: true, }}, time.Unix(20, 0)) @@ -331,18 +333,21 @@ func TestApprovalLifecycleCarriesNoticeTargetAndContinuation(t *testing.T) { if len(continuation.Prompts) != 0 { t.Fatalf("continuation must not request approval again: %#v", continuation.Prompts) } + if len(continuation.Interrupts) != 0 || continuation.ApprovalID != "" || continuation.ToolCallID != "" { + t.Fatalf("finished continuation kept pending approval state: interrupts=%#v approval=%q tool=%q", continuation.Interrupts, continuation.ApprovalID, continuation.ToolCallID) + } if continuation.Status.State != "complete" { t.Fatalf("approved continuation should finish the run, got %#v", continuation.Status) } - continuationCarriers, err := aistream.PackRunFromSeq(continuation, "$anchor", aistream.CarrierBudgetBytes, approvalCtx.SeqStart) + continuationCarriers, err := aistream.PackRunByTimeFromSeq(continuation, approvalCtx.SeqStart, demoStreamCarrierMaxSpan) if err != nil { t.Fatal(err) } if len(continuationCarriers) == 0 || len(continuationCarriers[0].Envelopes) == 0 || continuationCarriers[0].Envelopes[0].Seq != nextSeq { t.Fatalf("continuation should resume at seq %d, got %#v", nextSeq, continuationCarriers) } - if continuation.Events[0]["type"] != agui.EventCustom || continuation.Events[0]["name"] != agui.ApprovalCustomResponded { - t.Fatalf("continuation must start by acknowledging approval: %#v", continuation.Events) + if continuation.Events[0].Type() != agui.EventToolCallResult || toolResultApprovalID(continuation.Events[0]) != prompt.ID { + t.Fatalf("continuation must start with approval tool result: %#v", continuation.Events) } } @@ -361,7 +366,7 @@ func TestApprovalContinuationResumesOriginalRunAfterApprovedTool(t *testing.T) { AgentName: "AI", SeqStart: 12, } - run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]aistream.ToolApprovalResponse{approvalCtx.ID: { ID: approvalCtx.ID, Approved: true, }}, time.Unix(20, 0)) @@ -371,26 +376,23 @@ func TestApprovalContinuationResumesOriginalRunAfterApprovedTool(t *testing.T) { if len(run.Events) == 0 { t.Fatal("expected continuation events") } - if run.Events[0]["type"] != agui.EventCustom || run.Events[0]["name"] != agui.ApprovalCustomResponded { - t.Fatalf("first continuation event should acknowledge approval, got %#v", run.Events[0]) + if run.Events[0].Type() != agui.EventToolCallResult || toolResultApprovalID(run.Events[0]) != approvalCtx.ID { + t.Fatalf("first continuation event should be approval tool result, got %#v", run.Events[0]) } seenApprovedTool := false seenLaterTool := false seenFinished := false for _, evt := range run.Events { - if evt["type"] == agui.EventToolCallEnd && evt["toolCallId"] == approvalCtx.ToolCallID { - if evt["state"] == agui.ToolStateApprovalResponded { - result := jsonResultMap(t, evt["result"]) - if result["approved"] != true { - t.Fatalf("approved result missing approval state: %#v", result) - } + if evt.Type() == agui.EventToolCallResult && evt.Get("toolCallId") == approvalCtx.ToolCallID { + result := jsonResultMap(t, evt.Get("content")) + if result["approved"] == true { seenApprovedTool = true } } - if evt["type"] == agui.EventToolCallStart && evt["toolCallId"] == "dummy-tool-2-fetch" { + if evt.Type() == agui.EventToolCallStart && evt.Get("toolCallId") == "dummy-tool-2-fetch" { seenLaterTool = true } - if evt["type"] == agui.EventRunFinished { + if evt.Type() == agui.EventRunFinished { seenFinished = true } } @@ -403,6 +405,9 @@ func TestApprovalContinuationResumesOriginalRunAfterApprovedTool(t *testing.T) { if len(run.Prompts) != 0 { t.Fatalf("finished continuation should not keep pending prompts: %#v", run.Prompts) } + if len(run.Interrupts) != 0 || run.ApprovalID != "" || run.ToolCallID != "" { + t.Fatalf("finished continuation kept pending approval state: interrupts=%#v approval=%q tool=%q", run.Interrupts, run.ApprovalID, run.ToolCallID) + } } func TestApprovalContinuationStopsOriginalRunAfterDeniedTool(t *testing.T) { @@ -420,7 +425,7 @@ func TestApprovalContinuationStopsOriginalRunAfterDeniedTool(t *testing.T) { AgentName: "AI", SeqStart: 12, } - run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]aistream.ToolApprovalResponse{approvalCtx.ID: { ID: approvalCtx.ID, Approved: false, Reason: "denied", @@ -430,11 +435,11 @@ func TestApprovalContinuationStopsOriginalRunAfterDeniedTool(t *testing.T) { } seenDeniedTool := false for _, evt := range run.Events { - if evt["type"] == agui.EventToolCallStart && evt["toolCallId"] == "dummy-tool-2-fetch" { + if evt.Type() == agui.EventToolCallStart && evt.Get("toolCallId") == "dummy-tool-2-fetch" { t.Fatalf("denied approval must not continue later tools: %#v", run.Events) } - if evt["type"] == agui.EventToolCallEnd && evt["toolCallId"] == approvalCtx.ToolCallID && evt["state"] == agui.ToolStateApprovalResponded { - result := jsonResultMap(t, evt["result"]) + if evt.Type() == agui.EventToolCallResult && evt.Get("toolCallId") == approvalCtx.ToolCallID { + result := jsonResultMap(t, evt.Get("content")) if result["state"] != agui.ToolResultStateError || result["reason"] != "denied" { t.Fatalf("bad denied result: %#v", result) } @@ -447,6 +452,9 @@ func TestApprovalContinuationStopsOriginalRunAfterDeniedTool(t *testing.T) { if run.Status.State != "error" { t.Fatalf("denied continuation status = %#v", run.Status) } + if len(run.Prompts) != 0 || len(run.Interrupts) != 0 || run.ApprovalID != "" || run.ToolCallID != "" { + t.Fatalf("denied continuation kept pending approval state: prompts=%#v interrupts=%#v approval=%q tool=%q", run.Prompts, run.Interrupts, run.ApprovalID, run.ToolCallID) + } } func TestBuildAIRunToolsDenyProducesStructuredDeniedResult(t *testing.T) { @@ -455,10 +463,10 @@ func TestBuildAIRunToolsDenyProducesStructuredDeniedResult(t *testing.T) { t.Fatal(err) } for _, evt := range run.Events { - if evt["type"] != agui.EventToolCallEnd { + if evt.Type() != agui.EventToolCallResult { continue } - result := jsonResultMap(t, evt["result"]) + result := jsonResultMap(t, evt.Get("content")) if result["state"] == agui.ToolResultStateError && result["reason"] == "denied" { return } @@ -471,22 +479,32 @@ func TestBuildAIRunToolsOmitPlaceholderArgsAndEmitTerminalResult(t *testing.T) { if err != nil { t.Fatal(err) } + seenEnd := false + seenResult := false for _, evt := range run.Events { - if evt["type"] == agui.EventToolCallArgs { + if evt.Type() == agui.EventToolCallArgs { t.Fatalf("plain demo tool should not emit placeholder args: %#v", evt) } - if evt["type"] == agui.EventToolCallEnd { - if evt["input"] != nil { + if evt.Type() == agui.EventToolCallEnd { + if evt.Get("input") != nil { t.Fatalf("plain demo tool should omit placeholder input: %#v", evt) } - result := jsonResultMap(t, evt["result"]) + if evt.Has("result") { + t.Fatalf("TOOL_CALL_END must not carry result: %#v", evt) + } + seenEnd = true + } + if evt.Type() == agui.EventToolCallResult { + result := jsonResultMap(t, evt.Get("content")) if result["state"] != agui.ToolResultStateComplete || result["status"] != "success" { t.Fatalf("plain demo tool should emit terminal success result: %#v", evt) } - return + seenResult = true } } - t.Fatal("missing TOOL_CALL_END event") + if !seenEnd || !seenResult { + t.Fatalf("missing TOOL_CALL_END/TOOL_CALL_RESULT events: %#v", run.Events) + } } func TestBuildAIRunToolsPrelimUsesAGUIToolResult(t *testing.T) { @@ -495,10 +513,10 @@ func TestBuildAIRunToolsPrelimUsesAGUIToolResult(t *testing.T) { t.Fatal(err) } for _, evt := range run.Events { - if evt["type"] != agui.EventToolCallResult { + if evt.Type() != agui.EventToolCallResult { continue } - if evt["state"] != agui.ToolResultStateStreaming || evt["toolCallId"] == "" || evt["content"] == "" { + if evt.Get("state") != agui.ToolResultStateStreaming || evt.Get("toolCallId") == "" || evt.Get("content") == "" { t.Fatalf("bad TOOL_CALL_RESULT event: %#v", evt) } return @@ -511,38 +529,58 @@ func TestBuildAIRunFinalSnapshotPreservesToolParts(t *testing.T) { if err != nil { t.Fatal(err) } - var snapshot []agui.UIMessage + var snapshot []agui.Message seenRunFinished := false for _, evt := range run.Events { - switch evt["type"] { + switch evt.Type() { case agui.EventMessagesSnapshot: if seenRunFinished { t.Fatal("final snapshot must be emitted before RUN_FINISHED") } var ok bool - snapshot, ok = evt["messages"].([]agui.UIMessage) + snapshot, ok = evt.Get("messages").([]agui.Message) if !ok { - t.Fatalf("bad snapshot payload: %#v", evt["messages"]) + t.Fatalf("bad snapshot payload: %#v", evt.Get("messages")) } case agui.EventRunFinished: seenRunFinished = true } } - if len(snapshot) != 1 { - t.Fatalf("expected one final UI message snapshot, got %#v", snapshot) + if len(snapshot) == 0 { + t.Fatalf("expected final message snapshot, got %#v", snapshot) } seenToolCall := false seenToolResult := false - for _, part := range snapshot[0].Parts { - switch part["type"] { - case "tool-call": + for _, message := range snapshot { + if message.Role == agui.RoleAssistant && len(message.ToolCalls) > 0 { seenToolCall = true - case "tool-result": + } + if message.Role == agui.RoleTool && message.ToolCallID != "" { seenToolResult = true } } if !seenToolCall || !seenToolResult { - t.Fatalf("final snapshot lost tool parts: %#v", snapshot[0].Parts) + t.Fatalf("final snapshot lost tool messages: %#v", snapshot) + } +} + +func TestBuildAIRunFinalUIMessagePreservesTextToolTextOrder(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream-tools 420 fetch search --seed=7 --chunk-chars=32:32", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + message := run.FinalBeeperAIMessage(0, true) + var order []string + for _, part := range message.Parts { + switch part["type"] { + case "text": + order = append(order, "text") + case "tool-call": + order = append(order, "tool-call") + } + } + if strings.Join(order, "|") != "text|tool-call|text|tool-call|text" { + t.Fatalf("final UIMessage did not preserve text/tool order: %v\nparts: %#v", order, message.Parts) } } @@ -554,22 +592,22 @@ func TestBuildAIRunToolsFailureDeltaAndInputError(t *testing.T) { seenFailure := false seenInputError := false for _, evt := range run.Events { - if evt["type"] != agui.EventToolCallEnd && evt["type"] != agui.EventToolCallArgs { + if evt.Type() != agui.EventToolCallResult && evt.Type() != agui.EventToolCallArgs { continue } - toolCallID, _ := evt["toolCallId"].(string) - if evt["type"] == agui.EventToolCallArgs && strings.Contains(toolCallID, "fetch") { + toolCallID, _ := evt.Get("toolCallId").(string) + if evt.Type() == agui.EventToolCallArgs && strings.Contains(toolCallID, "fetch") { t.Fatalf("delta tool without real input should not emit placeholder args: %#v", evt) } - if evt["type"] == agui.EventToolCallEnd { + if evt.Type() == agui.EventToolCallResult { if strings.Contains(toolCallID, "shell") { - result := jsonResultMap(t, evt["result"]) + result := jsonResultMap(t, evt.Get("content")) if result["state"] == agui.ToolResultStateError { seenFailure = true } } if strings.Contains(toolCallID, "parser") { - result := jsonResultMap(t, evt["result"]) + result := jsonResultMap(t, evt.Get("content")) if result["reason"] == "input-error" { seenInputError = true } @@ -587,14 +625,14 @@ func TestBuildAIRunToolsProviderTagAddsRawEventPassthrough(t *testing.T) { t.Fatal(err) } for _, evt := range run.Events { - raw, ok := evt["rawEvent"].(map[string]any) + raw, ok := evt.Get("rawEvent").(map[string]any) if !ok { continue } if raw["provider"] != "dummybridge" || raw["tool"] != "shell" { t.Fatalf("bad raw provider event: %#v", raw) } - carriers, err := aistream.PackRun(*run, "$anchor", aistream.CarrierBudgetBytes) + carriers, err := aistream.PackRun(*run) if err != nil { t.Fatal(err) } @@ -623,40 +661,48 @@ func TestBuildAIRunTerminalErrorAndAbortStates(t *testing.T) { } } -func TestBuildAIRunOver64KBPacksTo58KCarriers(t *testing.T) { +func TestBuildAIRunOver64KBStreamsWithoutCarrierSizeSplit(t *testing.T) { run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream 1 --chars=70000 --actions=1 --seed=7", time.Unix(10, 0)) if err != nil { t.Fatal(err) } - carriers, err := aistream.PackRun(*run, "$anchor", aistream.CarrierBudgetBytes) + carriers, err := aistream.PackRun(*run) if err != nil { t.Fatal(err) } - if len(carriers) < 2 { - t.Fatalf("expected split carriers, got %d", len(carriers)) - } - for i, carrier := range carriers { - if size := aistream.JSONSize(aistream.CarrierContent(carrier.Envelopes)); size > aistream.CarrierBudgetBytes { - t.Fatalf("carrier %d size = %d", i, size) - } - } - for _, carrier := range carriers { - for _, envelope := range carrier.Envelopes { - if envelope.Part["type"] != agui.EventMessagesSnapshot { - continue - } - raw, err := json.Marshal(envelope.Part) - if err != nil { - t.Fatal(err) - } - if strings.Contains(string(raw), strings.Repeat("a", 60*1024)) { - t.Fatal("final snapshot should not repeat full streamed text") - } - } + if len(carriers) != 1 { + t.Fatalf("stream packing must not split by size, got %d carriers", len(carriers)) } if len(aistream.ReconstructText(carriers)) < 60*1024 { t.Fatalf("expected large reconstructed output, got %d", len(aistream.ReconstructText(carriers))) } + projection := aimatrix.ProjectFinal(*run, nil) + if !projection.NeedsAttachment { + t.Fatal("large final UIMessage should use final-parts attachment projection") + } + if len(projection.Message.Parts) == 0 { + t.Fatal("large final projection should preserve full UIMessage parts for attachment upload") + } +} + +func TestBuildAIRunStream50PacksByCadence(t *testing.T) { + run, err := buildAIRun(context.Background(), "run-1", "thread-1", "stream 50 --seed=7 --no-approval", time.Unix(10, 0)) + if err != nil { + t.Fatal(err) + } + carriers, err := aistream.PackRunByTimeFromSeq(*run, 1, demoStreamCarrierMaxSpan) + if err != nil { + t.Fatal(err) + } + if len(carriers) < 10 { + t.Fatalf("stream 50 should produce incremental carriers, got %d", len(carriers)) + } + start := time.Unix(100, 0) + first := aistream.CarrierTimestamp(*run, carriers[0], start) + last := aistream.CarrierTimestamp(*run, carriers[len(carriers)-1], start) + if first.IsZero() || last.IsZero() || !last.After(first) { + t.Fatalf("carrier timestamps should preserve stream cadence, first=%s last=%s", first, last) + } } func TestBuildAIRunPlansChaosCreatesMultipleRuns(t *testing.T) { @@ -692,9 +738,9 @@ func TestBuildAIRunRandomHonorsVirtualDelays(t *testing.T) { } var first, last int64 for _, evt := range run.Events { - ts, _ := evt["timestamp"].(int64) + ts, _ := evt.Get("timestamp").(int64) if ts == 0 { - if n, ok := evt["timestamp"].(int); ok { + if n, ok := evt.Get("timestamp").(int); ok { ts = int64(n) } } @@ -720,13 +766,11 @@ func TestRandomModeApprovalPause(t *testing.T) { if run.ApprovalID == "" { continue } - for _, evt := range run.Events { - if evt["type"] == agui.EventRunFinished { - t.Fatalf("approval run emitted RUN_FINISHED with seed %d", seed) - } + if run.Status.State != "interrupted" { + t.Fatalf("expected approval run to interrupt, got %q", run.Status.State) } - if run.Status.State != "streaming" { - t.Fatalf("expected approval run to remain streaming, got %q", run.Status.State) + if len(firstInterrupts(t, run.Events)) == 0 { + t.Fatalf("approval run missing interrupt outcome with seed %d", seed) } return } @@ -854,6 +898,40 @@ func TestBuildDemoVisibleTextDoesNotCutMarkdownSyntax(t *testing.T) { if strings.Contains(text, "https://dummybridge.") && !strings.Contains(text, "https://dummybridge.local/") { t.Fatalf("cut markdown URL for chars=%d seed=%d: %q", chars, seed, text) } + if joinedMarkdownBlockRE.MatchString(text) { + t.Fatalf("markdown block joined to incomplete text for chars=%d seed=%d: %q", chars, seed, text) + } + } + } +} + +func TestSliceByStepKeepsNaturalTextUnits(t *testing.T) { + text := strings.Join([]string{ + "First complete sentence. Second complete sentence.", + "Review the [release notes](https://dummybridge.local/docs/streaming) entry for **review-ready** output.", + "Third complete sentence. Fourth complete sentence.", + }, "\n\n") + + parts := []string{ + sliceByStep(text, 3, 0), + sliceByStep(text, 3, 1), + sliceByStep(text, 3, 2), + } + joined := strings.Join(parts, "\n\n") + for _, expected := range []string{ + "First complete sentence.", + "Second complete sentence.", + "Review the [release notes](https://dummybridge.local/docs/streaming) entry for **review-ready** output.", + "Third complete sentence.", + "Fourth complete sentence.", + } { + if !strings.Contains(joined, expected) { + t.Fatalf("sliced text lost %q:\n%#v", expected, parts) + } + } + for _, part := range parts { + if strings.HasSuffix(part, "complete") || strings.HasSuffix(part, "Review the") { + t.Fatalf("slice ended with a cut-off unit: %#v", parts) } } } @@ -916,7 +994,7 @@ func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { AgentName: "AI", SeqStart: 12, } - run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + run, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]aistream.ToolApprovalResponse{approvalCtx.ID: { ID: approvalCtx.ID, Approved: true, }}, time.Unix(20, 0)) @@ -929,8 +1007,17 @@ func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { if run.Prompts[0].ToolName != "fetch" { t.Fatalf("expected preserved prompt to belong to fetch, got %#v", run.Prompts[0]) } - if run.Status.State != "streaming" { - t.Fatalf("expected continuation with pending approval to remain streaming, got %#v", run.Status) + if run.Status.State != "interrupted" { + t.Fatalf("expected continuation with pending approval to interrupt, got %#v", run.Status) + } + if len(firstInterrupts(t, run.Events)) != 1 { + t.Fatalf("expected continuation with pending approval to finish with one interrupt: %#v", run.Events) + } + if len(run.Interrupts) != 1 || run.Interrupts[0].ID != run.Prompts[0].ID { + t.Fatalf("pending continuation should expose only the new interrupt: prompts=%#v interrupts=%#v", run.Prompts, run.Interrupts) + } + if run.ApprovalID != run.Prompts[0].ID || run.ToolCallID != run.Prompts[0].ToolCallID { + t.Fatalf("pending continuation should target the new approval: prompts=%#v approval=%q tool=%q", run.Prompts, run.ApprovalID, run.ToolCallID) } secondCtx := aistream.ApprovalContext{ @@ -946,7 +1033,7 @@ func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { AgentName: approvalCtx.AgentName, SeqStart: 100, } - finished, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), secondCtx, map[string]agui.ToolApprovalResponse{ + finished, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), secondCtx, map[string]aistream.ToolApprovalResponse{ approvalCtx.ID: { ID: approvalCtx.ID, Approved: true, @@ -965,6 +1052,9 @@ func TestMultiApprovalContinuationKeepsLaterPrompts(t *testing.T) { if len(finished.Prompts) != 0 { t.Fatalf("finished continuation should not keep prompts: %#v", finished.Prompts) } + if len(finished.Interrupts) != 0 || finished.ApprovalID != "" || finished.ToolCallID != "" { + t.Fatalf("finished continuation kept pending approval state: interrupts=%#v approval=%q tool=%q", finished.Interrupts, finished.ApprovalID, finished.ToolCallID) + } } func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { @@ -1000,7 +1090,7 @@ func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { AgentName: "AI", SeqStart: 50, } - continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]agui.ToolApprovalResponse{approvalCtx.ID: { + continuation, err := buildAIApprovalContinuationRunWithApprovals(context.Background(), approvalCtx, map[string]aistream.ToolApprovalResponse{approvalCtx.ID: { ID: approvalCtx.ID, Approved: true, }}, now.Add(time.Hour)) @@ -1010,8 +1100,8 @@ func TestApprovalContinuationReplaysRandomRunWithImplicitSeed(t *testing.T) { if len(continuation.Events) == 0 { t.Fatalf("expected continuation events for random run, got none") } - if continuation.Events[0]["type"] != agui.EventCustom || continuation.Events[0]["name"] != agui.ApprovalCustomResponded { - t.Fatalf("first continuation event should acknowledge approval, got %#v", continuation.Events[0]) + if continuation.Events[0].Type() != agui.EventToolCallResult || toolResultApprovalID(continuation.Events[0]) != approvalCtx.ID { + t.Fatalf("first continuation event should be approval tool result, got %#v", continuation.Events[0]) } return } @@ -1055,3 +1145,90 @@ func jsonResultMap(t *testing.T, value any) map[string]any { } return out } + +func firstInterrupts(t *testing.T, events []agui.Event) []agui.Interrupt { + t.Helper() + for _, evt := range events { + if evt.Type() != agui.EventRunFinished { + continue + } + interrupts := eventInterrupts(t, evt) + if len(interrupts) > 0 { + return interrupts + } + } + return nil +} + +func eventInterrupts(t *testing.T, evt agui.Event) []agui.Interrupt { + t.Helper() + switch outcome := evt.Get("outcome").(type) { + case agui.RunFinishedOutcome: + if outcome.Type != agui.OutcomeInterrupt { + return nil + } + return outcome.Interrupts + case map[string]any: + if outcome["type"] != agui.OutcomeInterrupt { + return nil + } + rawInterrupts, ok := outcome["interrupts"].([]any) + if !ok { + t.Fatalf("bad interrupt payload: %#v", outcome["interrupts"]) + } + interrupts := make([]agui.Interrupt, 0, len(rawInterrupts)) + for _, raw := range rawInterrupts { + value, ok := raw.(map[string]any) + if !ok { + t.Fatalf("bad interrupt value: %#v", raw) + } + metadata, _ := value["metadata"].(map[string]any) + responseSchema, _ := value["responseSchema"].(map[string]any) + interrupts = append(interrupts, agui.Interrupt{ + ID: stringFromAny(value["id"]), + Reason: stringFromAny(value["reason"]), + Message: stringFromAny(value["message"]), + ToolCallID: stringFromAny(value["toolCallId"]), + ExpiresAt: stringFromAny(value["expiresAt"]), + ResponseSchema: responseSchema, + Metadata: metadata, + }) + } + return interrupts + default: + t.Fatalf("bad outcome payload: %#v", evt.Get("outcome")) + return nil + } +} + +func approvalChoicesFromMetadata(t *testing.T, metadata map[string]any) []aistream.ApprovalChoice { + t.Helper() + switch raw := metadata["choices"].(type) { + case []aistream.ApprovalChoice: + return raw + case []any: + choices := make([]aistream.ApprovalChoice, 0, len(raw)) + for _, item := range raw { + value, ok := item.(map[string]any) + if !ok { + t.Fatalf("bad approval choice: %#v", item) + } + choices = append(choices, aistream.ApprovalChoice{ + Key: stringFromAny(value["key"]), + Label: stringFromAny(value["label"]), + Alias: stringFromAny(value["alias"]), + Style: stringFromAny(value["style"]), + Shortcut: stringFromAny(value["shortcut"]), + }) + } + return choices + default: + t.Fatalf("bad approval choices: %#v", raw) + return nil + } +} + +func stringFromAny(value any) string { + text, _ := value.(string) + return text +} diff --git a/pkg/connector/ai_text.go b/pkg/connector/ai_text.go index 40a6945..689a2f7 100644 --- a/pkg/connector/ai_text.go +++ b/pkg/connector/ai_text.go @@ -84,6 +84,10 @@ func buildLoremText(chars int, rng *rand.Rand) string { return trimText(sb.String(), chars) } +func buildCompleteLoremText(chars int, rng *rand.Rand) string { + return trimCompleteText(buildLoremText(chars+128, rng), chars) +} + func buildDemoVisibleText(chars int, rng *rand.Rand) string { if chars <= 0 { return "" @@ -96,8 +100,8 @@ func buildDemoVisibleText(chars int, rng *rand.Rand) string { return buildLoremText(max(48, min(168, remaining+48)), rand.New(rand.NewSource(rng.Int63()))) }}, {weight: 4, minLen: 96, build: func(rng *rand.Rand, _ int) string { - return fmt.Sprintf("%s Review the [%s](%s) entry for **%s** output and _staged_ formatting transitions.", - buildLoremText(72+rng.Intn(48), rand.New(rand.NewSource(rng.Int63()))), + return fmt.Sprintf("%s\n\nReview the [%s](%s) entry for **%s** output and _staged_ formatting transitions.", + buildCompleteLoremText(72+rng.Intn(48), rand.New(rand.NewSource(rng.Int63()))), demoMarkdownLabels[rng.Intn(len(demoMarkdownLabels))], demoMarkdownURLs[rng.Intn(len(demoMarkdownURLs))], demoMarkdownEmphasis[rng.Intn(len(demoMarkdownEmphasis))]) @@ -183,7 +187,7 @@ func trimVisibleText(text string, limit int) string { if isMarkdownSensitiveBlock(block) { kept = append(kept, trimMarkdownBlock(block, limit)) } else { - kept = append(kept, trimText(block, limit)) + kept = append(kept, trimCompleteText(block, limit)) } } break @@ -194,7 +198,7 @@ func trimVisibleText(text string, limit int) string { if len(kept) > 0 { return strings.Join(kept, "\n\n") } - return trimText(text, limit) + return trimCompleteText(text, limit) } func isMarkdownSensitiveBlock(block string) bool { @@ -224,12 +228,16 @@ func trimMarkdownBlock(block string, limit int) string { } } if trimmed == "" { - return trimText(block, limit) + return trimCompleteText(block, limit) } return trimmed } func trimText(text string, limit int) string { + return trimCompleteText(text, limit) +} + +func trimCompleteText(text string, limit int) string { text = strings.TrimSpace(text) if limit <= 0 || len(text) <= limit { return text @@ -241,10 +249,11 @@ func trimText(text string, limit int) string { return strings.TrimSpace(text[:i]) } } - for i := min(limit, len(text)); i >= minCutoff; i-- { - if text[i-1] == ' ' { - return strings.Trim(strings.TrimSpace(text[:i]), ".,;:") + for i := min(limit+128, len(text)); i > limit; i++ { + switch text[i-1] { + case '.', '!', '?': + return strings.TrimSpace(text[:i]) } } - return strings.Trim(strings.TrimSpace(text[:limit]), ".,;:") + return text } diff --git a/pkg/connector/ai_types.go b/pkg/connector/ai_types.go index 80c3ee2..c6dbaaa 100644 --- a/pkg/connector/ai_types.go +++ b/pkg/connector/ai_types.go @@ -5,7 +5,6 @@ import ( "errors" "time" - "github.com/beeper/ai-bridge/pkg/ag-ui" "github.com/beeper/ai-bridge/pkg/ai-stream" ) @@ -142,7 +141,7 @@ type aiRuntime struct { type aiRunner struct { runtime aiRuntime - approvals map[string]agui.ToolApprovalResponse + approvals map[string]aistream.ToolApprovalResponse } type aiRunPlan struct { diff --git a/pkg/connector/client.go b/pkg/connector/client.go index 378b3d4..a196907 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -13,6 +13,7 @@ import ( "github.com/beeper/ai-bridge/pkg/ag-ui" "github.com/beeper/ai-bridge/pkg/ai-stream" aibridgev2 "github.com/beeper/ai-bridge/pkg/ai-stream/bridgev2" + aimatrix "github.com/beeper/ai-bridge/pkg/ai-stream/matrix" "github.com/rs/zerolog/log" "go.mau.fi/util/exsync" "go.mau.fi/util/jsontime" @@ -42,7 +43,7 @@ type DummyClient struct { } type aiRunSession struct { - Decisions map[string]agui.ToolApprovalResponse + Decisions map[string]aistream.ToolApprovalResponse } var _ bridgev2.NetworkAPI = (*DummyClient)(nil) @@ -53,7 +54,9 @@ var _ bridgev2.MessageRequestAcceptingNetworkAPI = (*DummyClient)(nil) var _ bridgev2.ReactionHandlingNetworkAPI = (*DummyClient)(nil) const ( - dummyAIAgentName string = "Dummy" + dummyAIAgentName string = "Dummy" + defaultAIApprovalTimeout = 5 * time.Minute + demoStreamCarrierMaxSpan = 750 * time.Millisecond ) var delayedRemoteEchoPattern = regexp.MustCompile(`(?i)^remote-echo\s+delay\s+([0-9]+(?:ms|s|m|h))$`) @@ -262,9 +265,6 @@ func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.M if dc == nil || dc.UserLogin == nil || msg == nil || msg.TargetMessage == nil || msg.Content == nil || msg.Portal == nil { return &database.Reaction{}, nil } - if isApprovalOptionReaction(msg) { - return &database.Reaction{}, nil - } approvalID := string(msg.TargetMessage.ID) if !strings.HasPrefix(approvalID, "approval-") { return &database.Reaction{}, nil @@ -307,14 +307,6 @@ func (dc *DummyClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.M return &database.Reaction{}, nil } -func isApprovalOptionReaction(msg *bridgev2.MatrixReaction) bool { - if msg == nil || msg.Event == nil { - return false - } - _, ok := msg.Event.Content.Raw["com.beeper.ai.approval_option"] - return ok -} - func (dc *DummyClient) resolveApprovalOnce(approvalID, selectedKey string) (string, bool) { dc.approvalSelectionsOnce.Do(func() { dc.approvalSelections = exsync.NewMap[string, string]() @@ -516,20 +508,48 @@ func cloneMessageContent(content *event.MessageEventContent) *event.MessageEvent return &cloned } +type aiRunTarget struct { + portal *bridgev2.Portal + bot bridgev2.MatrixAPI + roomID id.RoomID + threadID string + sender networkid.UserID + agentName string +} + func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Portal, inbound *event.MessageEventContent) { if portal == nil { return } + dc.queueAIResponseToTarget(ctx, aiRunTarget{ + portal: portal, + threadID: string(portal.ID), + sender: dummyAISenderForPortal(portal), + agentName: dummyAIAgentNameForPortal(portal), + }, inbound) +} + +func (dc *DummyClient) queueAIResponseInRoom(ctx context.Context, bot bridgev2.MatrixAPI, roomID id.RoomID, inbound *event.MessageEventContent) { + if bot == nil || roomID == "" { + return + } + dc.queueAIResponseToTarget(ctx, aiRunTarget{ + bot: bot, + roomID: roomID, + threadID: string(roomID), + sender: networkid.UserID(dummyAIAgentName), + agentName: dummyAIAgentName, + }, inbound) +} +func (dc *DummyClient) queueAIResponseToTarget(ctx context.Context, target aiRunTarget, inbound *event.MessageEventContent) { now := time.Now() runID := "run-" + string(randomMessageID()) - sender := dummyAISenderForPortal(portal) - agentName := dummyAIAgentNameForPortal(portal) var body string if inbound != nil { body = inbound.Body } - plans, err := buildAIRunPlans(ctx, runID, string(portal.ID), body, now, string(sender), agentName) + plans, err := buildAIRunPlans(ctx, runID, target.threadID, body, now, string(target.sender), target.agentName) if err != nil { log.Warn().Err(err).Msg("Failed to build AI runs") return @@ -545,7 +565,7 @@ func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Por } dc.wg.Add(1) - go func(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, command string, delay time.Duration) { + go func(target aiRunTarget, messageID networkid.MessageID, run aistream.Run, command string, delay time.Duration) { defer dc.wg.Done() if delay > 0 { timer := time.NewTimer(delay) @@ -556,11 +576,14 @@ func (dc *DummyClient) queueAIResponse(ctx context.Context, portal *bridgev2.Por return } } - dc.ensureAISenderInvited(portal, sender) anchorAt := time.Now() - dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(portal.PortalKey, sender, initialAIAnchorRun(run), anchorAt)) - dc.queueAIRunStreamAndMetadata(portal, sender, messageID, run, command, anchorAt) - }(portal, sender, placeholderID, *plan.Run, effectiveCommand, plan.Delay) + targetEventID, err := target.sendAnchor(dc, initialAIAnchorRun(run), messageID, anchorAt) + if err != nil { + log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to send AI anchor") + return + } + dc.emitAIRunStream(target, messageID, targetEventID, run, command, 1, anchorAt) + }(target, placeholderID, *plan.Run, effectiveCommand, plan.Delay) } } @@ -571,69 +594,59 @@ func initialAIAnchorRun(run aistream.Run) aistream.Run { return run } -func (dc *DummyClient) queueAIRunStreamAndMetadata(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, run aistream.Run, command string, anchorAt time.Time) { - targetEventID := dc.waitForMessageMXID(portal, messageID, 30*time.Second) - if targetEventID == "" { - log.Warn(). - Str("run_id", run.RunID). - Str("message_id", string(messageID)). - Msg("Timed out waiting for AI anchor Matrix event") - return - } - dc.emitAIRunStream(portal, sender, messageID, targetEventID, run, command, 1, anchorAt) -} - // emitAIRunStream packs and emits one segment of an AI run — used both for // the initial run and for any approval continuation. It queues approval // prompts produced by the segment, repacks once approval event IDs are // known, and finally emits the carriers and (if the run terminated) the // final metadata edit. -func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid.UserID, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, command string, startSeq int, anchorAt time.Time) { +func (dc *DummyClient) emitAIRunStream(target aiRunTarget, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, command string, startSeq int, anchorAt time.Time) { sizingRun := run annotateApprovalEventIDs(&sizingRun, approvalEventIDPlaceholders(sizingRun.Prompts)) - carriers, err := aistream.PackRunFromSeq(sizingRun, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) + carriers, err := aistream.PackRunByTimeFromSeq(sizingRun, startSeq, demoStreamCarrierMaxSpan) if err != nil { log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to pack AI stream") return } - carriers = splitCarriersForTimedEmission(carriers) nextSeq := aistream.NextSeq(carriers) - queuedPrompts := run.Prompts - if len(queuedPrompts) > 1 { + approvalQueue := aistream.NewApprovalQueue(aistream.ApprovalTimeout{After: defaultAIApprovalTimeout}) + approvalQueue.AddAll(run.Prompts) + activePrompt, hasActivePrompt := approvalQueue.Active() + if pending := approvalQueue.Pending(); len(pending) > 0 { log.Warn(). Str("run_id", run.RunID). - Int("approval_prompts", len(queuedPrompts)). - Msg("AI run produced multiple simultaneous approval prompts; queueing the first prompt only") - queuedPrompts = queuedPrompts[:1] + Int("pending_approval_prompts", len(pending)). + Msg("AI run produced multiple approval prompts; keeping one active interrupt and queueing the rest") } - approvalEventIDs := make(map[string]id.EventID, len(queuedPrompts)) - for _, prompt := range queuedPrompts { + approvalEventIDs := make(map[string]id.EventID, 1) + if hasActivePrompt { + prompt := activePrompt prompt.SeqStart = nextSeq - ctx := dc.queueAIApprovalPrompt(portal, sender, run, prompt, targetEventID, command, time.Now()) - if approvalEventID := dc.waitForMessageMXID(portal, networkid.MessageID(ctx.ID), 10*time.Second); approvalEventID != "" { - approvalEventIDs[ctx.ID] = approvalEventID + approvalCtx, approvalEventID, err := target.sendApprovalPrompt(dc, run, prompt, targetEventID, command, time.Now()) + if err == nil && approvalEventID != "" { + target.scheduleApprovalTimeout(dc, approvalCtx, approvalQueue.Timeout()) + approvalEventIDs[approvalCtx.ID] = approvalEventID log.Info(). Str("run_id", run.RunID). - Str("approval_id", ctx.ID). + Str("approval_id", approvalCtx.ID). Stringer("approval_event_id", approvalEventID). - Int("approval_seq_start", ctx.SeqStart). + Int("approval_seq_start", approvalCtx.SeqStart). Msg("AI approval notice ready for reaction") } else { log.Warn(). + Err(err). Str("run_id", run.RunID). - Str("approval_id", ctx.ID). - Int("approval_seq_start", ctx.SeqStart). + Str("approval_id", approvalCtx.ID). + Int("approval_seq_start", approvalCtx.SeqStart). Msg("Timed out waiting for AI approval notice Matrix event") } } if len(approvalEventIDs) > 0 { annotateApprovalEventIDs(&run, approvalEventIDs) - carriers, err = aistream.PackRunFromSeq(run, string(targetEventID), aistream.CarrierBudgetBytes, startSeq) + carriers, err = aistream.PackRunByTimeFromSeq(run, startSeq, demoStreamCarrierMaxSpan) if err != nil { log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to repack AI stream with approval event IDs") return } - carriers = splitCarriersForTimedEmission(carriers) if actualNextSeq := aistream.NextSeq(carriers); actualNextSeq != nextSeq { log.Warn(). Str("run_id", run.RunID). @@ -642,13 +655,12 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid Msg("AI approval event ID repack changed stream sequence count") return } - } else if len(queuedPrompts) > 0 { + } else if hasActivePrompt { log.Info(). Str("run_id", run.RunID). - Int("approval_prompts", len(queuedPrompts)). Msg("Sending approval stream without approval event IDs") } - dc.queuePackedAICarriers(portal, sender, targetEventID, run, carriers, startSeq, anchorAt) + target.sendCarriers(dc, targetEventID, run, carriers, startSeq, anchorAt) if len(run.Prompts) > 0 && run.Status.State == "streaming" { log.Info(). Str("run_id", run.RunID). @@ -657,10 +669,39 @@ func (dc *DummyClient) emitAIRunStream(portal *bridgev2.Portal, sender networkid Msg("AI run paused for approval") } if run.Status.State != "streaming" { - dc.queueAIRunFinalMetadata(portal, sender, messageID, run) + target.sendFinal(dc, messageID, targetEventID, run, time.Now()) } } +func (dc *DummyClient) scheduleAIApprovalTimeout(portal *bridgev2.Portal, approvalMessageID networkid.MessageID, timeout aistream.ApprovalTimeout) { + if dc == nil || portal == nil || approvalMessageID == "" || timeout.After <= 0 { + return + } + dc.wg.Add(1) + go func() { + defer dc.wg.Done() + timer := time.NewTimer(timeout.After) + defer timer.Stop() + select { + case <-dc.clientContext().Done(): + return + case <-timer.C: + } + approvalID := string(approvalMessageID) + if _, firstResolution := dc.resolveApprovalOnce(approvalID, timeout.Reason); !firstResolution { + return + } + ctx := dc.clientContext() + approvalMessage, err := dc.lookupMessage(ctx, portal.Receiver, approvalMessageID) + if err != nil || approvalMessage == nil { + log.Warn().Err(err).Str("approval_id", approvalID).Msg("Timed-out AI approval message was not found") + return + } + response := aistream.TimedOutApprovalResponse(approvalID) + dc.queueAIApprovalResponse(ctx, portal, approvalMessage, response) + }() +} + func (dc *DummyClient) queuePackedAICarriers(portal *bridgev2.Portal, sender networkid.UserID, targetEventID id.EventID, run aistream.Run, carriers []aistream.Carrier, startSeq int, anchorAt time.Time) { streamStart := time.Now() // minCarrierTimestamp guarantees every carrier lands strictly after the @@ -681,95 +722,174 @@ func (dc *DummyClient) queuePackedAICarriers(portal *bridgev2.Portal, sender net } } -func splitCarriersForTimedEmission(carriers []aistream.Carrier) []aistream.Carrier { - out := make([]aistream.Carrier, 0, len(carriers)) - for _, carrier := range carriers { - if len(carrier.Envelopes) <= 1 { - out = append(out, carrier) - continue +func (target aiRunTarget) sendAnchor(dc *DummyClient, run aistream.Run, messageID networkid.MessageID, timestamp time.Time) (id.EventID, error) { + if target.portal != nil { + dc.ensureAISenderInvited(target.portal, target.sender) + dc.UserLogin.QueueRemoteEvent(aibridgev2.Anchor(target.portal.PortalKey, target.sender, run, timestamp)) + eventID := dc.waitForMessageMXID(target.portal, messageID, 30*time.Second) + if eventID == "" { + return "", fmt.Errorf("timed out waiting for AI anchor Matrix event") } - for _, env := range carrier.Envelopes { - out = append(out, aistream.Carrier{Envelopes: []aistream.Envelope{env}}) + return eventID, nil + } + content, extra := aimatrix.AnchorContent(run) + return dc.sendAIMessageToRoom(target.bot, target.roomID, content, extra, timestamp) +} + +func (target aiRunTarget) sendApprovalPrompt(dc *DummyClient, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, command string, timestamp time.Time) (aistream.ApprovalContext, id.EventID, error) { + approvalCtx := approvalContextForPrompt(run, prompt, targetEventID, command) + if target.portal != nil { + approvalCtx = dc.queueAIApprovalPrompt(target.portal, target.sender, run, prompt, targetEventID, command, timestamp) + eventID := dc.waitForMessageMXID(target.portal, networkid.MessageID(approvalCtx.ID), 10*time.Second) + if eventID == "" { + return approvalCtx, "", fmt.Errorf("timed out waiting for AI approval notice Matrix event") } + return approvalCtx, eventID, nil } - return out + eventID, err := dc.sendAIApprovalPromptToRoom(target.bot, target.roomID, approvalCtx, timestamp) + return approvalCtx, eventID, err } -func (dc *DummyClient) sleepUntilCarrierTime(run aistream.Run, carrier aistream.Carrier, streamStart time.Time) { - target := carrierTimestamp(run, carrier, streamStart) - if target.IsZero() { +func (target aiRunTarget) scheduleApprovalTimeout(dc *DummyClient, approvalCtx aistream.ApprovalContext, timeout aistream.ApprovalTimeout) { + if target.portal != nil { + dc.scheduleAIApprovalTimeout(target.portal, networkid.MessageID(approvalCtx.ID), timeout) return } - delay := time.Until(target) - if delay <= 0 { + if dc == nil || target.bot == nil || target.roomID == "" || approvalCtx.ID == "" || timeout.After <= 0 { return } - timer := time.NewTimer(delay) - select { - case <-timer.C: - case <-dc.done(): - timer.Stop() + dc.wg.Add(1) + go func() { + defer dc.wg.Done() + timer := time.NewTimer(timeout.After) + defer timer.Stop() + select { + case <-dc.clientContext().Done(): + return + case <-timer.C: + } + if _, firstResolution := dc.resolveApprovalOnce(approvalCtx.ID, timeout.Reason); !firstResolution { + return + } + response := aistream.TimedOutApprovalResponse(approvalCtx.ID) + approvals := dc.recordAIApprovalDecision(approvalCtx.RunID, response) + run, err := buildAIApprovalContinuationRunWithApprovals(dc.clientContext(), approvalCtx, approvals, time.Now()) + if err != nil { + log.Warn().Err(err).Str("approval_id", approvalCtx.ID).Msg("Failed to build timed-out AI approval continuation") + return + } + dc.emitAIRunStream(target, networkid.MessageID(approvalCtx.MessageID), id.EventID(approvalCtx.TargetEvent), run, approvalCtx.Command, approvalCtx.SeqStart, time.Now()) + }() +} + +func (target aiRunTarget) sendCarriers(dc *DummyClient, targetEventID id.EventID, run aistream.Run, carriers []aistream.Carrier, startSeq int, anchorAt time.Time) { + if target.portal != nil { + dc.queuePackedAICarriers(target.portal, target.sender, targetEventID, run, carriers, startSeq, anchorAt) + return } + dc.sendPackedAICarriersToRoom(target.bot, target.roomID, targetEventID, run, carriers, startSeq, anchorAt) } -func carrierTimestamp(run aistream.Run, carrier aistream.Carrier, streamStart time.Time) time.Time { - base := runStartTimestamp(run) - if base.IsZero() { - return time.Time{} +func (target aiRunTarget) sendFinal(dc *DummyClient, messageID networkid.MessageID, targetEventID id.EventID, run aistream.Run, timestamp time.Time) { + if target.portal != nil { + dc.queueAIRunFinalMetadata(target.portal, target.sender, messageID, run) + return } - var latest time.Time - for _, env := range carrier.Envelopes { - eventTime := eventTimestamp(env.Part) - if eventTime.IsZero() { - continue + dc.sendAIRunFinalToRoom(target.bot, target.roomID, targetEventID, run, timestamp) +} + +func approvalContextForPrompt(run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, command string) aistream.ApprovalContext { + return aistream.ApprovalContext{ + ID: prompt.ID, + ThreadID: run.ThreadID, + RunID: run.RunID, + MessageID: run.MessageID, + Command: command, + ToolCallID: prompt.ToolCallID, + ToolName: prompt.ToolName, + Title: prompt.Title, + Description: prompt.Description, + PlanText: prompt.PlanText, + ExpiresAt: prompt.ExpiresAt, + Choices: aistream.DefaultApprovalChoices(), + TargetEvent: string(targetEventID), + AgentID: run.AgentID, + AgentName: run.AgentName, + Model: run.Model, + SeqStart: prompt.SeqStart, + PreviewText: run.Preview.Text, + PreviewTruncated: run.Preview.Truncated, + Metadata: prompt.Metadata, + } +} + +func (dc *DummyClient) sendAIApprovalPromptToRoom(bot bridgev2.MatrixAPI, roomID id.RoomID, approvalCtx aistream.ApprovalContext, timestamp time.Time) (id.EventID, error) { + content, extra := aimatrix.ApprovalContent(approvalCtx, aistream.DefaultApprovalChoices()) + return dc.sendAIMessageToRoom(bot, roomID, content, extra, timestamp) +} + +func (dc *DummyClient) sendPackedAICarriersToRoom(bot bridgev2.MatrixAPI, roomID id.RoomID, targetEventID id.EventID, run aistream.Run, carriers []aistream.Carrier, startSeq int, anchorAt time.Time) { + streamStart := time.Now() + minCarrierTimestamp := anchorAt.Add(time.Millisecond) + if streamStart.Before(minCarrierTimestamp) { + streamStart = minCarrierTimestamp + } + for i, carrier := range carriers { + dc.sleepUntilCarrierTime(run, carrier, streamStart) + now := time.Now() + if now.Before(minCarrierTimestamp) { + now = minCarrierTimestamp } - if latest.IsZero() || eventTime.After(latest) { - latest = eventTime + minCarrierTimestamp = now.Add(time.Nanosecond) + content, extra := aimatrix.CarrierContent(run, carrier, targetEventID) + if _, err := dc.sendAIMessageToRoom(bot, roomID, content, extra, now); err != nil { + log.Warn().Err(err).Str("run_id", run.RunID).Int("carrier_index", startSeq+i).Msg("Failed to send AI stream carrier to Matrix room") + return } } - if latest.IsZero() { - return time.Time{} +} + +func (dc *DummyClient) sendAIRunFinalToRoom(bot bridgev2.MatrixAPI, roomID id.RoomID, targetEventID id.EventID, run aistream.Run, timestamp time.Time) { + content, extra := aimatrix.FinalContent(run) + content.SetEdit(targetEventID) + raw := map[string]any{ + "m.new_content": extra, + "com.beeper.dont_render_edited": true, + } + if _, err := dc.sendAIMessageToRoom(bot, roomID, content, raw, timestamp); err != nil { + log.Warn().Err(err).Str("run_id", run.RunID).Msg("Failed to send AI final edit to Matrix room") } - return streamStart.Add(latest.Sub(base)) } -func runStartTimestamp(run aistream.Run) time.Time { - for _, evt := range run.Events { - if ts := eventTimestamp(evt); !ts.IsZero() { - return ts - } +func (dc *DummyClient) sendAIMessageToRoom(bot bridgev2.MatrixAPI, roomID id.RoomID, content *event.MessageEventContent, extra map[string]any, timestamp time.Time) (id.EventID, error) { + resp, err := bot.SendMessage(dc.clientContext(), roomID, event.EventMessage, &event.Content{ + Parsed: content, + Raw: extra, + }, &bridgev2.MatrixSendExtra{Timestamp: timestamp}) + if err != nil { + return "", err + } + if resp == nil { + return "", nil } - return time.Time{} + return resp.EventID, nil } -func eventTimestamp(evt agui.Event) time.Time { - raw, ok := evt["timestamp"] - if !ok { - return time.Time{} +func (dc *DummyClient) sleepUntilCarrierTime(run aistream.Run, carrier aistream.Carrier, streamStart time.Time) { + target := aistream.CarrierTimestamp(run, carrier, streamStart) + if target.IsZero() { + return } - var millis int64 - switch value := raw.(type) { - case int64: - millis = value - case int: - millis = int64(value) - case int32: - millis = int64(value) - case float64: - millis = int64(value) - case json.Number: - parsed, err := value.Int64() - if err != nil { - return time.Time{} - } - millis = parsed - default: - return time.Time{} + delay := time.Until(target) + if delay <= 0 { + return } - if millis <= 0 { - return time.Time{} + timer := time.NewTimer(delay) + select { + case <-timer.C: + case <-dc.done(): + timer.Stop() } - return time.UnixMilli(millis) } func (dc *DummyClient) waitForMessageMXID( @@ -780,65 +900,47 @@ func (dc *DummyClient) waitForMessageMXID( if dc == nil || dc.UserLogin == nil || dc.UserLogin.Bridge == nil || dc.UserLogin.Bridge.DB == nil || portal == nil { return "" } - parent := dc.clientContext() - ctx, cancel := context.WithTimeout(parent, timeout) - defer cancel() receivers := []networkid.UserLoginID{portal.Receiver} if dc.UserLogin.ID != "" && dc.UserLogin.ID != portal.Receiver { receivers = append(receivers, dc.UserLogin.ID) } - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for ctx.Err() == nil { - select { - case <-ctx.Done(): - return "" - case <-ticker.C: + perReceiverTimeout := timeout + if perReceiverTimeout <= 0 { + perReceiverTimeout = 5 * time.Second + } + if len(receivers) > 1 { + perReceiverTimeout /= time.Duration(len(receivers)) + if perReceiverTimeout < time.Second { + perReceiverTimeout = time.Second } - for _, receiver := range receivers { - mxid := dc.lookupMessageMXID(ctx, receiver, messageID) - if mxid != "" { - return mxid - } + } + for _, receiver := range receivers { + eventID, err := aibridgev2.WaitForMessageEventID( + dc.clientContext(), + dc.UserLogin.Bridge, + receiver, + messageID, + networkid.PartID("0"), + perReceiverTimeout, + ) + if err == nil && eventID != "" { + return eventID } } return "" } -func (dc *DummyClient) lookupMessageMXID(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) id.EventID { - message, err := dc.UserLogin.Bridge.DB.Message.GetFirstPartByID(ctx, receiver, messageID) - if err != nil || message == nil { - return "" +func (dc *DummyClient) lookupMessage(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) (*database.Message, error) { + if dc == nil || dc.UserLogin == nil || dc.UserLogin.Bridge == nil || dc.UserLogin.Bridge.DB == nil { + return nil, nil } - return message.MXID + return dc.UserLogin.Bridge.DB.Message.GetFirstPartByID(ctx, receiver, messageID) } func (dc *DummyClient) queueAIApprovalPrompt(portal *bridgev2.Portal, sender networkid.UserID, run aistream.Run, prompt aistream.ApprovalPrompt, targetEventID id.EventID, command string, timestamp time.Time) aistream.ApprovalContext { - choices := aistream.DefaultApprovalChoices() - approvalCtx := aistream.ApprovalContext{ - ID: prompt.ID, - ThreadID: run.ThreadID, - RunID: run.RunID, - MessageID: run.MessageID, - Command: command, - ToolCallID: prompt.ToolCallID, - ToolName: prompt.ToolName, - TargetEvent: string(targetEventID), - AgentID: run.AgentID, - AgentName: run.AgentName, - Model: run.Model, - SeqStart: prompt.SeqStart, - PreviewText: run.Preview.Text, - PreviewTruncated: run.Preview.Truncated, - } + approvalCtx := approvalContextForPrompt(run, prompt, targetEventID, command) dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalPrompt(portal.PortalKey, sender, approvalCtx, timestamp)) - - for i, choice := range choices { - choice := choice - dc.UserLogin.QueueRemoteEvent(aibridgev2.ApprovalOptionReaction(portal.PortalKey, sender, approvalCtx, choice, timestamp.Add(time.Duration(i+1)*time.Millisecond))) - } return approvalCtx } @@ -846,20 +948,43 @@ func annotateApprovalEventIDs(run *aistream.Run, eventIDs map[string]id.EventID) if run == nil || len(eventIDs) == 0 { return } - for _, evt := range run.Events { - if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomRequested { + for i := range run.Interrupts { + eventID := eventIDs[run.Interrupts[i].ID] + if eventID == "" { continue } - value, _ := evt["value"].(map[string]any) - if value == nil { + aistream.SetApprovalInterruptEventID(&run.Interrupts[i], string(eventID)) + } + for _, evt := range run.Events { + if evt.Type() != agui.EventRunFinished { continue } - approvalID := aistream.ApprovalIDFromRequestedValue(value) - eventID := eventIDs[approvalID] - if eventID == "" { - continue + annotateApprovalOutcomeEventIDs(evt, eventIDs) + } +} + +func annotateApprovalOutcomeEventIDs(evt agui.Event, eventIDs map[string]id.EventID) { + switch outcome := evt.Get("outcome").(type) { + case agui.RunFinishedOutcome: + for i := range outcome.Interrupts { + eventID := eventIDs[outcome.Interrupts[i].ID] + if eventID == "" { + continue + } + aistream.SetApprovalInterruptEventID(&outcome.Interrupts[i], string(eventID)) + } + evt.Set("outcome", outcome) + case *agui.RunFinishedOutcome: + if outcome == nil { + return + } + for i := range outcome.Interrupts { + eventID := eventIDs[outcome.Interrupts[i].ID] + if eventID == "" { + continue + } + aistream.SetApprovalInterruptEventID(&outcome.Interrupts[i], string(eventID)) } - aistream.SetApprovalRequestedEventID(value, string(eventID)) } } @@ -877,7 +1002,7 @@ func approvalEventIDPlaceholders(prompts []aistream.ApprovalPrompt) map[string]i return placeholders } -func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *bridgev2.Portal, approvalMessage *database.Message, response agui.ToolApprovalResponse) { +func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *bridgev2.Portal, approvalMessage *database.Message, response aistream.ToolApprovalResponse) { approvalCtx, ok := dc.approvalContextForMessage(ctx, portal, approvalMessage) if !ok { log.Warn().Str("approval_id", messageIDString(approvalMessage)).Msg("Missing AI approval metadata") @@ -902,8 +1027,12 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid if sender == "" { sender = dummyAISenderForPortal(portal) } - dc.ensureAISenderInvited(portal, sender) - dc.emitAIRunStream(portal, sender, networkid.MessageID(approvalCtx.MessageID), targetEventID, run, approvalCtx.Command, approvalCtx.SeqStart, now) + dc.emitAIRunStream(aiRunTarget{ + portal: portal, + threadID: approvalCtx.ThreadID, + sender: sender, + agentName: approvalCtx.AgentName, + }, networkid.MessageID(approvalCtx.MessageID), targetEventID, run, approvalCtx.Command, approvalCtx.SeqStart, now) log.Info(). Str("run_id", approvalCtx.RunID). Str("approval_id", approvalCtx.ID). @@ -916,7 +1045,7 @@ func (dc *DummyClient) queueAIApprovalResponse(ctx context.Context, portal *brid Msg("Queued AI approval continuation") } -func buildAIApprovalContinuationRunWithApprovals(ctx context.Context, approvalCtx aistream.ApprovalContext, approvals map[string]agui.ToolApprovalResponse, now time.Time) (aistream.Run, error) { +func buildAIApprovalContinuationRunWithApprovals(ctx context.Context, approvalCtx aistream.ApprovalContext, approvals map[string]aistream.ToolApprovalResponse, now time.Time) (aistream.Run, error) { cmd, err := parseCommand(approvalCtx.Command) if err != nil { return aistream.Run{}, err @@ -936,13 +1065,19 @@ func buildAIApprovalContinuationRunWithApprovals(ctx context.Context, approvalCt run.RunID = approvalCtx.RunID run.ThreadID = approvalCtx.ThreadID run.MessageID = approvalCtx.MessageID - run.ToolCallID = approvalCtx.ToolCallID - run.ApprovalID = approvalCtx.ID // Keep only prompts that the continuation segment newly emitted (i.e. // approvals raised by tools that ran AFTER the resolved one). The // already-resolved approval has been removed from the event range above // and must not be queued again. run.Prompts = filterPendingPrompts(run.Prompts, approvalCtx.ID, run.Events) + run.Interrupts = filterPendingInterrupts(run.Interrupts, run.Prompts, run.Events) + if len(run.Prompts) > 0 { + run.ApprovalID = run.Prompts[0].ID + run.ToolCallID = run.Prompts[0].ToolCallID + } else { + run.ApprovalID = "" + run.ToolCallID = "" + } return *run, nil } @@ -950,15 +1085,9 @@ func filterPendingPrompts(prompts []aistream.ApprovalPrompt, resolvedID string, if len(prompts) == 0 { return nil } - requested := make(map[string]bool, len(events)) - for _, evt := range events { - if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomRequested { - continue - } - value, _ := evt["value"].(map[string]any) - if id := aistream.ApprovalIDFromRequestedValue(value); id != "" { - requested[id] = true - } + requested := approvalInterruptIDsFromEvents(events) + if len(requested) == 0 { + return nil } out := prompts[:0] for _, prompt := range prompts { @@ -973,25 +1102,87 @@ func filterPendingPrompts(prompts []aistream.ApprovalPrompt, resolvedID string, return out } +func filterPendingInterrupts(interrupts []agui.Interrupt, prompts []aistream.ApprovalPrompt, events []agui.Event) []agui.Interrupt { + if len(prompts) == 0 { + return nil + } + pending := make(map[string]bool, len(prompts)) + for _, prompt := range prompts { + pending[prompt.ID] = true + } + var out []agui.Interrupt + for _, interrupt := range approvalInterruptsFromEvents(events) { + if pending[interrupt.ID] { + out = append(out, interrupt) + } + } + if len(out) > 0 { + return out + } + for _, interrupt := range interrupts { + if pending[interrupt.ID] { + out = append(out, interrupt) + } + } + return out +} + +func approvalInterruptIDsFromEvents(events []agui.Event) map[string]bool { + requested := map[string]bool{} + for _, interrupt := range approvalInterruptsFromEvents(events) { + if interrupt.ID != "" { + requested[interrupt.ID] = true + } + } + return requested +} + +func approvalInterruptsFromEvents(events []agui.Event) []agui.Interrupt { + var interrupts []agui.Interrupt + for _, evt := range events { + if evt.Type() != agui.EventRunFinished { + continue + } + switch outcome := evt.Get("outcome").(type) { + case agui.RunFinishedOutcome: + if outcome.Type != agui.OutcomeInterrupt { + continue + } + interrupts = append(interrupts, outcome.Interrupts...) + case *agui.RunFinishedOutcome: + if outcome == nil || outcome.Type != agui.OutcomeInterrupt { + continue + } + interrupts = append(interrupts, outcome.Interrupts...) + } + } + return interrupts +} + func approvalContinuationStart(events []agui.Event, approvalID string) int { for i, evt := range events { - if evt["type"] != agui.EventCustom || evt["name"] != agui.ApprovalCustomResponded { + if evt.Type() != agui.EventToolCallResult { continue } - value, _ := evt["value"].(map[string]any) - approval, _ := value["approval"].(agui.ToolApprovalResponse) - if approval.ID == approvalID { + if toolResultApprovalID(evt) == approvalID { return i } - if raw, ok := value["approval"].(map[string]any); ok { - if idValue, _ := raw["id"].(string); idValue == approvalID { - return i - } - } } return -1 } +func toolResultApprovalID(evt agui.Event) string { + content, _ := evt.Get("content").(string) + if content == "" { + return "" + } + result, ok := aistream.ParseApprovalToolResult(content) + if !ok { + return "" + } + return result.ApprovalID +} + func (dc *DummyClient) approvalContextForMessage(ctx context.Context, portal *bridgev2.Portal, message *database.Message) (aistream.ApprovalContext, bool) { var fetch func(context.Context, networkid.MessageID) (*database.Message, error) if dc != nil && dc.UserLogin != nil && dc.UserLogin.Bridge != nil && dc.UserLogin.Bridge.DB != nil && portal != nil { @@ -1034,44 +1225,112 @@ func approvalContextFromAny(value any) (aistream.ApprovalContext, bool) { } return validApprovalContext(*typed) case map[string]any: - if nested, ok := typed["com.beeper.ai.approval"]; ok { + if nested, ok := typed[aistream.BeeperAIApprovalKey]; ok { return approvalContextFromAny(nested) } - case *map[string]any: - if typed == nil { - return aistream.ApprovalContext{}, false - } - return approvalContextFromAny(*typed) + return validApprovalContext(approvalContextFromMap(typed)) case json.RawMessage: return approvalContextFromJSON(typed) case []byte: return approvalContextFromJSON(typed) - case string: - return approvalContextFromJSON([]byte(typed)) } - var ctx aistream.ApprovalContext - raw, err := json.Marshal(value) - if err != nil { - return aistream.ApprovalContext{}, false - } - if err = json.Unmarshal(raw, &ctx); err != nil { - return aistream.ApprovalContext{}, false + return aistream.ApprovalContext{}, false +} + +func approvalContextFromMap(raw map[string]any) aistream.ApprovalContext { + return aistream.ApprovalContext{ + ID: stringField(raw, "id"), + ThreadID: stringField(raw, "threadId"), + RunID: stringField(raw, "runId"), + MessageID: stringField(raw, "messageId"), + Command: stringField(raw, "command"), + ToolCallID: stringField(raw, "toolCallId"), + ToolName: stringField(raw, "toolName"), + Title: stringField(raw, "title"), + Description: stringField(raw, "description"), + PlanText: stringField(raw, "planText"), + ExpiresAt: stringField(raw, "expiresAt"), + Choices: approvalChoicesField(raw, "choices"), + TargetEvent: stringField(raw, "targetEvent"), + AgentID: stringField(raw, "agentId"), + AgentName: stringField(raw, "agentName"), + Model: stringField(raw, "model"), + SeqStart: intField(raw, "seqStart"), + PreviewText: stringField(raw, "previewText"), + PreviewTruncated: boolField(raw, "previewTruncated"), + Metadata: mapField(raw, "metadata"), } - return validApprovalContext(ctx) } func approvalContextFromJSON(raw []byte) (aistream.ApprovalContext, bool) { - var decoded any - if err := json.Unmarshal(raw, &decoded); err == nil { - if approvalCtx, ok := approvalContextFromAny(decoded); ok { + var ctx aistream.ApprovalContext + if err := json.Unmarshal(raw, &ctx); err == nil { + if approvalCtx, ok := validApprovalContext(ctx); ok { return approvalCtx, true } } - var ctx aistream.ApprovalContext - if err := json.Unmarshal(raw, &ctx); err != nil { + var wrapper map[string]any + if err := json.Unmarshal(raw, &wrapper); err != nil { return aistream.ApprovalContext{}, false } - return validApprovalContext(ctx) + return approvalContextFromAny(wrapper) +} + +func stringField(raw map[string]any, key string) string { + value, _ := raw[key].(string) + return value +} + +func intField(raw map[string]any, key string) int { + switch value := raw[key].(type) { + case int: + return value + case int64: + return int(value) + case float64: + return int(value) + default: + return 0 + } +} + +func boolField(raw map[string]any, key string) bool { + value, _ := raw[key].(bool) + return value +} + +func mapField(raw map[string]any, key string) map[string]any { + switch value := raw[key].(type) { + case map[string]any: + return value + default: + return nil + } +} + +func approvalChoicesField(raw map[string]any, key string) []aistream.ApprovalChoice { + switch value := raw[key].(type) { + case []aistream.ApprovalChoice: + return value + case []any: + choices := make([]aistream.ApprovalChoice, 0, len(value)) + for _, item := range value { + rawChoice, ok := item.(map[string]any) + if !ok { + return nil + } + choices = append(choices, aistream.ApprovalChoice{ + Key: stringField(rawChoice, "key"), + Label: stringField(rawChoice, "label"), + Alias: stringField(rawChoice, "alias"), + Style: stringField(rawChoice, "style"), + Shortcut: stringField(rawChoice, "shortcut"), + }) + } + return choices + default: + return nil + } } func messageIDString(message *database.Message) string { @@ -1105,12 +1364,12 @@ func (dc *DummyClient) ensureAIRunSession(runID string) { dc.aiRunSessions = make(map[string]*aiRunSession) } if dc.aiRunSessions[runID] == nil { - dc.aiRunSessions[runID] = &aiRunSession{Decisions: make(map[string]agui.ToolApprovalResponse)} + dc.aiRunSessions[runID] = &aiRunSession{Decisions: make(map[string]aistream.ToolApprovalResponse)} } } -func (dc *DummyClient) recordAIApprovalDecision(runID string, response agui.ToolApprovalResponse) map[string]agui.ToolApprovalResponse { - decisions := make(map[string]agui.ToolApprovalResponse) +func (dc *DummyClient) recordAIApprovalDecision(runID string, response aistream.ToolApprovalResponse) map[string]aistream.ToolApprovalResponse { + decisions := make(map[string]aistream.ToolApprovalResponse) if response.ID == "" { return decisions } @@ -1125,7 +1384,7 @@ func (dc *DummyClient) recordAIApprovalDecision(runID string, response agui.Tool } session := dc.aiRunSessions[runID] if session == nil { - session = &aiRunSession{Decisions: make(map[string]agui.ToolApprovalResponse)} + session = &aiRunSession{Decisions: make(map[string]aistream.ToolApprovalResponse)} dc.aiRunSessions[runID] = session } session.Decisions[response.ID] = response diff --git a/pkg/connector/client_test.go b/pkg/connector/client_test.go index 1c12db6..07747ca 100644 --- a/pkg/connector/client_test.go +++ b/pkg/connector/client_test.go @@ -42,20 +42,14 @@ func TestGetRemoteEchoBehavior(t *testing.T) { func TestSleepUntilCarrierTimeWithoutConnectedContext(t *testing.T) { base := time.Now() + builder := agui.NewEventBuilder("dummybridge/test", func() time.Time { return base }) run := aistream.Run{ - Events: []agui.Event{{ - "type": agui.EventRunStarted, - "timestamp": base.UnixMilli(), - "threadId": "thread-1", - }}, + Events: []agui.Event{builder.RunStarted("thread-1", "run-1")}, } + builder = agui.NewEventBuilder("dummybridge/test", func() time.Time { return base.Add(time.Millisecond) }) carrier := aistream.Carrier{ Envelopes: []aistream.Envelope{{ - Part: agui.Event{ - "type": agui.EventTextMessageContent, - "timestamp": base.Add(time.Millisecond).UnixMilli(), - "messageId": "message-1", - }, + Event: builder.TextMessageContent("message-1", "hello"), }}, } diff --git a/pkg/connector/commands.go b/pkg/connector/commands.go index e3934e1..1212972 100644 --- a/pkg/connector/commands.go +++ b/pkg/connector/commands.go @@ -268,10 +268,6 @@ var FileCommand = &commands.FullHandler{ } func runStreamCommand(e *commands.Event, name string) { - if e.Portal == nil { - e.Reply("Can only stream within a portal") - return - } login := e.User.GetDefaultLogin() if login == nil { e.Reply("No login") @@ -287,7 +283,11 @@ func runStreamCommand(e *commands.Event, name string) { e.Reply(err.Error()) return } - client.queueAIResponse(e.Ctx, e.Portal, &event.MessageEventContent{Body: body}) + if e.Portal != nil { + client.queueAIResponse(e.Ctx, e.Portal, &event.MessageEventContent{Body: body}) + } else { + client.queueAIResponseInRoom(e.Ctx, e.Bot, e.RoomID, &event.MessageEventContent{Body: body}) + } e.Reply("Started %s", name) }