paged.rs 10.3 KB
Newer Older
1
2
3
4
5
6
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

7
use reqwest::header::{self, HeaderMap};
8
9
10
11
use serde::de::DeserializeOwned;
use thiserror::Error;
use url::Url;

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
use crate::api::{ApiError, Client, Endpoint, Query};

/// Errors which may occur with pagination.
#[derive(Debug, Error)]
// TODO #[non_exhaustive]
pub enum PaginationError {
    /// A `Link` HTTP header can fail to parse.
    #[error("failed to parse a Link HTTP header: {}", source)]
    LinkHeader {
        /// The source of the error.
        #[from]
        source: LinkHeaderParseError,
    },
    /// An invalid URL can be returned.
    #[error("failed to parse a Link HTTP header URL: {}", source)]
    InvalidUrl {
        /// The source of the error.
        #[from]
        source: url::ParseError,
    },
    /// This is here to force `_` matching right now.
    ///
    /// **DO NOT USE**
    #[doc(hidden)]
    #[error("unreachable...")]
    _NonExhaustive,
}
39

40
#[derive(Debug)]
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
struct LinkHeader<'a> {
    url: &'a str,
    params: Vec<(&'a str, &'a str)>,
}

impl<'a> LinkHeader<'a> {
    fn parse(s: &'a str) -> Result<Self, LinkHeaderParseError> {
        let mut parts = s.split(';');

        let url_part = parts.next().expect("a split always has at least one part");
        let url = {
            let part = url_part.trim();
            if part.starts_with('<') && part.ends_with('>') {
                &part[1..part.len() - 1]
            } else {
                return Err(LinkHeaderParseError::NoBrackets);
            }
        };

        let params = parts
            .map(|part| {
                let part = part.trim();
                let mut halves = part.splitn(2, '=');
                let key = halves.next().expect("a split always has at least one part");
                let value = if let Some(value) = halves.next() {
                    if value.starts_with('"') && value.ends_with('"') {
                        &value[1..value.len() - 1]
                    } else {
                        value
                    }
                } else {
                    return Err(LinkHeaderParseError::MissingParamValue);
                };

                Ok((key, value))
            })
            .collect::<Result<Vec<_>, LinkHeaderParseError>>()?;

        Ok(Self {
            url,
            params,
        })
    }
}

/// An error which can occur when parsing a link header.
#[derive(Debug, Error)]
pub enum LinkHeaderParseError {
    /// An invalid HTTP header was found.
    #[error("invalid header")]
    InvalidHeader {
        /// The source of the error.
        #[from]
        source: reqwest::header::ToStrError,
    },
    /// The `url` for a `Link` header was missing `<>` brackets.
    #[error("missing brackets around url")]
    NoBrackets,
    /// A parameter for a `Link` header was missing a value.
    #[error("missing parameter value")]
    MissingParamValue,
}

impl LinkHeaderParseError {
    fn invalid_header(source: reqwest::header::ToStrError) -> Self {
        Self::InvalidHeader {
            source,
        }
    }
}

/// Pagination options for GitLab.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Pagination {
    /// Return all results.
    ///
    /// Note that some endpoints may have a server-side limit to the number of results (e.g.,
    /// `/projects` is limited to 10000 results).
    All,
    /// Limit to a number of results.
    Limit(usize),
}

impl Default for Pagination {
    fn default() -> Self {
        Pagination::All
    }
}

const MAX_PAGE_SIZE: usize = 100;

impl Pagination {
    fn page_limit(self) -> usize {
        match self {
            Pagination::All => MAX_PAGE_SIZE,
            Pagination::Limit(size) => size.min(MAX_PAGE_SIZE),
        }
    }

    fn is_last_page<T>(self, last_page_size: usize, results: &[T]) -> bool {
        // If the last page has fewer elements than our limit, we're definitely done.
        if last_page_size < self.page_limit() {
            return true;
        }

        // Otherwise, check if we have results which fill our limit.
        if let Pagination::Limit(limit) = self {
            return limit <= results.len();
        }

        // We're not done yet.
        false
    }
}

/// A query modifier that paginates an endpoint.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Paged<E> {
    endpoint: E,
    pagination: Pagination,
}

/// Collect data from a paged endpoint.
pub fn paged<E>(endpoint: E, pagination: Pagination) -> Paged<E> {
    Paged {
        endpoint,
        pagination,
    }
}

/// A trait to indicate that an endpoint is pageable.
pub trait Pageable {
    /// Whether the endpoint uses keyset pagination or not.
    fn use_keyset_pagination(&self) -> bool {
        false
    }
}

179
impl<E, T, C> Query<Vec<T>, C> for Paged<E>
180
where
181
    E: Endpoint,
182
183
    E: Pageable,
    T: DeserializeOwned,
184
    C: Client,
185
{
186
    fn query(&self, client: &C) -> Result<Vec<T>, ApiError<C::Error>> {
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        let url = {
            let mut url = client.rest_endpoint(&self.endpoint.endpoint())?;
            self.endpoint.add_parameters(url.query_pairs_mut());
            url
        };

        let mut page_num = 1;
        let per_page = self.pagination.page_limit();
        let per_page_str = format!("{}", per_page);

        let mut results = Vec::new();
        let mut next_url = None;
        let use_keyset_pagination = self.endpoint.use_keyset_pagination();

201
202
        let body = self.endpoint.body()?;

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        loop {
            let page_url = if let Some(url) = next_url.take() {
                url
            } else {
                let page_str = format!("{}", page_num);
                let mut page_url = url.clone();

                {
                    let mut pairs = page_url.query_pairs_mut();
                    pairs.append_pair("per_page", &per_page_str);

                    if use_keyset_pagination {
                        pairs.append_pair("pagination", "keyset");
                    } else {
                        pairs.append_pair("page", &page_str);
                    }
                }

                page_url
            };

224
            let req = client.build_rest(self.endpoint.method(), page_url);
225
226
227
228
229
            let req = if let Some((mime, data)) = body.as_ref() {
                req.header(header::CONTENT_TYPE, *mime).body(data.clone())
            } else {
                req
            };
230
231
232
233
234
235
236
            let rsp = client.rest(req)?;
            let status = rsp.status();

            if use_keyset_pagination {
                next_url = next_page_from_headers(rsp.headers())?;
            }

237
            let v = serde_json::from_reader(rsp)?;
238
            if !status.is_success() {
239
                return Err(ApiError::from_gitlab(v));
240
241
242
            }

            let page =
243
                serde_json::from_value::<Vec<T>>(v).map_err(ApiError::data_type::<Vec<T>>)?;
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
            let page_len = page.len();
            results.extend(page);

            // Gitlab used to have issues returning paginated results; these have been fixed since,
            // but if it is needed, the bug manifests as Gitlab returning *all* results instead of
            // just the requested results. This can cause an infinite loop here if the number of
            // total results is exactly equal to `per_page`.
            if self.pagination.is_last_page(page_len, &results) {
                break;
            }

            if use_keyset_pagination {
                if next_url.is_none() {
                    break;
                }
            } else {
                page_num += 1;
            }
        }

        Ok(results)
    }
}

fn next_page_from_headers(headers: &HeaderMap) -> Result<Option<Url>, PaginationError> {
    headers
        .get_all(reqwest::header::LINK)
        .iter()
        .map(|link| {
            let value = link
                .to_str()
                .map_err(LinkHeaderParseError::invalid_header)?;
            Ok(LinkHeader::parse(value)?)
        })
        .collect::<Result<Vec<_>, PaginationError>>()?
        .into_iter()
        .filter_map(|header| {
            let is_next_link = header
                .params
                .into_iter()
                .any(|(key, value)| key == "rel" && value == "next");

            if is_next_link {
                Some(header.url.parse().map_err(PaginationError::from))
            } else {
                None
            }
        })
        .next()
        .transpose()
}
295
296
297

#[cfg(test)]
mod tests {
Ben Boeckel's avatar
Ben Boeckel committed
298
299
    use crate::api::paged::LinkHeader;
    use crate::api::{LinkHeaderParseError, Pagination};
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348

    #[test]
    fn test_link_header_no_brackets() {
        let err = LinkHeader::parse("url; param=value").unwrap_err();
        if let LinkHeaderParseError::NoBrackets = err {
            // expected error
        } else {
            panic!("unexpected error: {}", err);
        }
    }

    #[test]
    fn test_link_header_no_param_value() {
        let err = LinkHeader::parse("<url>; param").unwrap_err();
        if let LinkHeaderParseError::MissingParamValue = err {
            // expected error
        } else {
            panic!("unexpected error: {}", err);
        }
    }

    #[test]
    fn test_link_header_no_params() {
        let link = LinkHeader::parse("<url>").unwrap();
        assert_eq!(link.url, "url");
        assert_eq!(link.params.len(), 0);
    }

    #[test]
    fn test_link_header_quoted_params() {
        let link = LinkHeader::parse("<url>; param=\"value\"; param2=\"value\"").unwrap();
        assert_eq!(link.url, "url");
        assert_eq!(link.params.len(), 2);
        assert_eq!(link.params[0].0, "param");
        assert_eq!(link.params[0].1, "value");
        assert_eq!(link.params[1].0, "param2");
        assert_eq!(link.params[1].1, "value");
    }

    #[test]
    fn test_link_header_bare_params() {
        let link = LinkHeader::parse("<url>; param=value; param2=value").unwrap();
        assert_eq!(link.url, "url");
        assert_eq!(link.params.len(), 2);
        assert_eq!(link.params[0].0, "param");
        assert_eq!(link.params[0].1, "value");
        assert_eq!(link.params[1].0, "param2");
        assert_eq!(link.params[1].1, "value");
    }
Ben Boeckel's avatar
Ben Boeckel committed
349
350
351
352
353

    #[test]
    fn pagination_default() {
        assert_eq!(Pagination::default(), Pagination::All);
    }
354
}