diff --git a/crates/trigger-http2/src/headers.rs b/crates/trigger-http2/src/headers.rs index d9a2843c9..eed6a4075 100644 --- a/crates/trigger-http2/src/headers.rs +++ b/crates/trigger-http2/src/headers.rs @@ -139,6 +139,8 @@ fn prepare_header_key(key: &str) -> String { #[cfg(test)] mod tests { use super::*; + use anyhow::Result; + use spin_http::routes::Router; #[test] fn test_spin_header_keys() { @@ -155,4 +157,175 @@ mod tests { "spin-raw-component-route".to_string() ); } + + #[test] + fn test_default_headers() -> Result<()> { + let scheme = "https"; + let host = "fermyon.dev"; + let trigger_route = "/foo/..."; + let component_path = "/foo"; + let path_info = "/bar"; + let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap(); + + let req_uri = format!( + "{}://{}{}{}?key1=value1&key2=value2", + scheme, host, component_path, path_info + ); + + let req = http::Request::builder() + .method("POST") + .uri(req_uri) + .body("")?; + + let (router, _) = Router::build("/", [("DUMMY", &trigger_route.into())])?; + let route_match = router.route("/foo/bar")?; + + let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?; + + assert_eq!( + search(&FULL_URL, &default_headers).unwrap(), + "https://fermyon.dev/foo/bar?key1=value1&key2=value2".to_string() + ); + assert_eq!( + search(&PATH_INFO, &default_headers).unwrap(), + "/bar".to_string() + ); + assert_eq!( + search(&MATCHED_ROUTE, &default_headers).unwrap(), + "/foo/...".to_string() + ); + assert_eq!( + search(&BASE_PATH, &default_headers).unwrap(), + "/".to_string() + ); + assert_eq!( + search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(), + "/foo/...".to_string() + ); + assert_eq!( + search(&COMPONENT_ROUTE, &default_headers).unwrap(), + "/foo".to_string() + ); + assert_eq!( + search(&CLIENT_ADDR, &default_headers).unwrap(), + "127.0.0.1:8777".to_string() + ); + + Ok(()) + } + + #[test] + fn test_default_headers_with_named_wildcards() -> Result<()> { + let scheme = "https"; + let host = "fermyon.dev"; + let trigger_route = "/foo/:userid/..."; + let component_path = "/foo"; + let path_info = "/bar"; + let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap(); + + let req_uri = format!( + "{}://{}{}/42{}?key1=value1&key2=value2", + scheme, host, component_path, path_info + ); + + let req = http::Request::builder() + .method("POST") + .uri(req_uri) + .body("")?; + + let (router, _) = Router::build("/", [("DUMMY", &trigger_route.into())])?; + let route_match = router.route("/foo/42/bar")?; + + let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?; + + assert_eq!( + search(&FULL_URL, &default_headers).unwrap(), + "https://fermyon.dev/foo/42/bar?key1=value1&key2=value2".to_string() + ); + assert_eq!( + search(&PATH_INFO, &default_headers).unwrap(), + "/bar".to_string() + ); + assert_eq!( + search(&MATCHED_ROUTE, &default_headers).unwrap(), + "/foo/:userid/...".to_string() + ); + assert_eq!( + search(&BASE_PATH, &default_headers).unwrap(), + "/".to_string() + ); + assert_eq!( + search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(), + "/foo/:userid/...".to_string() + ); + assert_eq!( + search(&COMPONENT_ROUTE, &default_headers).unwrap(), + "/foo/:userid".to_string() + ); + assert_eq!( + search(&CLIENT_ADDR, &default_headers).unwrap(), + "127.0.0.1:8777".to_string() + ); + + assert_eq!( + search( + &["SPIN_PATH_MATCH_USERID", "X_PATH_MATCH_USERID"], + &default_headers + ) + .unwrap(), + "42".to_string() + ); + + Ok(()) + } + + #[test] + fn forbidden_headers_are_removed() { + let mut req = Request::get("http://test.spin.internal") + .header("Host", "test.spin.internal") + .header("accept", "text/plain") + .body(Default::default()) + .unwrap(); + + strip_forbidden_headers(&mut req); + + assert_eq!(1, req.headers().len()); + assert!(req.headers().get("Host").is_none()); + + let mut req = Request::get("http://test.spin.internal") + .header("Host", "test.spin.internal:1234") + .header("accept", "text/plain") + .body(Default::default()) + .unwrap(); + + strip_forbidden_headers(&mut req); + + assert_eq!(1, req.headers().len()); + assert!(req.headers().get("Host").is_none()); + } + + #[test] + fn non_forbidden_headers_are_not_removed() { + let mut req = Request::get("http://test.example.com") + .header("Host", "test.example.org") + .header("accept", "text/plain") + .body(Default::default()) + .unwrap(); + + strip_forbidden_headers(&mut req); + + assert_eq!(2, req.headers().len()); + assert!(req.headers().get("Host").is_some()); + } + + fn search(keys: &[&str; 2], headers: &[([String; 2], String)]) -> Option { + let mut res: Option = None; + for (k, v) in headers { + if k[0] == keys[0] && k[1] == keys[1] { + res = Some(v.clone()); + } + } + + res + } } diff --git a/crates/trigger-http2/src/lib.rs b/crates/trigger-http2/src/lib.rs index 8263915b1..994911099 100644 --- a/crates/trigger-http2/src/lib.rs +++ b/crates/trigger-http2/src/lib.rs @@ -10,7 +10,6 @@ mod wagi; mod wasi; use std::{ - collections::HashMap, error::Error, net::{Ipv4Addr, SocketAddr, ToSocketAddrs}, path::PathBuf, @@ -21,7 +20,6 @@ use anyhow::{bail, Context}; use clap::Args; use serde::Deserialize; use spin_app::App; -use spin_http::{config::HttpTriggerConfig, routes::Router}; use spin_trigger2::Trigger; use wasmtime_wasi_http::bindings::wasi::http::types::ErrorCode; @@ -72,9 +70,6 @@ pub struct HttpTrigger { /// If the port is set to 0, the actual address will be determined by the OS. listen_addr: SocketAddr, tls_config: Option, - router: Router, - // Component ID -> component trigger config - component_trigger_configs: HashMap, } impl Trigger for HttpTrigger { @@ -109,38 +104,9 @@ impl HttpTrigger { ) -> anyhow::Result { Self::validate_app(app)?; - let component_trigger_configs = HashMap::from_iter( - app.trigger_configs::("http")? - .into_iter() - .map(|(_, config)| (config.component.clone(), config)), - ); - - let component_routes = component_trigger_configs - .iter() - .map(|(component_id, config)| (component_id.as_str(), &config.route)); - let (router, duplicate_routes) = Router::build("/", component_routes)?; - if !duplicate_routes.is_empty() { - tracing::error!( - "The following component routes are duplicates and will never be used:" - ); - for dup in &duplicate_routes { - tracing::error!( - " {}: {} (duplicate of {})", - dup.replaced_id, - dup.route(), - dup.effective_id, - ); - } - } - tracing::trace!( - "Constructed router: {:?}", - router.routes().collect::>() - ); Ok(Self { listen_addr, tls_config, - router, - component_trigger_configs, }) } @@ -149,16 +115,8 @@ impl HttpTrigger { let Self { listen_addr, tls_config, - router, - component_trigger_configs, } = self; - let server = Arc::new(HttpServer::new( - listen_addr, - tls_config, - trigger_app, - router, - component_trigger_configs, - )?); + let server = Arc::new(HttpServer::new(listen_addr, tls_config, trigger_app)?); Ok(server) } @@ -221,142 +179,7 @@ pub fn dns_error(rcode: String, info_code: u16) -> ErrorCode { #[cfg(test)] mod tests { - use anyhow::Result; - use http::Request; - - use super::{headers::*, *}; - - #[test] - fn test_default_headers() -> Result<()> { - let scheme = "https"; - let host = "fermyon.dev"; - let trigger_route = "/foo/..."; - let component_path = "/foo"; - let path_info = "/bar"; - let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap(); - - let req_uri = format!( - "{}://{}{}{}?key1=value1&key2=value2", - scheme, host, component_path, path_info - ); - - let req = http::Request::builder() - .method("POST") - .uri(req_uri) - .body("")?; - - let (router, _) = Router::build("/", [("DUMMY", &trigger_route.into())])?; - let route_match = router.route("/foo/bar")?; - - let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?; - - assert_eq!( - search(&FULL_URL, &default_headers).unwrap(), - "https://fermyon.dev/foo/bar?key1=value1&key2=value2".to_string() - ); - assert_eq!( - search(&PATH_INFO, &default_headers).unwrap(), - "/bar".to_string() - ); - assert_eq!( - search(&MATCHED_ROUTE, &default_headers).unwrap(), - "/foo/...".to_string() - ); - assert_eq!( - search(&BASE_PATH, &default_headers).unwrap(), - "/".to_string() - ); - assert_eq!( - search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(), - "/foo/...".to_string() - ); - assert_eq!( - search(&COMPONENT_ROUTE, &default_headers).unwrap(), - "/foo".to_string() - ); - assert_eq!( - search(&CLIENT_ADDR, &default_headers).unwrap(), - "127.0.0.1:8777".to_string() - ); - - Ok(()) - } - - #[test] - fn test_default_headers_with_named_wildcards() -> Result<()> { - let scheme = "https"; - let host = "fermyon.dev"; - let trigger_route = "/foo/:userid/..."; - let component_path = "/foo"; - let path_info = "/bar"; - let client_addr: SocketAddr = "127.0.0.1:8777".parse().unwrap(); - - let req_uri = format!( - "{}://{}{}/42{}?key1=value1&key2=value2", - scheme, host, component_path, path_info - ); - - let req = http::Request::builder() - .method("POST") - .uri(req_uri) - .body("")?; - - let (router, _) = Router::build("/", [("DUMMY", &trigger_route.into())])?; - let route_match = router.route("/foo/42/bar")?; - - let default_headers = compute_default_headers(req.uri(), host, &route_match, client_addr)?; - - assert_eq!( - search(&FULL_URL, &default_headers).unwrap(), - "https://fermyon.dev/foo/42/bar?key1=value1&key2=value2".to_string() - ); - assert_eq!( - search(&PATH_INFO, &default_headers).unwrap(), - "/bar".to_string() - ); - assert_eq!( - search(&MATCHED_ROUTE, &default_headers).unwrap(), - "/foo/:userid/...".to_string() - ); - assert_eq!( - search(&BASE_PATH, &default_headers).unwrap(), - "/".to_string() - ); - assert_eq!( - search(&RAW_COMPONENT_ROUTE, &default_headers).unwrap(), - "/foo/:userid/...".to_string() - ); - assert_eq!( - search(&COMPONENT_ROUTE, &default_headers).unwrap(), - "/foo/:userid".to_string() - ); - assert_eq!( - search(&CLIENT_ADDR, &default_headers).unwrap(), - "127.0.0.1:8777".to_string() - ); - - assert_eq!( - search( - &["SPIN_PATH_MATCH_USERID", "X_PATH_MATCH_USERID"], - &default_headers - ) - .unwrap(), - "42".to_string() - ); - - Ok(()) - } - - fn search(keys: &[&str; 2], headers: &[([String; 2], String)]) -> Option { - let mut res: Option = None; - for (k, v) in headers { - if k[0] == keys[0] && k[1] == keys[1] { - res = Some(v.clone()); - } - } - - res - } + use super::*; #[test] fn parse_listen_addr_prefers_ipv4() { @@ -364,43 +187,4 @@ mod tests { assert_eq!(addr.ip(), Ipv4Addr::LOCALHOST); assert_eq!(addr.port(), 12345); } - - #[test] - fn forbidden_headers_are_removed() { - let mut req = Request::get("http://test.spin.internal") - .header("Host", "test.spin.internal") - .header("accept", "text/plain") - .body(Default::default()) - .unwrap(); - - strip_forbidden_headers(&mut req); - - assert_eq!(1, req.headers().len()); - assert!(req.headers().get("Host").is_none()); - - let mut req = Request::get("http://test.spin.internal") - .header("Host", "test.spin.internal:1234") - .header("accept", "text/plain") - .body(Default::default()) - .unwrap(); - - strip_forbidden_headers(&mut req); - - assert_eq!(1, req.headers().len()); - assert!(req.headers().get("Host").is_none()); - } - - #[test] - fn non_forbidden_headers_are_not_removed() { - let mut req = Request::get("http://test.example.com") - .header("Host", "test.example.org") - .header("accept", "text/plain") - .body(Default::default()) - .unwrap(); - - strip_forbidden_headers(&mut req); - - assert_eq!(2, req.headers().len()); - assert!(req.headers().get("Host").is_some()); - } } diff --git a/crates/trigger-http2/src/server.rs b/crates/trigger-http2/src/server.rs index 3923bc9c4..8ead8feeb 100644 --- a/crates/trigger-http2/src/server.rs +++ b/crates/trigger-http2/src/server.rs @@ -58,9 +58,42 @@ impl HttpServer { listen_addr: SocketAddr, tls_config: Option, trigger_app: TriggerApp, - router: Router, - component_trigger_configs: HashMap, ) -> anyhow::Result { + // This needs to be a vec before building the router to handle duplicate routes + let component_trigger_configs = Vec::from_iter( + trigger_app + .app() + .trigger_configs::("http")? + .into_iter() + .map(|(_, config)| (config.component.clone(), config)), + ); + + // Build router + let component_routes = component_trigger_configs + .iter() + .map(|(component_id, config)| (component_id.as_str(), &config.route)); + let (router, duplicate_routes) = Router::build("/", component_routes)?; + if !duplicate_routes.is_empty() { + tracing::error!( + "The following component routes are duplicates and will never be used:" + ); + for dup in &duplicate_routes { + tracing::error!( + " {}: {} (duplicate of {})", + dup.replaced_id, + dup.route(), + dup.effective_id, + ); + } + } + tracing::trace!( + "Constructed router: {:?}", + router.routes().collect::>() + ); + + // Now that router is built we can merge duplicate routes by component + let component_trigger_configs = HashMap::from_iter(component_trigger_configs); + let component_handler_types = component_trigger_configs .keys() .map(|component_id| {