diff --git a/sqlparser_bench/benches/sqlparser_bench.rs b/sqlparser_bench/benches/sqlparser_bench.rs index a3c39bc54..8654a313f 100644 --- a/sqlparser_bench/benches/sqlparser_bench.rs +++ b/sqlparser_bench/benches/sqlparser_bench.rs @@ -245,6 +245,34 @@ fn parse_prefix_case_chain(c: &mut Criterion) { group.finish(); } +/// Benchmark parsing pathological paren chains that previously caused 2^N +/// work in `parse_table_factor`. The input `SELECT 1 FROM ((((...` rejects +/// at EOF, which used to force exponential backtracking through the chain. +fn parse_table_factor_paren_chain(c: &mut Criterion) { + let mut group = c.benchmark_group("parse_table_factor_paren_chain"); + let dialect = GenericDialect {}; + + for &n in &[10usize, 20, 30] { + let mut sql = String::from("SELECT 1 "); + for _ in 0..5 { + sql.push_str("FROM "); + sql.push_str(&"(".repeat(n)); + sql.push(' '); + } + + group.bench_function(format!("chain_{n}"), |b| { + b.iter(|| { + let _ = Parser::new(&dialect) + .with_recursion_limit(256) + .try_with_sql(std::hint::black_box(&sql)) + .and_then(|mut p| p.parse_statements()); + }); + }); + } + + group.finish(); +} + criterion_group!( benches, basic_queries, @@ -253,6 +281,7 @@ criterion_group!( parse_compound_chain, parse_compound_keyword_chain, parse_prefix_keyword_call_chain, - parse_prefix_case_chain + parse_prefix_case_chain, + parse_table_factor_paren_chain ); criterion_main!(benches); diff --git a/src/parser/mod.rs b/src/parser/mod.rs index a7e641f98..28ce22cbd 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -15,7 +15,7 @@ #[cfg(not(feature = "std"))] use alloc::{ boxed::Box, - collections::BTreeMap, + collections::{BTreeMap, BTreeSet}, format, string::{String, ToString}, vec, @@ -26,7 +26,7 @@ use core::{ str::FromStr, }; #[cfg(feature = "std")] -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use helpers::attached_token::AttachedToken; @@ -369,6 +369,10 @@ pub struct Parser<'a> { /// Cached failures from the speculative reserved-word prefix arm. See /// [`Parser::parse_prefix`] for the 2^N patterns this guards. failed_reserved_word_prefix_positions: BTreeMap, + /// Cached failures from the speculative derived-table arm of + /// `parse_table_factor`. See [`Parser::parse_table_factor`] for the 2^N + /// pattern this guards. + failed_derived_table_factor_positions: BTreeSet, } /// Copy marker for a [`ParserError`] cached by the `parse_prefix` failure @@ -414,6 +418,7 @@ impl<'a> Parser<'a> { options: ParserOptions::new().with_trailing_commas(dialect.supports_trailing_commas()), failed_prefix_positions: BTreeMap::new(), failed_reserved_word_prefix_positions: BTreeMap::new(), + failed_derived_table_factor_positions: BTreeSet::new(), } } @@ -477,6 +482,7 @@ impl<'a> Parser<'a> { self.index = 0; self.failed_prefix_positions.clear(); self.failed_reserved_word_prefix_positions.clear(); + self.failed_derived_table_factor_positions.clear(); self } @@ -16172,9 +16178,27 @@ impl<'a> Parser<'a> { // `parse_derived_table_factor` below will return success after parsing the // subquery, followed by the closing ')', and the alias of the derived table. // In the example above this is case (3). - if let Some(mut table) = - self.maybe_parse(|parser| parser.parse_derived_table_factor(NotLateral))? + // + // Memoize failures to break the 2^N work on inputs like + // `FROM ((((...`, where the nested-join fallback recurses back into + // `parse_table_factor` and re-attempts the same speculative parse. + let derived_pos = self.index; + let derived = if self + .failed_derived_table_factor_positions + .contains(&derived_pos) { + None + } else { + match self.maybe_parse(|parser| parser.parse_derived_table_factor(NotLateral))? { + Some(t) => Some(t), + None => { + self.failed_derived_table_factor_positions + .insert(derived_pos); + None + } + } + }; + if let Some(mut table) = derived { while let Some(kw) = self.parse_one_of_keywords(&[Keyword::PIVOT, Keyword::UNPIVOT]) { table = match kw { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index b561f8935..6bcdf4d3b 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -19111,3 +19111,34 @@ fn parse_prefix_case_chain_no_exponential_blowup() { rx.recv_timeout(Duration::from_secs(5)) .expect("parser should reject this quickly, not loop exponentially"); } + +/// Regression test for the 2^N parse-time blowup in `parse_table_factor` on +/// inputs like `SELECT 1 FROM ((((...`. The speculative derived-table arm +/// and the nested-join fallback both recurse through the remaining paren +/// chain, doubling work per level. Post-fix the per-position failure cache +/// short-circuits the second descent. +#[test] +fn parse_table_factor_paren_chain_no_exponential_blowup() { + use std::sync::mpsc; + use std::thread; + use std::time::Duration; + + let mut sql = String::from("SELECT 1 "); + for _ in 0..5 { + sql.push_str("FROM "); + sql.push_str(&"(".repeat(30)); + sql.push(' '); + } + + let (tx, rx) = mpsc::channel(); + thread::spawn(move || { + let _ = Parser::new(&GenericDialect {}) + .with_recursion_limit(256) + .try_with_sql(&sql) + .and_then(|mut p| p.parse_statements()); + let _ = tx.send(()); + }); + + rx.recv_timeout(Duration::from_secs(5)) + .expect("parser should reject this quickly, not loop exponentially"); +}