pg_query/
parse_result.rs

1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::iter::FromIterator;
4use std::string::String;
5
6use itertools::join;
7
8use crate::*;
9
10macro_rules! cast {
11    ($target: expr, $pat: path) => {{
12        if let $pat(a) = $target {
13            // #1
14            a
15        } else {
16            panic!("mismatch variant when cast to {}", stringify!($pat)); // #2
17        }
18    }};
19}
20
21impl protobuf::ParseResult {
22    pub fn deparse(&self) -> Result<String> {
23        crate::deparse(self)
24    }
25
26    // Note: this doesn't iterate over every possible node type, since we only care about a subset of nodes.
27    pub fn nodes(&self) -> Vec<(NodeRef, i32, Context, bool)> {
28        self.stmts
29            .iter()
30            .filter_map(|s|
31            // RawStmt  ->  Node   ->    NodeEnum           ->              NodeRef
32            s.stmt.as_ref().and_then(|s| s.node.as_ref()).map(|n| n.nodes()))
33            .flatten()
34            .collect()
35    }
36
37    /// Returns a mutable reference to nested nodes.
38    ///
39    /// # Safety
40    ///
41    /// The caller may have to deal with dangling pointers, and passing an
42    /// invalid tree back to libpg_query may cause it to panic.
43    pub unsafe fn nodes_mut(&mut self) -> Vec<(NodeMut, i32, Context)> {
44        self.stmts
45            .iter_mut()
46            .filter_map(|s|
47            // RawStmt  ->  Node   ->    NodeEnum           ->              NodeMut
48            s.stmt.as_mut().and_then(|s| s.node.as_mut()).map(|n| n.nodes_mut()))
49            .flatten()
50            .collect()
51    }
52}
53
54/// Result from calling [parse]
55#[derive(Debug)]
56pub struct ParseResult {
57    pub protobuf: protobuf::ParseResult,
58    pub warnings: Vec<String>,
59    pub tables: Vec<(String, Context)>,
60    pub aliases: HashMap<String, String>,
61    pub cte_names: Vec<String>,
62    pub functions: Vec<(String, Context)>,
63    pub filter_columns: Vec<(Option<String>, String)>,
64}
65
66impl ParseResult {
67    pub fn new(protobuf: protobuf::ParseResult, stderr: String) -> Self {
68        let warnings = stderr.lines().filter_map(|l| if l.starts_with("WARNING") { Some(l.trim().into()) } else { None }).collect();
69        let mut tables: HashSet<(String, Context)> = HashSet::new();
70        let mut aliases: HashMap<String, String> = HashMap::new();
71        let mut cte_names: HashSet<String> = HashSet::new();
72        let mut functions: HashSet<(String, Context)> = HashSet::new();
73        let mut filter_columns: HashSet<(Option<String>, String)> = HashSet::new();
74
75        for (node, _depth, context, has_filter_columns) in protobuf.nodes().into_iter() {
76            match node {
77                NodeRef::CommonTableExpr(s) => {
78                    cte_names.insert(s.ctename.to_owned());
79                }
80                NodeRef::RangeVar(v) => {
81                    // TODO: this incorrectly returns no tables: parse('with f as (select * from f limit 1) select * from f')
82                    let table = if !v.schemaname.is_empty() { format!("{}.{}", v.schemaname, v.relname) } else { v.relname.to_owned() };
83                    if cte_names.contains(&table) {
84                        continue;
85                    }
86                    tables.insert((table.to_owned(), context));
87                    v.alias.as_ref().and_then(|alias| aliases.insert(alias.aliasname.to_owned(), table));
88                }
89                NodeRef::FuncCall(c) => {
90                    let funcname = join(c.funcname.iter().filter_map(|n| n.node.as_ref().map(|n| &cast!(n, NodeEnum::String).sval)), ".");
91                    functions.insert((funcname, Context::Call));
92                }
93                NodeRef::DropStmt(s) => {
94                    match protobuf::ObjectType::try_from(s.remove_type) {
95                        Ok(protobuf::ObjectType::ObjectTable) => {
96                            for o in &s.objects {
97                                if let Some(NodeEnum::List(list)) = &o.node {
98                                    let table =
99                                        join(list.items.iter().filter_map(|i| i.node.as_ref().map(|n| &cast!(n, NodeEnum::String).sval)), ".");
100                                    tables.insert((table, Context::DDL));
101                                };
102                            }
103                        }
104                        Ok(protobuf::ObjectType::ObjectRule) | Ok(protobuf::ObjectType::ObjectTrigger) => {
105                            for o in &s.objects {
106                                if let Some(NodeEnum::List(list)) = &o.node {
107                                    // Unlike ObjectTable, this ignores the last string (the rule/trigger name)
108                                    let table = join(
109                                        list.items[0..list.items.len() - 1]
110                                            .iter()
111                                            .filter_map(|i| i.node.as_ref().map(|n| &cast!(n, NodeEnum::String).sval)),
112                                        ".",
113                                    );
114                                    tables.insert((table, Context::DDL));
115                                };
116                            }
117                        }
118                        Ok(protobuf::ObjectType::ObjectFunction) => {
119                            // Only one function can be dropped in a statement
120                            if let Some(NodeEnum::ObjectWithArgs(object)) = &s.objects[0].node {
121                                if let Some(NodeEnum::String(string)) = &object.objname[0].node {
122                                    functions.insert((string.sval.to_string(), Context::DDL));
123                                }
124                            }
125                        }
126                        _ => (),
127                    }
128                }
129                NodeRef::CreateFunctionStmt(s) => {
130                    if let Some(NodeEnum::String(string)) = &s.funcname[0].node {
131                        functions.insert((string.sval.to_string(), Context::DDL));
132                    }
133                }
134                NodeRef::RenameStmt(s) => {
135                    if let Ok(protobuf::ObjectType::ObjectFunction) = protobuf::ObjectType::try_from(s.rename_type) {
136                        if let Some(object) = &s.object {
137                            if let Some(NodeEnum::ObjectWithArgs(object)) = &object.node {
138                                if let Some(NodeEnum::String(string)) = &object.objname[0].node {
139                                    functions.insert((string.sval.to_string(), Context::DDL));
140                                    functions.insert((s.newname.to_string(), Context::DDL));
141                                }
142                            }
143                        }
144                    }
145                }
146                NodeRef::ColumnRef(c) => {
147                    if !has_filter_columns {
148                        continue;
149                    }
150                    let f: Vec<String> = c
151                        .fields
152                        .iter()
153                        .filter_map(|n| match n.node.as_ref() {
154                            Some(NodeEnum::String(s)) => Some(s.sval.to_string()),
155                            _ => None,
156                        })
157                        .rev()
158                        .collect();
159                    if f.len() > 0 {
160                        filter_columns.insert((f.get(1).cloned(), f[0].to_string()));
161                    }
162                }
163                _ => (),
164            }
165        }
166
167        Self {
168            protobuf,
169            warnings,
170            tables: Vec::from_iter(tables),
171            aliases,
172            cte_names: Vec::from_iter(cte_names),
173            functions: Vec::from_iter(functions),
174            filter_columns: Vec::from_iter(filter_columns),
175        }
176    }
177
178    /// Returns all referenced tables in the query
179    pub fn tables(&self) -> Vec<String> {
180        let mut tables = HashSet::new();
181        self.tables.iter().for_each(|(t, _c)| {
182            tables.insert(t.to_string());
183        });
184        Vec::from_iter(tables)
185    }
186
187    /// Returns only tables that were selected from
188    pub fn select_tables(&self) -> Vec<String> {
189        self.tables
190            .iter()
191            .filter_map(|(table, context)| match context {
192                Context::Select => Some(table.to_string()),
193                _ => None,
194            })
195            .collect()
196    }
197
198    /// Returns only tables that were modified by the query
199    pub fn dml_tables(&self) -> Vec<String> {
200        self.tables
201            .iter()
202            .filter_map(|(table, context)| match context {
203                Context::DML => Some(table.to_string()),
204                _ => None,
205            })
206            .collect()
207    }
208
209    /// Returns only tables that were modified by DDL statements
210    pub fn ddl_tables(&self) -> Vec<String> {
211        self.tables
212            .iter()
213            .filter_map(|(table, context)| match context {
214                Context::DDL => Some(table.to_string()),
215                _ => None,
216            })
217            .collect()
218    }
219
220    /// Returns all function references
221    pub fn functions(&self) -> Vec<String> {
222        let mut functions = HashSet::new();
223        self.functions.iter().for_each(|(f, _c)| {
224            functions.insert(f.to_string());
225        });
226        Vec::from_iter(functions)
227    }
228
229    /// Returns DDL functions
230    pub fn ddl_functions(&self) -> Vec<String> {
231        self.functions
232            .iter()
233            .filter_map(|(function, context)| match context {
234                Context::DDL => Some(function.to_string()),
235                _ => None,
236            })
237            .collect()
238    }
239
240    /// Returns functions that were called
241    pub fn call_functions(&self) -> Vec<String> {
242        self.functions
243            .iter()
244            .filter_map(|(function, context)| match context {
245                Context::Call => Some(function.to_string()),
246                _ => None,
247            })
248            .collect()
249    }
250
251    /// Converts the parsed query back into a SQL string
252    pub fn deparse(&self) -> Result<String> {
253        crate::deparse(&self.protobuf)
254    }
255
256    /// Intelligently truncates queries to a max length.
257    ///
258    /// # Example
259    ///
260    /// ```rust
261    /// let query = "INSERT INTO \"x\" (a, b, c, d, e, f) VALUES ($1)";
262    /// let result = pg_query::parse(query).unwrap();
263    /// assert_eq!(result.truncate(32).unwrap(), "INSERT INTO x (...) VALUES (...)")
264    /// ```
265    pub fn truncate(&self, max_length: usize) -> Result<String> {
266        crate::truncate(&self.protobuf, max_length)
267    }
268
269    /// Returns all statement types in the query
270    pub fn statement_types(&self) -> Vec<&str> {
271        self.protobuf
272            .stmts
273            .iter()
274            .filter_map(|s| match s.stmt.as_ref().and_then(|s| s.node.as_ref()) {
275                Some(NodeEnum::InsertStmt(..)) => Some("InsertStmt"),
276                Some(NodeEnum::DeleteStmt(..)) => Some("DeleteStmt"),
277                Some(NodeEnum::UpdateStmt(..)) => Some("UpdateStmt"),
278                Some(NodeEnum::SelectStmt(..)) => Some("SelectStmt"),
279                Some(NodeEnum::AlterTableStmt(..)) => Some("AlterTableStmt"),
280                Some(NodeEnum::AlterTableCmd(..)) => Some("AlterTableCmd"),
281                Some(NodeEnum::AlterDomainStmt(..)) => Some("AlterDomainStmt"),
282                Some(NodeEnum::SetOperationStmt(..)) => Some("SetOperationStmt"),
283                Some(NodeEnum::GrantStmt(..)) => Some("GrantStmt"),
284                Some(NodeEnum::GrantRoleStmt(..)) => Some("GrantRoleStmt"),
285                Some(NodeEnum::AlterDefaultPrivilegesStmt(..)) => Some("AlterDefaultPrivilegesStmt"),
286                Some(NodeEnum::ClosePortalStmt(..)) => Some("ClosePortalStmt"),
287                Some(NodeEnum::ClusterStmt(..)) => Some("ClusterStmt"),
288                Some(NodeEnum::CopyStmt(..)) => Some("CopyStmt"),
289                Some(NodeEnum::CreateStmt(..)) => Some("CreateStmt"),
290                Some(NodeEnum::DefineStmt(..)) => Some("DefineStmt"),
291                Some(NodeEnum::DropStmt(..)) => Some("DropStmt"),
292                Some(NodeEnum::TruncateStmt(..)) => Some("TruncateStmt"),
293                Some(NodeEnum::CommentStmt(..)) => Some("CommentStmt"),
294                Some(NodeEnum::FetchStmt(..)) => Some("FetchStmt"),
295                Some(NodeEnum::IndexStmt(..)) => Some("IndexStmt"),
296                Some(NodeEnum::CreateFunctionStmt(..)) => Some("CreateFunctionStmt"),
297                Some(NodeEnum::AlterFunctionStmt(..)) => Some("AlterFunctionStmt"),
298                Some(NodeEnum::DoStmt(..)) => Some("DoStmt"),
299                Some(NodeEnum::RenameStmt(..)) => Some("RenameStmt"),
300                Some(NodeEnum::RuleStmt(..)) => Some("RuleStmt"),
301                Some(NodeEnum::NotifyStmt(..)) => Some("NotifyStmt"),
302                Some(NodeEnum::ListenStmt(..)) => Some("ListenStmt"),
303                Some(NodeEnum::UnlistenStmt(..)) => Some("UnlistenStmt"),
304                Some(NodeEnum::TransactionStmt(..)) => Some("TransactionStmt"),
305                Some(NodeEnum::ViewStmt(..)) => Some("ViewStmt"),
306                Some(NodeEnum::LoadStmt(..)) => Some("LoadStmt"),
307                Some(NodeEnum::CreateDomainStmt(..)) => Some("CreateDomainStmt"),
308                Some(NodeEnum::CreatedbStmt(..)) => Some("CreatedbStmt"),
309                Some(NodeEnum::DropdbStmt(..)) => Some("DropdbStmt"),
310                Some(NodeEnum::VacuumStmt(..)) => Some("VacuumStmt"),
311                Some(NodeEnum::ExplainStmt(..)) => Some("ExplainStmt"),
312                Some(NodeEnum::CreateTableAsStmt(..)) => Some("CreateTableAsStmt"),
313                Some(NodeEnum::CreateSeqStmt(..)) => Some("CreateSeqStmt"),
314                Some(NodeEnum::AlterSeqStmt(..)) => Some("AlterSeqStmt"),
315                Some(NodeEnum::VariableSetStmt(..)) => Some("VariableSetStmt"),
316                Some(NodeEnum::VariableShowStmt(..)) => Some("VariableShowStmt"),
317                Some(NodeEnum::DiscardStmt(..)) => Some("DiscardStmt"),
318                Some(NodeEnum::CreateTrigStmt(..)) => Some("CreateTrigStmt"),
319                Some(NodeEnum::CreatePlangStmt(..)) => Some("CreatePlangStmt"),
320                Some(NodeEnum::CreateRoleStmt(..)) => Some("CreateRoleStmt"),
321                Some(NodeEnum::AlterRoleStmt(..)) => Some("AlterRoleStmt"),
322                Some(NodeEnum::DropRoleStmt(..)) => Some("DropRoleStmt"),
323                Some(NodeEnum::LockStmt(..)) => Some("LockStmt"),
324                Some(NodeEnum::ConstraintsSetStmt(..)) => Some("ConstraintsSetStmt"),
325                Some(NodeEnum::ReindexStmt(..)) => Some("ReindexStmt"),
326                Some(NodeEnum::CheckPointStmt(..)) => Some("CheckPointStmt"),
327                Some(NodeEnum::CreateSchemaStmt(..)) => Some("CreateSchemaStmt"),
328                Some(NodeEnum::AlterDatabaseStmt(..)) => Some("AlterDatabaseStmt"),
329                Some(NodeEnum::AlterDatabaseSetStmt(..)) => Some("AlterDatabaseSetStmt"),
330                Some(NodeEnum::AlterRoleSetStmt(..)) => Some("AlterRoleSetStmt"),
331                Some(NodeEnum::CreateConversionStmt(..)) => Some("CreateConversionStmt"),
332                Some(NodeEnum::CreateCastStmt(..)) => Some("CreateCastStmt"),
333                Some(NodeEnum::CreateOpClassStmt(..)) => Some("CreateOpClassStmt"),
334                Some(NodeEnum::CreateOpFamilyStmt(..)) => Some("CreateOpFamilyStmt"),
335                Some(NodeEnum::AlterOpFamilyStmt(..)) => Some("AlterOpFamilyStmt"),
336                Some(NodeEnum::PrepareStmt(..)) => Some("PrepareStmt"),
337                Some(NodeEnum::ExecuteStmt(..)) => Some("ExecuteStmt"),
338                Some(NodeEnum::DeallocateStmt(..)) => Some("DeallocateStmt"),
339                Some(NodeEnum::DeclareCursorStmt(..)) => Some("DeclareCursorStmt"),
340                Some(NodeEnum::CreateTableSpaceStmt(..)) => Some("CreateTableSpaceStmt"),
341                Some(NodeEnum::DropTableSpaceStmt(..)) => Some("DropTableSpaceStmt"),
342                Some(NodeEnum::AlterObjectDependsStmt(..)) => Some("AlterObjectDependsStmt"),
343                Some(NodeEnum::AlterObjectSchemaStmt(..)) => Some("AlterObjectSchemaStmt"),
344                Some(NodeEnum::AlterOwnerStmt(..)) => Some("AlterOwnerStmt"),
345                Some(NodeEnum::AlterOperatorStmt(..)) => Some("AlterOperatorStmt"),
346                Some(NodeEnum::AlterTypeStmt(..)) => Some("AlterTypeStmt"),
347                Some(NodeEnum::DropOwnedStmt(..)) => Some("DropOwnedStmt"),
348                Some(NodeEnum::ReassignOwnedStmt(..)) => Some("ReassignOwnedStmt"),
349                Some(NodeEnum::CompositeTypeStmt(..)) => Some("CompositeTypeStmt"),
350                Some(NodeEnum::CreateEnumStmt(..)) => Some("CreateEnumStmt"),
351                Some(NodeEnum::CreateRangeStmt(..)) => Some("CreateRangeStmt"),
352                Some(NodeEnum::AlterEnumStmt(..)) => Some("AlterEnumStmt"),
353                Some(NodeEnum::AlterTsdictionaryStmt(..)) => Some("AlterTsdictionaryStmt"),
354                Some(NodeEnum::AlterTsconfigurationStmt(..)) => Some("AlterTsconfigurationStmt"),
355                Some(NodeEnum::CreateFdwStmt(..)) => Some("CreateFdwStmt"),
356                Some(NodeEnum::AlterFdwStmt(..)) => Some("AlterFdwStmt"),
357                Some(NodeEnum::CreateForeignServerStmt(..)) => Some("CreateForeignServerStmt"),
358                Some(NodeEnum::AlterForeignServerStmt(..)) => Some("AlterForeignServerStmt"),
359                Some(NodeEnum::CreateUserMappingStmt(..)) => Some("CreateUserMappingStmt"),
360                Some(NodeEnum::AlterUserMappingStmt(..)) => Some("AlterUserMappingStmt"),
361                Some(NodeEnum::DropUserMappingStmt(..)) => Some("DropUserMappingStmt"),
362                Some(NodeEnum::AlterTableSpaceOptionsStmt(..)) => Some("AlterTableSpaceOptionsStmt"),
363                Some(NodeEnum::AlterTableMoveAllStmt(..)) => Some("AlterTableMoveAllStmt"),
364                Some(NodeEnum::SecLabelStmt(..)) => Some("SecLabelStmt"),
365                Some(NodeEnum::CreateForeignTableStmt(..)) => Some("CreateForeignTableStmt"),
366                Some(NodeEnum::ImportForeignSchemaStmt(..)) => Some("ImportForeignSchemaStmt"),
367                Some(NodeEnum::CreateExtensionStmt(..)) => Some("CreateExtensionStmt"),
368                Some(NodeEnum::AlterExtensionStmt(..)) => Some("AlterExtensionStmt"),
369                Some(NodeEnum::AlterExtensionContentsStmt(..)) => Some("AlterExtensionContentsStmt"),
370                Some(NodeEnum::CreateEventTrigStmt(..)) => Some("CreateEventTrigStmt"),
371                Some(NodeEnum::AlterEventTrigStmt(..)) => Some("AlterEventTrigStmt"),
372                Some(NodeEnum::RefreshMatViewStmt(..)) => Some("RefreshMatViewStmt"),
373                Some(NodeEnum::ReplicaIdentityStmt(..)) => Some("ReplicaIdentityStmt"),
374                Some(NodeEnum::AlterSystemStmt(..)) => Some("AlterSystemStmt"),
375                Some(NodeEnum::CreatePolicyStmt(..)) => Some("CreatePolicyStmt"),
376                Some(NodeEnum::AlterPolicyStmt(..)) => Some("AlterPolicyStmt"),
377                Some(NodeEnum::CreateTransformStmt(..)) => Some("CreateTransformStmt"),
378                Some(NodeEnum::CreateAmStmt(..)) => Some("CreateAmStmt"),
379                Some(NodeEnum::CreatePublicationStmt(..)) => Some("CreatePublicationStmt"),
380                Some(NodeEnum::AlterPublicationStmt(..)) => Some("AlterPublicationStmt"),
381                Some(NodeEnum::CreateSubscriptionStmt(..)) => Some("CreateSubscriptionStmt"),
382                Some(NodeEnum::AlterSubscriptionStmt(..)) => Some("AlterSubscriptionStmt"),
383                Some(NodeEnum::DropSubscriptionStmt(..)) => Some("DropSubscriptionStmt"),
384                Some(NodeEnum::CreateStatsStmt(..)) => Some("CreateStatsStmt"),
385                Some(NodeEnum::AlterCollationStmt(..)) => Some("AlterCollationStmt"),
386                Some(NodeEnum::CallStmt(..)) => Some("CallStmt"),
387                Some(NodeEnum::AlterStatsStmt(..)) => Some("AlterStatsStmt"),
388                _ => None,
389            })
390            .collect()
391    }
392}