Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aiscript-runtime/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ pub struct Endpoint {
pub path_specs: Vec<PathSpec>,
#[allow(unused)]
pub return_type: Option<String>,
pub path: Vec<Field>,
pub query: Vec<Field>,
pub body: RequestBody,
pub statements: String,
Expand Down
103 changes: 101 additions & 2 deletions aiscript-runtime/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use aiscript_vm::{ReturnValue, Vm, VmError};
use axum::{
Form, Json, RequestExt,
body::Body,
extract::{self, FromRequest, Request},
extract::{self, FromRequest, RawPathParams, Request},
http::{HeaderName, HeaderValue},
response::{IntoResponse, Response},
};
Expand Down Expand Up @@ -57,6 +57,7 @@ pub struct Field {
#[derive(Clone)]
pub struct Endpoint {
pub annotation: RouteAnnotation,
pub path_params: Vec<Field>,
pub query_params: Vec<Field>,
pub body_type: BodyKind,
pub body_fields: Vec<Field>,
Expand All @@ -70,6 +71,7 @@ pub struct Endpoint {

enum ProcessingState {
ValidatingAuth,
ValidatingPath,
ValidatingQuery,
ValidatingBody,
Executing(JoinHandle<Result<ReturnValue, VmError>>),
Expand All @@ -79,6 +81,7 @@ pub struct RequestProcessor {
endpoint: Endpoint,
request: Request<Body>,
jwt_claim: Option<Value>,
path_data: HashMap<String, Value>,
query_data: HashMap<String, Value>,
body_data: HashMap<String, Value>,
state: ProcessingState,
Expand All @@ -89,12 +92,13 @@ impl RequestProcessor {
let state = if endpoint.annotation.is_auth_required() {
ProcessingState::ValidatingAuth
} else {
ProcessingState::ValidatingQuery
ProcessingState::ValidatingPath
};
Self {
endpoint,
request,
jwt_claim: None,
path_data: HashMap::new(),
query_data: HashMap::new(),
body_data: HashMap::new(),
state,
Expand Down Expand Up @@ -307,6 +311,99 @@ impl Future for RequestProcessor {
}
}
}
self.state = ProcessingState::ValidatingPath;
}
ProcessingState::ValidatingPath => {
let raw_path_params = {
// Extract path parameters using Axum's RawPathParams extractor
let future = self.request.extract_parts::<RawPathParams>();

tokio::pin!(future);
match future.poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(params)) => params,
Poll::Ready(Err(e)) => {
return Poll::Ready(Ok(format!(
"Failed to extract path parameters: {}",
e
)
.into_response()));
}
}
};

// Process and validate each path parameter
for (param_name, param_value) in &raw_path_params {
// Find the corresponding path parameter field
if let Some(field) = self
.endpoint
.path_params
.iter()
.find(|f| f.name == param_name)
{
// Convert the value to the appropriate type based on the field definition
let value = match field.field_type {
FieldType::Str => Value::String(param_value.to_string()),
FieldType::Number => {
if let Ok(num) = param_value.parse::<i64>() {
Value::Number(num.into())
} else if let Ok(float) = param_value.parse::<f64>() {
match serde_json::Number::from_f64(float) {
Some(n) => Value::Number(n),
None => {
return Poll::Ready(Ok(
format!("Invalid path parameter type for {}: could not convert to number", param_name)
.into_response()
));
}
}
} else {
return Poll::Ready(Ok(format!(
"Invalid path parameter type for {}: expected a number",
param_name
)
.into_response()));
}
}
FieldType::Bool => match param_value.to_lowercase().as_str() {
"true" => Value::Bool(true),
"false" => Value::Bool(false),
_ => {
return Poll::Ready(Ok(
format!("Invalid path parameter type for {}: expected a boolean", param_name)
.into_response()
));
}
},
_ => {
return Poll::Ready(Ok(format!(
"Unsupported path parameter type for {}",
param_name
)
.into_response()));
}
};

// Validate the value using our existing validation infrastructure
if let Err(e) = Self::validate_field(field, &value) {
return Poll::Ready(Ok(e.into_response()));
}

// Store the validated parameter
self.path_data.insert(param_name.to_string(), value);
}
}

// Check for missing required parameters
for field in &self.endpoint.path_params {
if !self.path_data.contains_key(&field.name) && field.required {
return Poll::Ready(Ok(
ServerError::MissingField(field.name.clone()).into_response()
));
}
}

// Move to the next state
self.state = ProcessingState::ValidatingQuery;
}
ProcessingState::ValidatingQuery => {
Expand Down Expand Up @@ -400,6 +497,7 @@ impl Future for RequestProcessor {
} else {
None
};
let path_data = mem::take(&mut self.path_data);
let query_data = mem::take(&mut self.query_data);
let body_data = mem::take(&mut self.body_data);
let pg_connection = self.endpoint.pg_connection.clone();
Expand All @@ -417,6 +515,7 @@ impl Future for RequestProcessor {
vm.eval_function(
0,
&[
Value::Object(path_data.into_iter().collect()),
Value::Object(query_data.into_iter().collect()),
Value::Object(body_data.into_iter().collect()),
Value::Object(
Expand Down
1 change: 1 addition & 0 deletions aiscript-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ async fn run_server(
for endpoint_spec in route.endpoints {
let endpoint = Endpoint {
annotation: endpoint_spec.annotation.or(&route.annotation),
path_params: endpoint_spec.path.into_iter().map(convert_field).collect(),
query_params: endpoint_spec.query.into_iter().map(convert_field).collect(),
body_type: endpoint_spec.body.kind,
body_fields: endpoint_spec
Expand Down
38 changes: 21 additions & 17 deletions aiscript-runtime/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ impl<'a> Parser<'a> {
let docs = self.parse_docs();

// Parse structured parts (query and body)
let mut path = Vec::new();
let mut query = Vec::new();
let mut body = RequestBody::default();

Expand All @@ -98,6 +99,9 @@ impl<'a> Parser<'a> {
} else if self.scanner.check_identifier("body") {
self.advance();
body.fields = self.parse_fields()?;
} else if self.scanner.check_identifier("path") {
self.advance();
path = self.parse_fields()?;
} else if self.scanner.check(TokenType::At) {
let directives = DirectiveParser::new(&mut self.scanner).parse_directives();
for directive in directives {
Expand Down Expand Up @@ -126,12 +130,16 @@ impl<'a> Parser<'a> {
}
// Parse the handler function body
let script = self.read_raw_script()?;
let statements = format!("ai fn handler(query, body, request, header){{{}}}", script);
let statements = format!(
"ai fn handler(path, query, body, request, header){{{}}}",
script
);
self.consume(TokenType::CloseBrace, "Expect '}' after endpoint")?;
Ok(Endpoint {
annotation,
path_specs,
return_type: None,
path,
query,
body,
statements,
Expand Down Expand Up @@ -304,30 +312,26 @@ impl<'a> Parser<'a> {
path.push('/');
self.advance();
}
TokenType::Less => {
self.advance(); // Consume <
TokenType::Colon => {
self.advance(); // Consume :

// Parse parameter name
if !self.check(TokenType::Identifier) {
return Err("Expected parameter name".to_string());
return Err("Expected parameter name after ':'".to_string());
}
let name = self.current.lexeme.to_string();
self.advance();

self.consume(TokenType::Colon, "Expected ':' after parameter name")?;

// Parse parameter type
if !self.check(TokenType::Identifier) {
return Err("Expected parameter type".to_string());
}
let param_type = self.current.lexeme.to_string();
self.advance();

self.consume(TokenType::Greater, "Expected '>' after parameter type")?;

path.push(':');
// Add parameter to path in the format Axum expects: {id}
path.push('{');
path.push_str(&name);
params.push(PathParameter { name, param_type });
path.push('}');

// Add parameter to our list
params.push(PathParameter {
name,
param_type: "str".to_string(), // Default type, will be overridden by path block
});
}
TokenType::Identifier => {
path.push_str(self.current.lexeme);
Expand Down
16 changes: 16 additions & 0 deletions examples/routes/path.ai
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
get /users/:id/posts/:postId {
path {
@string(min_len=3)
id: str,
postId: int,
}

query {
refresh: bool = true
}

print(path);
let userId = path.id;
let postId = path.postId;
return f"Accessing post {postId} for user {userId}";
}
Loading