subplot_grid/
subplot_grid.rs

1use plotlars::{
2    Arrangement, Axis, BarPlot, BoxPlot, CandlestickPlot, ColorBar, Direction, HeatMap, Histogram,
3    Legend, Line, Mode, Orientation, Palette, Plot, Rgb, SankeyDiagram, Scatter3dPlot, ScatterGeo,
4    ScatterMap, ScatterPlot, ScatterPolar, Shape, SubplotGrid, Text, TickDirection, TimeSeriesPlot,
5    ValueExponent,
6};
7use polars::prelude::*;
8
9fn main() {
10    regular_grid_example();
11    irregular_grid_example();
12    mixed_grid_example();
13}
14
15fn regular_grid_example() {
16    let dataset1 = LazyCsvReader::new(PlPath::new("data/animal_statistics.csv"))
17        .finish()
18        .unwrap()
19        .collect()
20        .unwrap();
21
22    let plot1 = BarPlot::builder()
23        .data(&dataset1)
24        .labels("animal")
25        .values("value")
26        .orientation(Orientation::Vertical)
27        .group("gender")
28        .sort_groups_by(|a, b| a.len().cmp(&b.len()))
29        .error("error")
30        .colors(vec![Rgb(255, 127, 80), Rgb(64, 224, 208)])
31        .plot_title(Text::from("Bar Plot").x(-0.05).y(1.35).size(14))
32        .y_title(Text::from("value").x(-0.055).y(0.76))
33        .x_title(Text::from("animal").x(0.97).y(-0.2))
34        .legend(
35            &Legend::new()
36                .orientation(Orientation::Horizontal)
37                .x(0.4)
38                .y(1.2),
39        )
40        .build();
41
42    let dataset2 = LazyCsvReader::new(PlPath::new("data/penguins.csv"))
43        .finish()
44        .unwrap()
45        .select([
46            col("species"),
47            col("sex").alias("gender"),
48            col("flipper_length_mm").cast(DataType::Int16),
49            col("body_mass_g").cast(DataType::Int16),
50        ])
51        .collect()
52        .unwrap();
53
54    let axis = Axis::new()
55        .show_line(true)
56        .tick_direction(TickDirection::OutSide)
57        .value_thousands(true);
58
59    let plot2 = ScatterPlot::builder()
60        .data(&dataset2)
61        .x("body_mass_g")
62        .y("flipper_length_mm")
63        .group("species")
64        .sort_groups_by(|a, b| {
65            if a.len() == b.len() {
66                a.cmp(b)
67            } else {
68                a.len().cmp(&b.len())
69            }
70        })
71        .opacity(0.5)
72        .size(12)
73        .colors(vec![Rgb(178, 34, 34), Rgb(65, 105, 225), Rgb(255, 140, 0)])
74        .shapes(vec![Shape::Circle, Shape::Square, Shape::Diamond])
75        .plot_title(Text::from("Scatter Plot").x(-0.075).y(1.35).size(14))
76        .x_title(Text::from("body mass (g)").y(-0.4))
77        .y_title(Text::from("flipper length (mm)").x(-0.078).y(0.5))
78        .legend_title("species")
79        .x_axis(&axis.clone().value_range(vec![2500.0, 6500.0]))
80        .y_axis(&axis.clone().value_range(vec![170.0, 240.0]))
81        .legend(&Legend::new().x(0.98).y(0.95))
82        .build();
83
84    let dataset3 = LazyCsvReader::new(PlPath::new("data/debilt_2023_temps.csv"))
85        .with_has_header(true)
86        .with_try_parse_dates(true)
87        .finish()
88        .unwrap()
89        .with_columns(vec![
90            (col("tavg") / lit(10)).alias("avg"),
91            (col("tmin") / lit(10)).alias("min"),
92            (col("tmax") / lit(10)).alias("max"),
93        ])
94        .collect()
95        .unwrap();
96
97    let plot3 = TimeSeriesPlot::builder()
98        .data(&dataset3)
99        .x("date")
100        .y("avg")
101        .additional_series(vec!["min", "max"])
102        .colors(vec![Rgb(128, 128, 128), Rgb(0, 122, 255), Rgb(255, 128, 0)])
103        .lines(vec![Line::Solid, Line::Dot, Line::Dot])
104        .plot_title(Text::from("Time Series Plot").x(-0.05).y(1.35).size(14))
105        .y_title(Text::from("temperature (ÂșC)").x(-0.055).y(0.6))
106        .legend(&Legend::new().x(0.9).y(1.25))
107        .build();
108
109    let plot4 = BoxPlot::builder()
110        .data(&dataset2)
111        .labels("species")
112        .values("body_mass_g")
113        .orientation(Orientation::Vertical)
114        .group("gender")
115        .box_points(true)
116        .point_offset(-1.5)
117        .jitter(0.01)
118        .opacity(0.1)
119        .colors(vec![Rgb(0, 191, 255), Rgb(57, 255, 20), Rgb(255, 105, 180)])
120        .plot_title(Text::from("Box Plot").x(-0.075).y(1.35).size(14))
121        .x_title(Text::from("species").y(-0.3))
122        .y_title(Text::from("body mass (g)").x(-0.08).y(0.5))
123        .legend_title(Text::from("gender").size(12))
124        .y_axis(&Axis::new().value_thousands(true))
125        .legend(&Legend::new().x(1.0))
126        .build();
127
128    SubplotGrid::regular()
129        .plots(vec![&plot1, &plot2, &plot3, &plot4])
130        .rows(2)
131        .cols(2)
132        .v_gap(0.4)
133        .title(
134            Text::from("Regular Subplot Grid")
135                .size(16)
136                .font("Arial bold")
137                .y(0.95),
138        )
139        .build()
140        .plot();
141}
142
143fn irregular_grid_example() {
144    let dataset1 = LazyCsvReader::new(PlPath::new("data/penguins.csv"))
145        .finish()
146        .unwrap()
147        .select([
148            col("species"),
149            col("sex").alias("gender"),
150            col("flipper_length_mm").cast(DataType::Int16),
151            col("body_mass_g").cast(DataType::Int16),
152        ])
153        .collect()
154        .unwrap();
155
156    let axis = Axis::new()
157        .show_line(true)
158        .show_grid(true)
159        .value_thousands(true)
160        .tick_direction(TickDirection::OutSide);
161
162    let plot1 = Histogram::builder()
163        .data(&dataset1)
164        .x("body_mass_g")
165        .group("species")
166        .opacity(0.5)
167        .colors(vec![Rgb(255, 165, 0), Rgb(147, 112, 219), Rgb(46, 139, 87)])
168        .plot_title(Text::from("Histogram").x(0.0).y(1.35).size(14))
169        .x_title(Text::from("body mass (g)").x(0.94).y(-0.35))
170        .y_title(Text::from("count").x(-0.062).y(0.83))
171        .x_axis(&axis)
172        .y_axis(&axis)
173        .legend_title(Text::from("species"))
174        .legend(&Legend::new().x(0.87).y(1.2))
175        .build();
176
177    let dataset2 = LazyCsvReader::new(PlPath::new("data/stock_prices.csv"))
178        .finish()
179        .unwrap()
180        .collect()
181        .unwrap();
182
183    let increasing = Direction::new()
184        .line_color(Rgb(0, 200, 100))
185        .line_width(0.5);
186
187    let decreasing = Direction::new()
188        .line_color(Rgb(200, 50, 50))
189        .line_width(0.5);
190
191    let plot2 = CandlestickPlot::builder()
192        .data(&dataset2)
193        .dates("date")
194        .open("open")
195        .high("high")
196        .low("low")
197        .close("close")
198        .increasing(&increasing)
199        .decreasing(&decreasing)
200        .whisker_width(0.1)
201        .plot_title(Text::from("Candlestick").x(0.0).y(1.35).size(14))
202        .y_title(Text::from("price ($)").x(-0.06).y(0.76))
203        .y_axis(&Axis::new().show_axis(true).show_grid(true))
204        .build();
205
206    let dataset3 = LazyCsvReader::new(PlPath::new("data/heatmap.csv"))
207        .finish()
208        .unwrap()
209        .collect()
210        .unwrap();
211
212    let plot3 = HeatMap::builder()
213        .data(&dataset3)
214        .x("x")
215        .y("y")
216        .z("z")
217        .color_bar(
218            &ColorBar::new()
219                .value_exponent(ValueExponent::None)
220                .separate_thousands(true)
221                .tick_length(5)
222                .tick_step(5000.0),
223        )
224        .plot_title(Text::from("Heat Map").x(0.0).y(1.35).size(14))
225        .color_scale(Palette::Viridis)
226        .build();
227
228    SubplotGrid::irregular()
229        .plots(vec![
230            (&plot1, 0, 0, 1, 1),
231            (&plot2, 0, 1, 1, 1),
232            (&plot3, 1, 0, 1, 2),
233        ])
234        .rows(2)
235        .cols(2)
236        .v_gap(0.35)
237        .h_gap(0.05)
238        .title(
239            Text::from("Irregular Subplot Grid")
240                .size(16)
241                .font("Arial bold")
242                .y(0.95),
243        )
244        .build()
245        .plot();
246}
247
248fn mixed_grid_example() {
249    // 2D cartesian scatter (baseline)
250    let penguins = LazyCsvReader::new(PlPath::new("data/penguins.csv"))
251        .finish()
252        .unwrap()
253        .collect()
254        .unwrap()
255        .lazy()
256        .select([
257            col("species"),
258            col("bill_length_mm"),
259            col("flipper_length_mm"),
260            col("body_mass_g"),
261        ])
262        .collect()
263        .unwrap();
264
265    let scatter_2d = ScatterPlot::builder()
266        .data(&penguins)
267        .x("bill_length_mm")
268        .y("flipper_length_mm")
269        .group("species")
270        .opacity(0.65)
271        .size(10)
272        .plot_title(Text::from("Penguins 2D").y(1.3))
273        .build();
274
275    // 3D scene subplot
276    let scatter_3d = Scatter3dPlot::builder()
277        .data(&penguins)
278        .x("bill_length_mm")
279        .y("flipper_length_mm")
280        .z("body_mass_g")
281        .group("species")
282        .opacity(0.35)
283        .size(6)
284        .plot_title(Text::from("Penguins 3D").y(1.45))
285        .build();
286
287    // Polar subplot
288    let polar_df = LazyCsvReader::new(PlPath::new("data/product_comparison_polar.csv"))
289        .finish()
290        .unwrap()
291        .collect()
292        .unwrap();
293
294    let polar = ScatterPolar::builder()
295        .data(&polar_df)
296        .theta("angle")
297        .r("score")
298        .group("product")
299        .mode(Mode::LinesMarkers)
300        .size(10)
301        .plot_title(Text::from("Product Comparison (Polar)").y(1.5).x(0.72))
302        .legend(&Legend::new().x(0.8))
303        .build();
304
305    // Domain-based subplot (Sankey)
306    let sankey_df = LazyCsvReader::new(PlPath::new("data/energy_transition.csv"))
307        .finish()
308        .unwrap()
309        .collect()
310        .unwrap();
311
312    let sankey = SankeyDiagram::builder()
313        .data(&sankey_df)
314        .sources("source")
315        .targets("target")
316        .values("value")
317        .orientation(Orientation::Horizontal)
318        .arrangement(Arrangement::Freeform)
319        .plot_title(Text::from("Energy Flow").y(1.2))
320        .build();
321
322    // Mapbox subplot
323    let map_df = LazyCsvReader::new(PlPath::new("data/cities.csv"))
324        .finish()
325        .unwrap()
326        .collect()
327        .unwrap();
328
329    let scatter_map = ScatterMap::builder()
330        .data(&map_df)
331        .latitude("latitude")
332        .longitude("longitude")
333        .group("city")
334        .zoom(4)
335        .center([50.0, 5.0])
336        .opacity(0.8)
337        .plot_title(Text::from("Cities (Mapbox)").y(1.2))
338        .build();
339
340    // Geo subplot
341    let geo_df = LazyCsvReader::new(PlPath::new("data/world_cities.csv"))
342        .finish()
343        .unwrap()
344        .collect()
345        .unwrap();
346
347    let scatter_geo = ScatterGeo::builder()
348        .data(&geo_df)
349        .lat("lat")
350        .lon("lon")
351        .group("continent")
352        .mode(Mode::Markers)
353        .size(10)
354        .color(Rgb(255, 140, 0))
355        .shape(Shape::Circle)
356        .plot_title(Text::from("Global Cities (Geo)").x(0.65).y(1.2))
357        .legend(&Legend::new().x(0.8))
358        .build();
359
360    SubplotGrid::regular()
361        .plots(vec![
362            &scatter_2d,
363            &scatter_3d,
364            &polar,
365            &sankey,
366            &scatter_map,
367            &scatter_geo,
368        ])
369        .rows(2)
370        .cols(3)
371        .h_gap(0.12)
372        .v_gap(0.22)
373        .title(
374            Text::from("Mixed Subplot Grid")
375                .size(16)
376                .font("Arial bold")
377                .y(0.95),
378        )
379        .build()
380        .plot();
381}