Skip to content

Commit

Permalink
Ver 0.35.1
Browse files Browse the repository at this point in the history
Implement PlotType
  • Loading branch information
Axect committed Mar 29, 2024
2 parents dad6071 + 86ac163 commit db37f51
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "peroxide"
version = "0.35.0"
version = "0.35.1"
authors = ["axect <[email protected]>"]
edition = "2018"
description = "Rust comprehensive scientific computation library contains linear algebra, numerical analysis, statistics and machine learning tools with farmiliar syntax"
Expand Down
7 changes: 7 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Release 0.35.1 (2024-03-29)

- Add `PlotType` for `Plot2D`
- `PlotType::Scatter`
- `PlotType::Line` (default)
- `PlotType::Bar`

# Release 0.35.0 (2024-03-29)

## Change some plot functions
Expand Down
Binary file modified example_data/test_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
79 changes: 73 additions & 6 deletions src/util/plot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@
//! let y1 = x.fmap(|t| t.powi(2));
//! let y2 = x.fmap(|t| t.powi(3));
//!
//! let normal = Normal(0f64, 0.1);
//! let eps = normal.sample(100);
//! let y3 = y2.add_v(&eps);
//!
//! let mut plt = Plot2D::new();
//! plt.set_domain(x)
//! .insert_image(y1)
//! .insert_image(y2)
//! .set_legend(vec![r"$y=x^2$", r"$y=x^3$"])
//! .insert_image(y3)
//! .set_legend(vec![r"$y=x^2$", r"$y=x^3$", r"$y=x^2 + \epsilon$"])
//! .set_line_style(vec![(0, LineStyle::Dashed), (1, LineStyle::Dotted)])
//! .set_color(vec![(0, "red"), (1, "darkblue")])
//! .set_plot_type(vec![(2, PlotType::Scatter)])
//! .set_marker(vec![(2, Markers::Point)])
//! .set_color(vec![(0, "red"), (1, "darkblue"), (2, "olive")])
//! .set_xlabel(r"$x$")
//! .set_ylabel(r"$y$")
//! .set_style(PlotStyle::Nature) // if you want to use scienceplots
Expand Down Expand Up @@ -63,8 +70,9 @@
//! - `set_style` : Set style of plot (`PlotStyle::Nature`, `PlotStyle::IEEE`, `PlotStyle::Default` (default), `PlotStyle::Science`)
//! - `tight_layout` : Set tight layout of plot (optional)
//! - `set_line_style` : Set line style of plot (optional; `LineStyle::{Solid, Dashed, Dotted, DashDot}`)
//! - `set_color` : Set color of plot (optional; Vec<&str>)
//! - `set_alpha` : Set alpha of plot (optional; Vec<f64>)
//! - `set_color` : Set color of plot (optional; Vec<(usize, &str)>)
//! - `set_alpha` : Set alpha of plot (optional; Vec<(usize, f64)>)
//! - `set_plot_type` : Set plot type of plot (optional; `PlotType::{Scatter, Line, Bar}`)
//! - `savefig` : Save plot with given path
extern crate pyo3;
Expand Down Expand Up @@ -180,6 +188,24 @@ pub enum PlotScale {
Log,
}

#[derive(Debug, Copy, Clone, Hash, PartialOrd, PartialEq, Eq)]
pub enum PlotType {
Scatter,
Line,
Bar,
}

impl Display for PlotType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let str = match self {
PlotType::Scatter => "scatter".to_string(),
PlotType::Line => "line".to_string(),
PlotType::Bar => "bar".to_string(),
};
write!(f, "{}", str)
}
}

pub trait Plot {
fn set_domain(&mut self, x: Vec<f64>) -> &mut Self;
fn insert_image(&mut self, y: Vec<f64>) -> &mut Self;
Expand All @@ -203,6 +229,7 @@ pub trait Plot {
fn set_line_style(&mut self, style: Vec<(usize, LineStyle)>) -> &mut Self;
fn set_color(&mut self, color: Vec<(usize, &str)>) -> &mut Self;
fn set_alpha(&mut self, alpha: Vec<(usize, f64)>) -> &mut Self;
fn set_plot_type(&mut self, plot_type: Vec<(usize, PlotType)>) -> &mut Self;
fn savefig(&self) -> PyResult<()>;
}

Expand All @@ -229,6 +256,7 @@ pub struct Plot2D {
grid: Grid,
style: PlotStyle,
tight: bool,
plot_type: Vec<(usize, PlotType)>,
options: HashMap<PlotOptions, bool>,
}

Expand Down Expand Up @@ -262,6 +290,7 @@ impl Plot2D {
grid: On,
style: PlotStyle::Default,
tight: false,
plot_type: vec![],
options: default_options,
}
}
Expand Down Expand Up @@ -392,6 +421,11 @@ impl Plot for Plot2D {
self
}

fn set_plot_type(&mut self, plot_type: Vec<(usize, PlotType)>) -> &mut Self {
self.plot_type = plot_type;
self
}

fn savefig(&self) -> PyResult<()> {
// Check domain
match self.options.get(&Domain) {
Expand Down Expand Up @@ -456,6 +490,7 @@ impl Plot for Plot2D {
let line_style = self.line_style.iter().map(|(i, x)| (i, format!("{}", x))).collect::<Vec<_>>();
let color = self.color.clone();
let alpha = self.alpha.clone();
let plot_type = self.plot_type.clone();

// Global variables to plot
let globals = vec![("plt", py.import("matplotlib.pyplot")?)].into_py_dict(py);
Expand Down Expand Up @@ -553,7 +588,23 @@ impl Plot for Plot2D {
let alpha = alpha.iter().find(|(j, _)| j == &i).unwrap().1;
inner_string.push_str(&format!(",alpha={}", alpha)[..]);
}
plot_string.push_str(&format!("plt.plot({})\n", inner_string)[..]);
let is_corresponding_plot_type = !plot_type.is_empty() && (plot_type.iter().any(|(j, _)| j == &i));
if is_corresponding_plot_type {
let plot_type = plot_type.iter().find(|(j, _)| j == &i).unwrap().1;
match plot_type {
PlotType::Scatter => {
plot_string.push_str(&format!("plt.scatter({})\n", inner_string)[..]);
}
PlotType::Line => {
plot_string.push_str(&format!("plt.plot({})\n", inner_string)[..]);
}
PlotType::Bar => {
plot_string.push_str(&format!("plt.bar({})\n", inner_string)[..]);
}
}
} else {
plot_string.push_str(&format!("plt.plot({})\n", inner_string)[..]);
}
}
for i in 0..pair_length {
let mut inner_string = format!("pair[{}][0],pair[{}][1]", i, i);
Expand All @@ -580,7 +631,23 @@ impl Plot for Plot2D {
let alpha = alpha.iter().find(|(j, _)| j == &(i + y_length)).unwrap().1;
inner_string.push_str(&format!(",alpha={}", alpha)[..]);
}
plot_string.push_str(&format!("plt.plot({})\n", inner_string)[..]);
let is_corresponding_plot_type = !plot_type.is_empty() && (plot_type.iter().any(|(j, _)| j == &(i + y_length)));
if is_corresponding_plot_type {
let plot_type = plot_type.iter().find(|(j, _)| j == &(i + y_length)).unwrap().1;
match plot_type {
PlotType::Scatter => {
plot_string.push_str(&format!("plt.scatter({})\n", inner_string)[..]);
}
PlotType::Line => {
plot_string.push_str(&format!("plt.plot({})\n", inner_string)[..]);
}
PlotType::Bar => {
plot_string.push_str(&format!("plt.bar({})\n", inner_string)[..]);
}
}
} else {
plot_string.push_str(&format!("plt.plot({})\n", inner_string)[..]);
}
}

if self.tight {
Expand Down

0 comments on commit db37f51

Please sign in to comment.