pg_query/
query.rs

1use std::ffi::{CStr, CString};
2use std::os::raw::c_char;
3
4use prost::Message;
5
6use crate::bindings::*;
7use crate::error::*;
8use crate::parse_result::ParseResult;
9use crate::protobuf;
10
11/// Represents the resulting fingerprint containing both the raw integer form as well as the
12/// corresponding 16 character hex value.
13pub struct Fingerprint {
14    pub value: u64,
15    pub hex: String,
16}
17
18/// Parses the given SQL statement into the given abstract syntax tree.
19///
20/// # Example
21///
22/// ```rust
23/// use pg_query::{Node, NodeEnum, NodeRef};
24///
25/// let result = pg_query::parse("SELECT * FROM contacts");
26/// assert!(result.is_ok());
27/// let result = result.unwrap();
28/// assert_eq!(result.tables(), vec!["contacts"]);
29/// assert!(matches!(result.protobuf.nodes()[0].0, NodeRef::SelectStmt(_)));
30/// ```
31pub fn parse(statement: &str) -> Result<ParseResult> {
32    let input = CString::new(statement)?;
33    let result = unsafe { pg_query_parse_protobuf(input.as_ptr()) };
34    let parse_result = if !result.error.is_null() {
35        let message = unsafe { CStr::from_ptr((*result.error).message) }.to_string_lossy().to_string();
36        Err(Error::Parse(message))
37    } else {
38        let data = unsafe { std::slice::from_raw_parts(result.parse_tree.data as *const u8, result.parse_tree.len as usize) };
39        let stderr = unsafe { CStr::from_ptr(result.stderr_buffer) }.to_string_lossy().to_string();
40        protobuf::ParseResult::decode(data).map_err(Error::Decode).map(|result| ParseResult::new(result, stderr))
41    };
42    unsafe { pg_query_free_protobuf_parse_result(result) };
43    parse_result
44}
45
46/// Converts a parsed tree back into a string.
47///
48/// # Example
49///
50/// ```rust
51/// use pg_query::{Node, NodeEnum, NodeRef};
52///
53/// let result = pg_query::parse("INSERT INTO other (name) SELECT name FROM contacts");
54/// let result = result.unwrap();
55/// let insert = result.protobuf.nodes()[0].0;
56/// let select = result.protobuf.nodes()[1].0;
57/// assert!(matches!(insert, NodeRef::InsertStmt(_)));
58/// assert!(matches!(select, NodeRef::SelectStmt(_)));
59///
60/// // The entire parse result can be deparsed:
61/// assert_eq!(result.deparse().unwrap(), "INSERT INTO other (name) SELECT name FROM contacts");
62/// // Or an individual node can be deparsed:
63/// assert_eq!(insert.deparse().unwrap(), "INSERT INTO other (name) SELECT name FROM contacts");
64/// assert_eq!(select.deparse().unwrap(), "SELECT name FROM contacts");
65/// ```
66///
67/// Note that this function will panic if called on a node not defined in `deparseStmt`
68pub fn deparse(protobuf: &protobuf::ParseResult) -> Result<String> {
69    let buffer = protobuf.encode_to_vec();
70    let len = buffer.len();
71    let data = buffer.as_ptr() as *const c_char as *mut c_char;
72    let protobuf = PgQueryProtobuf { data, len };
73    let result = unsafe { pg_query_deparse_protobuf(protobuf) };
74
75    let deparse_result = if !result.error.is_null() {
76        let message = unsafe { CStr::from_ptr((*result.error).message) }.to_string_lossy().to_string();
77        Err(Error::Parse(message))
78    } else {
79        let query = unsafe { CStr::from_ptr(result.query) }.to_string_lossy().to_string();
80        Ok(query)
81    };
82
83    unsafe { pg_query_free_deparse_result(result) };
84    deparse_result
85}
86
87/// Normalizes the given SQL statement, returning a parametized version.
88///
89/// # Example
90///
91/// ```rust
92/// let result = pg_query::normalize("SELECT * FROM contacts WHERE name='Paul'");
93/// assert!(result.is_ok());
94/// let result = result.unwrap();
95/// assert_eq!(result, "SELECT * FROM contacts WHERE name=$1");
96/// ```
97pub fn normalize(statement: &str) -> Result<String> {
98    let input = CString::new(statement).unwrap();
99    let result = unsafe { pg_query_normalize(input.as_ptr()) };
100    let normalized_query = if !result.error.is_null() {
101        let message = unsafe { CStr::from_ptr((*result.error).message) }.to_string_lossy().to_string();
102        Err(Error::Parse(message))
103    } else {
104        let n = unsafe { CStr::from_ptr(result.normalized_query) };
105        Ok(n.to_string_lossy().to_string())
106    };
107    unsafe { pg_query_free_normalize_result(result) };
108    normalized_query
109}
110
111/// Fingerprints the given SQL statement. Useful for comparing parse trees across different implementations
112/// of `libpg_query`.
113///
114/// # Example
115///
116/// ```rust
117/// let result = pg_query::fingerprint("SELECT * FROM contacts WHERE name='Paul'");
118/// assert!(result.is_ok());
119/// let result = result.unwrap();
120/// assert_eq!(result.hex, "0e2581a461ece536");
121/// ```
122pub fn fingerprint(statement: &str) -> Result<Fingerprint> {
123    let input = CString::new(statement)?;
124    let result = unsafe { pg_query_fingerprint(input.as_ptr()) };
125    let fingerprint = if !result.error.is_null() {
126        let message = unsafe { CStr::from_ptr((*result.error).message) }.to_string_lossy().to_string();
127        Err(Error::Parse(message))
128    } else {
129        let hex = unsafe { CStr::from_ptr(result.fingerprint_str) };
130        Ok(Fingerprint { value: result.fingerprint, hex: hex.to_string_lossy().to_string() })
131    };
132    unsafe { pg_query_free_fingerprint_result(result) };
133    fingerprint
134}
135
136/// An experimental API which parses a PLPGSQL function. This currently returns the raw JSON structure.
137///
138/// # Example
139///
140/// ```rust
141/// let result = pg_query::parse_plpgsql("
142///     CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
143///     RETURNS varchar AS $$
144///     BEGIN
145///         IF v_version IS NULL THEN
146///             RETURN v_name;
147///         END IF;
148///         RETURN v_name || '/' || v_version;
149///     END;
150///     $$ LANGUAGE plpgsql;
151/// ");
152/// assert!(result.is_ok());
153/// ```
154pub fn parse_plpgsql(stmt: &str) -> Result<serde_json::Value> {
155    let input = CString::new(stmt)?;
156    let result = unsafe { pg_query_parse_plpgsql(input.as_ptr()) };
157    let structure = if !result.error.is_null() {
158        let message = unsafe { CStr::from_ptr((*result.error).message) }.to_string_lossy().to_string();
159        Err(Error::Parse(message))
160    } else {
161        let raw = unsafe { CStr::from_ptr(result.plpgsql_funcs) };
162        serde_json::from_str(&raw.to_string_lossy()).map_err(|e| Error::InvalidJson(e.to_string()))
163    };
164    unsafe { pg_query_free_plpgsql_parse_result(result) };
165    structure
166}
167
168/// Split a well-formed query into separate statements.
169///
170/// # Example
171///
172/// ```rust
173/// let query = r#"select /*;*/ 1; select "2;", (select 3);"#;
174/// let statements = pg_query::split_with_parser(query).unwrap();
175/// assert_eq!(statements, vec!["select /*;*/ 1", r#" select "2;", (select 3)"#]);
176/// ```
177///
178/// However, `split_with_parser` will fail on malformed statements
179///
180/// ```rust
181/// let query = "select 1; this statement is not sql; select 2;";
182/// let result = pg_query::split_with_parser(query);
183/// let err = r#"syntax error at or near "this""#;
184/// assert_eq!(result, Err(pg_query::Error::Split(err.to_string())));
185/// ```
186pub fn split_with_parser(query: &str) -> Result<Vec<&str>> {
187    let input = CString::new(query)?;
188    let result = unsafe { pg_query_split_with_parser(input.as_ptr()) };
189    let split_result = if !result.error.is_null() {
190        let message = unsafe { CStr::from_ptr((*result.error).message) }.to_string_lossy().to_string();
191        Err(Error::Split(message))
192    } else {
193        let n_stmts = result.n_stmts as usize;
194        let mut statements = Vec::with_capacity(n_stmts);
195        for offset in 0..n_stmts {
196            let split_stmt = unsafe { *result.stmts.add(offset).read() };
197            let start = split_stmt.stmt_location as usize;
198            let end = start + split_stmt.stmt_len as usize;
199            statements.push(&query[start..end]);
200            // not sure the start..end slice'll hold up for non-utf8 charsets
201        }
202        Ok(statements)
203    };
204    unsafe { pg_query_free_split_result(result) };
205    split_result
206}
207
208/// Scan a sql query into a its component of tokens.
209///
210/// # Example
211///
212/// ```rust
213/// use pg_query::protobuf::*;
214/// let sql = "SELECT update AS left /* comment */ FROM between";
215/// let result = pg_query::scan(sql).unwrap();
216/// let tokens: Vec<std::string::String> = result.tokens.iter().map(|token| {
217///     format!("{:?}", token)
218/// }).collect();
219/// assert_eq!(
220///     tokens,
221///     vec![
222///         "ScanToken { start: 0, end: 6, token: Select, keyword_kind: ReservedKeyword }",
223///         "ScanToken { start: 7, end: 13, token: Update, keyword_kind: UnreservedKeyword }",
224///         "ScanToken { start: 14, end: 16, token: As, keyword_kind: ReservedKeyword }",
225///         "ScanToken { start: 17, end: 21, token: Left, keyword_kind: TypeFuncNameKeyword }",
226///         "ScanToken { start: 22, end: 35, token: CComment, keyword_kind: NoKeyword }",
227///         "ScanToken { start: 36, end: 40, token: From, keyword_kind: ReservedKeyword }",
228///         "ScanToken { start: 41, end: 48, token: Between, keyword_kind: ColNameKeyword }"
229///     ]);
230/// ```
231pub fn scan(sql: &str) -> Result<protobuf::ScanResult> {
232    let input = CString::new(sql)?;
233    let result = unsafe { pg_query_scan(input.as_ptr()) };
234    let scan_result = if !result.error.is_null() {
235        let message = unsafe { CStr::from_ptr((*result.error).message) }.to_string_lossy().to_string();
236        Err(Error::Scan(message))
237    } else {
238        let data = unsafe { std::slice::from_raw_parts(result.pbuf.data as *const u8, result.pbuf.len as usize) };
239        protobuf::ScanResult::decode(data).map_err(Error::Decode)
240    };
241    unsafe { pg_query_free_scan_result(result) };
242    scan_result
243}
244
245/// Split a potentially-malformed query into separate statements. Note that
246/// invalid tokens will be skipped
247/// ```rust
248/// let query = r#"select /*;*/ 1; asdf; select "2;", (select 3); asdf"#;
249/// let statements = pg_query::split_with_scanner(query).unwrap();
250/// assert_eq!(statements, vec![
251///     "select /*;*/ 1",
252///     // skipped " asdf" since it was an invalid token
253///     r#" select "2;", (select 3)"#,
254/// ]);
255/// ```
256pub fn split_with_scanner(query: &str) -> Result<Vec<&str>> {
257    let input = CString::new(query)?;
258    let result = unsafe { pg_query_split_with_scanner(input.as_ptr()) };
259    let split_result = if !result.error.is_null() {
260        let message = unsafe { CStr::from_ptr((*result.error).message) }.to_string_lossy().to_string();
261        Err(Error::Split(message))
262    } else {
263        // don't use result.stderr_buffer since it appears unused unless
264        // libpg_query is compiled with DEBUG defined.
265        let n_stmts = result.n_stmts as usize;
266        let mut start: usize;
267        let mut end: usize;
268        let mut statements = Vec::with_capacity(n_stmts);
269        for offset in 0..n_stmts {
270            let split_stmt = unsafe { *result.stmts.add(offset).read() };
271            start = split_stmt.stmt_location as usize;
272            // TODO: consider comparing the new value of start to the old value
273            // of end to see if any region larger than a statement-separator got skipped
274            end = start + split_stmt.stmt_len as usize;
275            statements.push(&query[start..end]);
276        }
277        Ok(statements)
278    };
279    unsafe { pg_query_free_split_result(result) };
280    split_result
281}