Skip to content

Commit

Permalink
Merge branch 'features/plot' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
Axect committed Mar 11, 2024
2 parents 0add558 + f26c33a commit d98d12c
Showing 1 changed file with 131 additions and 78 deletions.
209 changes: 131 additions & 78 deletions src/util/plot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ pub use self::Grid::{Off, On};
pub use self::Markers::{Circle, Line, Point};
use self::PlotOptions::{Domain, Images, Legends, Pairs, Path};
use std::collections::HashMap;
use std::fmt::Display;

type Vector = Vec<f64>;

Expand All @@ -107,6 +108,69 @@ pub enum Markers {
Point,
Line,
Circle,
Pixel,
TriangleDown,
TriangleUp,
TriangleLeft,
TriangleRight,
Square,
Pentagon,
Star,
Hexagon1,
Hexagon2,
Plus,
X,
Diamond,
ThinDiamond,
VLine,
HLine,
}

impl Display for Markers {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let str = match self {
Markers::Point => ".".to_string(),
Markers::Line => "-".to_string(),
Markers::Circle => "o".to_string(),
Markers::Pixel => ",".to_string(),
Markers::TriangleDown => "v".to_string(),
Markers::TriangleUp => "^".to_string(),
Markers::TriangleLeft => "<".to_string(),
Markers::TriangleRight => ">".to_string(),
Markers::Square => "s".to_string(),
Markers::Pentagon => "p".to_string(),
Markers::Star => "*".to_string(),
Markers::Hexagon1 => "h".to_string(),
Markers::Hexagon2 => "H".to_string(),
Markers::Plus => "+".to_string(),
Markers::X => "x".to_string(),
Markers::Diamond => "D".to_string(),
Markers::ThinDiamond => "d".to_string(),
Markers::VLine => "|".to_string(),
Markers::HLine => "_".to_string(),
};
write!(f, "{}", str)
}
}

#[derive(Debug, Copy, Clone, Hash, PartialOrd, PartialEq, Eq)]
pub enum LineStyle {
Solid,
Dashed,
Dotted,
DashDot,
}

impl Display for LineStyle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let str = match self {
LineStyle::Solid => "solid".to_string(),
LineStyle::Dashed => "dashed".to_string(),
LineStyle::Dotted => "dotted".to_string(),
LineStyle::DashDot => "dashdot".to_string(),
};
write!(f, "{}", str)
}
}

#[derive(Debug, Copy, Clone, Hash, PartialOrd, PartialEq, Eq)]
Expand Down Expand Up @@ -155,6 +219,9 @@ pub trait Plot {
fn set_marker(&mut self, styles: Vec<Markers>) -> &mut Self;
fn set_style(&mut self, style: PlotStyle) -> &mut Self;
fn tight_layout(&mut self) -> &mut Self;
fn set_line_style(&mut self, style: Vec<LineStyle>) -> &mut Self;
fn set_color(&mut self, color: Vec<&str>) -> &mut Self;
fn set_alpha(&mut self, alpha: Vec<f64>) -> &mut Self;
fn savefig(&self) -> PyResult<()>;
}

Expand All @@ -172,6 +239,9 @@ pub struct Plot2D {
ylim: Option<(f64, f64)>,
legends: Vec<String>,
markers: Vec<Markers>,
line_style: Vec<LineStyle>,
color: Vec<String>,
alpha: Vec<f64>,
path: String,
fig_size: Option<(usize, usize)>,
dpi: usize,
Expand All @@ -187,7 +257,6 @@ impl Plot2D {
default_options.insert(Domain, false);
default_options.insert(Images, false);
default_options.insert(Pairs, false);
default_options.insert(Legends, false);
default_options.insert(Path, false);

Plot2D {
Expand All @@ -203,6 +272,9 @@ impl Plot2D {
ylim: None,
legends: vec![],
markers: vec![],
line_style: vec![],
color: vec![],
alpha: vec![],
path: "".to_string(),
fig_size: None,
dpi: 300,
Expand Down Expand Up @@ -279,9 +351,6 @@ impl Plot for Plot2D {
}

fn set_legend(&mut self, legends: Vec<&str>) -> &mut Self {
if let Some(x) = self.options.get_mut(&Legends) {
*x = true
}
self.legends = legends
.into_iter()
.map(|x| x.to_owned())
Expand Down Expand Up @@ -327,6 +396,21 @@ impl Plot for Plot2D {
self
}

fn set_line_style(&mut self, style: Vec<LineStyle>) -> &mut Self {
self.line_style = style;
self
}

fn set_color(&mut self, color: Vec<&str>) -> &mut Self {
self.color = color.into_iter().map(|x| x.to_owned()).collect();
self
}

fn set_alpha(&mut self, alpha: Vec<f64>) -> &mut Self {
self.alpha = alpha;
self
}

fn savefig(&self) -> PyResult<()> {
// Check domain
match self.options.get(&Domain) {
Expand Down Expand Up @@ -362,21 +446,6 @@ impl Plot for Plot2D {
_ => (),
}

// Check legends
match self.options.get(&Legends) {
Some(x) => {
assert!(*x, "Legends are not defined");
assert_eq!(
self.images.len() + self.pairs.len(),
self.legends.len(),
"Legends are not matched with images"
);
}
None => {
assert!(false, "Legends are None");
}
}

// Plot
Python::with_gil(|py| {
// Input data
Expand All @@ -402,6 +471,10 @@ impl Plot for Plot2D {
let ylabel = self.ylabel.clone();
let legends = self.legends.clone();
let path = self.path.clone();
let markers = self.markers.iter().map(|x| format!("{}", x)).collect::<Vec<String>>();
let line_style = self.line_style.iter().map(|x| format!("{}", x)).collect::<Vec<String>>();
let color = self.color.clone();
let alpha = self.alpha.clone();

// Global variables to plot
let globals = vec![("plt", py.import("matplotlib.pyplot")?)].into_py_dict(py);
Expand Down Expand Up @@ -460,77 +533,57 @@ impl Plot for Plot2D {
plot_string.push_str(&format!("plt.ylabel(r\"{}\")\n", y)[..]);
}
match self.xscale {
PlotScale::Linear => plot_string.push_str(&format!("plt.xscale(\"linear\")\n")[..]),
PlotScale::Log => plot_string.push_str(&format!("plt.xscale(\"log\")\n")[..]),
PlotScale::Linear => plot_string.push_str(&"plt.xscale(\"linear\")\n".to_string()[..]),
PlotScale::Log => plot_string.push_str(&"plt.xscale(\"log\")\n".to_string()[..]),
}
match self.yscale {
PlotScale::Linear => plot_string.push_str(&format!("plt.yscale(\"linear\")\n")[..]),
PlotScale::Log => plot_string.push_str(&format!("plt.yscale(\"log\")\n")[..]),
PlotScale::Linear => plot_string.push_str(&"plt.yscale(\"linear\")\n".to_string()[..]),
PlotScale::Log => plot_string.push_str(&"plt.yscale(\"log\")\n".to_string()[..]),
}
if let Some(xl) = self.xlim {
plot_string.push_str(&format!("plt.xlim(xl)\n")[..]);
plot_string.push_str(&"plt.xlim(xl)\n".to_string()[..]);
}
if let Some(yl) = self.ylim {
plot_string.push_str(&format!("plt.ylim(yl)\n")[..]);
plot_string.push_str(&"plt.ylim(yl)\n".to_string()[..]);
}

if self.markers.len() == 0 {
for i in 0..y_length {
plot_string
.push_str(&format!("plt.plot(x,y[{}],label=r\"{}\")\n", i, legends[i])[..])
for i in 0..y_length {
let mut inner_string = format!("x,y[{}]", i);
if !markers.is_empty() {
inner_string.push_str(&format!(",marker=\"{}\"", markers[i])[..]);
}
for i in 0..pair_length {
plot_string.push_str(
&format!(
"plt.plot(pair[{}][0],pair[{}][1],label=r\"{}\")\n",
i,
i,
legends[i + y_length]
)[..],
)
if !line_style.is_empty() {
inner_string.push_str(&format!(",linestyle=\"{}\"", line_style[i])[..]);
}
} else {
for i in 0..y_length {
match self.markers[i] {
Line => plot_string.push_str(
&format!("plt.plot(x,y[{}],label=r\"{}\")\n", i, legends[i])[..],
),
Point => plot_string.push_str(
&format!("plt.plot(x,y[{}],\".\",label=r\"{}\")\n", i, legends[i])[..],
),
Circle => plot_string.push_str(
&format!("plt.plot(x,y[{}],\"o\",label=r\"{}\")\n", i, legends[i])[..],
),
}
if !color.is_empty() {
inner_string.push_str(&format!(",color=\"{}\"", color[i])[..]);
}
if !legends.is_empty() {
inner_string.push_str(&format!(",label=r\"{}\"", legends[i])[..]);
}
if !alpha.is_empty() {
inner_string.push_str(&format!(",alpha={}", alpha[i])[..]);
}
plot_string.push_str(&format!("plt.plot({})\n", inner_string)[..]);
}
for i in 0..pair_length {
let mut inner_string = format!("pair[{}][0],pair[{}][1]", i, i);
if !markers.is_empty() {
inner_string.push_str(&format!(",marker=\"{}\"", markers[i + y_length])[..]);
}
if !line_style.is_empty() {
inner_string.push_str(&format!(",linestyle=\"{}\"", line_style[i + y_length])[..]);
}
if !color.is_empty() {
inner_string.push_str(&format!(",color=\"{}\"", color[i + y_length])[..]);
}
if !legends.is_empty() {
inner_string.push_str(&format!(",label=r\"{}\"", legends[i + y_length])[..]);
}
for i in 0..pair_length {
match self.markers[i + y_length] {
Line => plot_string.push_str(
&format!(
"plt.plot(pair[{}][0],pair[{}][1],label=r\"{}\")\n",
i,
i,
legends[i + y_length]
)[..],
),
Point => plot_string.push_str(
&format!(
"plt.plot(pair[{}][0],pair[{}][1],\".\",label=r\"{}\")\n",
i,
i,
legends[i + y_length]
)[..],
),
Circle => plot_string.push_str(
&format!(
"plt.plot(pair[{}][0],pair[{}][1],\"o\",label=r\"{}\")\n",
i,
i,
legends[i + y_length]
)[..],
),
}
if !alpha.is_empty() {
inner_string.push_str(&format!(",alpha={}", alpha[i + y_length])[..]);
}
plot_string.push_str(&format!("plt.plot({})\n", inner_string)[..]);
}

if self.tight {
Expand Down

0 comments on commit d98d12c

Please sign in to comment.