diff --git a/src/util/plot.rs b/src/util/plot.rs index c334f3b4..d592f52d 100644 --- a/src/util/plot.rs +++ b/src/util/plot.rs @@ -90,6 +90,7 @@ pub use self::Grid::{Off, On}; pub use self::Markers::{Circle, Line, Point}; use self::PlotOptions::{Domain, Images, Legends, Pairs, Path}; use std::collections::HashMap; +use std::fmt::Display; type Vector = Vec; @@ -107,6 +108,69 @@ pub enum Markers { Point, Line, Circle, + Pixel, + TriangleDown, + TriangleUp, + TriangleLeft, + TriangleRight, + Square, + Pentagon, + Star, + Hexagon1, + Hexagon2, + Plus, + X, + Diamond, + ThinDiamond, + VLine, + HLine, +} + +impl Display for Markers { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let str = match self { + Markers::Point => ".".to_string(), + Markers::Line => "-".to_string(), + Markers::Circle => "o".to_string(), + Markers::Pixel => ",".to_string(), + Markers::TriangleDown => "v".to_string(), + Markers::TriangleUp => "^".to_string(), + Markers::TriangleLeft => "<".to_string(), + Markers::TriangleRight => ">".to_string(), + Markers::Square => "s".to_string(), + Markers::Pentagon => "p".to_string(), + Markers::Star => "*".to_string(), + Markers::Hexagon1 => "h".to_string(), + Markers::Hexagon2 => "H".to_string(), + Markers::Plus => "+".to_string(), + Markers::X => "x".to_string(), + Markers::Diamond => "D".to_string(), + Markers::ThinDiamond => "d".to_string(), + Markers::VLine => "|".to_string(), + Markers::HLine => "_".to_string(), + }; + write!(f, "{}", str) + } +} + +#[derive(Debug, Copy, Clone, Hash, PartialOrd, PartialEq, Eq)] +pub enum LineStyle { + Solid, + Dashed, + Dotted, + DashDot, +} + +impl Display for LineStyle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let str = match self { + LineStyle::Solid => "solid".to_string(), + LineStyle::Dashed => "dashed".to_string(), + LineStyle::Dotted => "dotted".to_string(), + LineStyle::DashDot => "dashdot".to_string(), + }; + write!(f, "{}", str) + } } #[derive(Debug, Copy, Clone, Hash, PartialOrd, PartialEq, Eq)] @@ -155,6 +219,9 @@ pub trait Plot { fn set_marker(&mut self, styles: Vec) -> &mut Self; fn set_style(&mut self, style: PlotStyle) -> &mut Self; fn tight_layout(&mut self) -> &mut Self; + fn set_line_style(&mut self, style: Vec) -> &mut Self; + fn set_color(&mut self, color: Vec<&str>) -> &mut Self; + fn set_alpha(&mut self, alpha: Vec) -> &mut Self; fn savefig(&self) -> PyResult<()>; } @@ -172,6 +239,9 @@ pub struct Plot2D { ylim: Option<(f64, f64)>, legends: Vec, markers: Vec, + line_style: Vec, + color: Vec, + alpha: Vec, path: String, fig_size: Option<(usize, usize)>, dpi: usize, @@ -187,7 +257,6 @@ impl Plot2D { default_options.insert(Domain, false); default_options.insert(Images, false); default_options.insert(Pairs, false); - default_options.insert(Legends, false); default_options.insert(Path, false); Plot2D { @@ -203,6 +272,9 @@ impl Plot2D { ylim: None, legends: vec![], markers: vec![], + line_style: vec![], + color: vec![], + alpha: vec![], path: "".to_string(), fig_size: None, dpi: 300, @@ -279,9 +351,6 @@ impl Plot for Plot2D { } fn set_legend(&mut self, legends: Vec<&str>) -> &mut Self { - if let Some(x) = self.options.get_mut(&Legends) { - *x = true - } self.legends = legends .into_iter() .map(|x| x.to_owned()) @@ -327,6 +396,21 @@ impl Plot for Plot2D { self } + fn set_line_style(&mut self, style: Vec) -> &mut Self { + self.line_style = style; + self + } + + fn set_color(&mut self, color: Vec<&str>) -> &mut Self { + self.color = color.into_iter().map(|x| x.to_owned()).collect(); + self + } + + fn set_alpha(&mut self, alpha: Vec) -> &mut Self { + self.alpha = alpha; + self + } + fn savefig(&self) -> PyResult<()> { // Check domain match self.options.get(&Domain) { @@ -362,21 +446,6 @@ impl Plot for Plot2D { _ => (), } - // Check legends - match self.options.get(&Legends) { - Some(x) => { - assert!(*x, "Legends are not defined"); - assert_eq!( - self.images.len() + self.pairs.len(), - self.legends.len(), - "Legends are not matched with images" - ); - } - None => { - assert!(false, "Legends are None"); - } - } - // Plot Python::with_gil(|py| { // Input data @@ -402,6 +471,10 @@ impl Plot for Plot2D { let ylabel = self.ylabel.clone(); let legends = self.legends.clone(); let path = self.path.clone(); + let markers = self.markers.iter().map(|x| format!("{}", x)).collect::>(); + let line_style = self.line_style.iter().map(|x| format!("{}", x)).collect::>(); + let color = self.color.clone(); + let alpha = self.alpha.clone(); // Global variables to plot let globals = vec![("plt", py.import("matplotlib.pyplot")?)].into_py_dict(py); @@ -460,77 +533,57 @@ impl Plot for Plot2D { plot_string.push_str(&format!("plt.ylabel(r\"{}\")\n", y)[..]); } match self.xscale { - PlotScale::Linear => plot_string.push_str(&format!("plt.xscale(\"linear\")\n")[..]), - PlotScale::Log => plot_string.push_str(&format!("plt.xscale(\"log\")\n")[..]), + PlotScale::Linear => plot_string.push_str(&"plt.xscale(\"linear\")\n".to_string()[..]), + PlotScale::Log => plot_string.push_str(&"plt.xscale(\"log\")\n".to_string()[..]), } match self.yscale { - PlotScale::Linear => plot_string.push_str(&format!("plt.yscale(\"linear\")\n")[..]), - PlotScale::Log => plot_string.push_str(&format!("plt.yscale(\"log\")\n")[..]), + PlotScale::Linear => plot_string.push_str(&"plt.yscale(\"linear\")\n".to_string()[..]), + PlotScale::Log => plot_string.push_str(&"plt.yscale(\"log\")\n".to_string()[..]), } if let Some(xl) = self.xlim { - plot_string.push_str(&format!("plt.xlim(xl)\n")[..]); + plot_string.push_str(&"plt.xlim(xl)\n".to_string()[..]); } if let Some(yl) = self.ylim { - plot_string.push_str(&format!("plt.ylim(yl)\n")[..]); + plot_string.push_str(&"plt.ylim(yl)\n".to_string()[..]); } - if self.markers.len() == 0 { - for i in 0..y_length { - plot_string - .push_str(&format!("plt.plot(x,y[{}],label=r\"{}\")\n", i, legends[i])[..]) + for i in 0..y_length { + let mut inner_string = format!("x,y[{}]", i); + if !markers.is_empty() { + inner_string.push_str(&format!(",marker=\"{}\"", markers[i])[..]); } - for i in 0..pair_length { - plot_string.push_str( - &format!( - "plt.plot(pair[{}][0],pair[{}][1],label=r\"{}\")\n", - i, - i, - legends[i + y_length] - )[..], - ) + if !line_style.is_empty() { + inner_string.push_str(&format!(",linestyle=\"{}\"", line_style[i])[..]); } - } else { - for i in 0..y_length { - match self.markers[i] { - Line => plot_string.push_str( - &format!("plt.plot(x,y[{}],label=r\"{}\")\n", i, legends[i])[..], - ), - Point => plot_string.push_str( - &format!("plt.plot(x,y[{}],\".\",label=r\"{}\")\n", i, legends[i])[..], - ), - Circle => plot_string.push_str( - &format!("plt.plot(x,y[{}],\"o\",label=r\"{}\")\n", i, legends[i])[..], - ), - } + if !color.is_empty() { + inner_string.push_str(&format!(",color=\"{}\"", color[i])[..]); + } + if !legends.is_empty() { + inner_string.push_str(&format!(",label=r\"{}\"", legends[i])[..]); + } + if !alpha.is_empty() { + inner_string.push_str(&format!(",alpha={}", alpha[i])[..]); + } + plot_string.push_str(&format!("plt.plot({})\n", inner_string)[..]); + } + for i in 0..pair_length { + let mut inner_string = format!("pair[{}][0],pair[{}][1]", i, i); + if !markers.is_empty() { + inner_string.push_str(&format!(",marker=\"{}\"", markers[i + y_length])[..]); + } + if !line_style.is_empty() { + inner_string.push_str(&format!(",linestyle=\"{}\"", line_style[i + y_length])[..]); + } + if !color.is_empty() { + inner_string.push_str(&format!(",color=\"{}\"", color[i + y_length])[..]); + } + if !legends.is_empty() { + inner_string.push_str(&format!(",label=r\"{}\"", legends[i + y_length])[..]); } - for i in 0..pair_length { - match self.markers[i + y_length] { - Line => plot_string.push_str( - &format!( - "plt.plot(pair[{}][0],pair[{}][1],label=r\"{}\")\n", - i, - i, - legends[i + y_length] - )[..], - ), - Point => plot_string.push_str( - &format!( - "plt.plot(pair[{}][0],pair[{}][1],\".\",label=r\"{}\")\n", - i, - i, - legends[i + y_length] - )[..], - ), - Circle => plot_string.push_str( - &format!( - "plt.plot(pair[{}][0],pair[{}][1],\"o\",label=r\"{}\")\n", - i, - i, - legends[i + y_length] - )[..], - ), - } + if !alpha.is_empty() { + inner_string.push_str(&format!(",alpha={}", alpha[i + y_length])[..]); } + plot_string.push_str(&format!("plt.plot({})\n", inner_string)[..]); } if self.tight {