Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fmilthaler committed Oct 1, 2023
1 parent 15d6927 commit 7fed1bf
Showing 1 changed file with 31 additions and 49 deletions.
80 changes: 31 additions & 49 deletions tests/test_momentum_indicators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest

from finquant.momentum_indicators import macd
from finquant.momentum_indicators import mpl_macd
from finquant.momentum_indicators import relative_strength_index as rsi

plt.switch_backend("Agg")
Expand All @@ -17,7 +18,6 @@ def test_rsi():
rsi(df)
# get data from axis object
ax = plt.gca()
# ax.lines[0] is the data we passed to plot_bollinger_band
line1 = ax.lines[0]
stock_plot = line1.get_xydata()
xlabel_plot = ax.get_xlabel()
Expand All @@ -33,7 +33,7 @@ def test_rsi_standalone():
x = np.sin(np.linspace(1, 10, 100))
xlabel_orig = "Date"
ylabel_orig = "RSI"
labels_orig = ["rsi"]
labels_orig = ["overbought", "oversold", "rsi"]
title_orig = "RSI Plot"
df = pd.DataFrame({"Stock": x}, index=np.linspace(1, 10, 100))
df.index.name = "Date"
Expand All @@ -60,53 +60,35 @@ def test_rsi_standalone():
assert title_plot == title_orig


def test_macd():
def test_mpl_macd():
axes0_ylabel_orig = "Price"
axes4_ylabel_orig = "Volume $10^{6}$"
# Create sample data for testing
x = np.sin(np.linspace(1, 10, 100))
xlabel_orig = "Date"
ylabel_orig = "Price"
df = pd.DataFrame({"Stock": x}, index=np.linspace(1, 10, 100))
df.index.name = "Date"
macd(df)
# get data from axis object
ax = plt.gca()
# ax.lines[0] is the data we passed to plot_bollinger_band
line1 = ax.lines[0]
stock_plot = line1.get_xydata()
xlabel_plot = ax.get_xlabel()
ylabel_plot = ax.get_ylabel()
# tests
assert (df["Stock"].index.values == stock_plot[:, 0]).all()
assert (df["Stock"].values == stock_plot[:, 1]).all()
assert xlabel_orig == xlabel_plot
assert ylabel_orig == ylabel_plot
df = pd.DataFrame({"Close": x}, index=pd.date_range("2015-01-01", periods=100, freq="D"))
df.name = "DIS"

# Call mpl_macd function
fig, axes = mpl_macd(df)

def test_macd_standalone():
labels_orig = ["MACD", "diff", "SIGNAL"]
axes0_ylabel_plot = axes[0].get_ylabel()
axes4_ylabel_plot = axes[4].get_ylabel()

# Check if the function returned valid figures and axes objects
assert isinstance(fig, plt.Figure)
assert isinstance(axes, list)
assert len(axes) == 6 # Assuming there are six subplots in the returned figure
assert axes0_ylabel_orig == axes0_ylabel_plot
assert axes4_ylabel_orig == axes4_ylabel_plot

def test_mpl_macd_invalid_window_parameters():
# Create sample data with invalid window parameters
x = np.sin(np.linspace(1, 10, 100))
xlabel_orig = "Date"
ylabel_orig = "MACD"
df = pd.DataFrame({"Stock": x}, index=np.linspace(1, 10, 100))
df.index.name = "Date"
macd(df, standalone=True)
# get data from axis object
ax = plt.gca()
labels_plot = ax.get_legend_handles_labels()[1]
xlabel_plot = ax.get_xlabel()
ylabel_plot = ax.get_ylabel()
assert labels_plot == labels_orig
assert xlabel_plot == xlabel_orig
assert ylabel_plot == ylabel_orig
# ax.lines[0] is macd data
# ax.lines[1] is diff data
# ax.lines[2] is macd_s data
# tests
for i, key in ((0, "macd"), (1, "diff"), (2, "macd_s")):
line = ax.lines[i]
data_plot = line.get_xydata()
# tests
assert (df[key].index.values == data_plot[:, 0]).all()
# for comparing values, we need to remove nan
a, b = df[key].values, data_plot[:, 1]
a, b = map(lambda x: x[~np.isnan(x)], (a, b))
assert (a == b).all()
df = pd.DataFrame({"Close": x}, index=pd.date_range("2015-01-01", periods=100, freq="D"))
df.name = "DIS"

# Call mpl_macd function with invalid window parameters and check for ValueError
with pytest.raises(ValueError):
mpl_macd(df, longer_ema_window=10, shorter_ema_window=20, signal_ema_window=30)
with pytest.raises(ValueError):
mpl_macd(df, longer_ema_window=10, shorter_ema_window=5, signal_ema_window=30)

0 comments on commit 7fed1bf

Please sign in to comment.