diff --git a/Cargo.lock b/Cargo.lock index bdcb978..256c6d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5,3 +5,12 @@ version = 4 [[package]] name = "mini_web_server" version = "0.1.0" +dependencies = [ + "urlencoding", +] + +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" diff --git a/Cargo.toml b/Cargo.toml index 72b688f..099c852 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,4 @@ version = "0.1.0" edition = "2024" [dependencies] +urlencoding = "2.1.3" \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 1470cf0..209b96e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,45 @@ +use std::env; use std::fs; use std::path::Path; use std::io::{BufRead, BufReader, Write}; use std::net::{TcpListener, TcpStream}; +use std::process::exit; +use std::sync::OnceLock; +use urlencoding::decode; use mini_web_server::ThreadPool; +static GLOBAL_PATH: OnceLock = OnceLock::new(); + + fn main() { - let listener = TcpListener::bind("127.0.0.1:7878").unwrap(); + let args: Vec = env::args().collect(); + let addr = String::from("127.0.0.1:7878"); + + match args.len() { + 1 => { + let path = String::from("./public"); + GLOBAL_PATH.set(path).unwrap(); + } + 2 => { + let path = String::from(&args[1]); + GLOBAL_PATH.set(path).unwrap(); + if Path::new(GLOBAL_PATH.get().unwrap()).try_exists().is_err() { + println!("path is invalid or not exist: {}", GLOBAL_PATH.get().unwrap()); + exit(0); + } + } + + _ => { + println!("usage: {} path/to/asset/", args[0]); + exit(0); + } + } + println!("set root path: {}", GLOBAL_PATH.get().unwrap()); + + + let listener = TcpListener::bind(addr).unwrap(); + println!("serverd on http://{}/", listener.local_addr().unwrap()); let pool = ThreadPool::new(10); @@ -66,7 +99,6 @@ fn handle_connection(mut stream: TcpStream) { } // 6. 安全地构建文件路径 - let base_dir = "./public"; let requested_path = path.trim_start_matches('/'); if requested_path.contains("..") { stream.write(b"HTTP/1.1 403 Forbidden\r\n\r\n").unwrap(); @@ -92,9 +124,19 @@ fn handle_connection(mut stream: TcpStream) { // }; // v4 - let base_path = Path::new(base_dir); + let base_path = Path::new(GLOBAL_PATH.get().unwrap()); let requested_path = path.trim_start_matches('/'); // 去掉开头的 '/' + let requested_path = match decode(requested_path) { + Ok(decoded) => decoded.into_owned(), + + Err(_) => { + let _ = stream.write(b"HTTP/1.1 400 Bad Request\r\n\r\n"); + let _ = stream.flush(); + return; + } + }; + // 禁止路径遍历 if requested_path.contains("..") { stream.write(b"HTTP/1.1 403 Forbidden\r\n\r\n").unwrap();